1.fragment的定义方式有两种,矩阵型和累加器型:

wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major/wmma::col_major> a_frag;

(只能half)

wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;

(可以float)

关于这个定义时的行优先还是列优先说明:如果定义了行优先,后期加载到fragment中的就是A矩阵本身,如果定义了列优先,就是A的逆矩阵被加载到fragment。

2.加载到fragment

wmma::load_matrix_sync(a_frag, a , M);

M为矩阵a的leading dimension,即如果是行优先,则为其列数,列优先则为其行数。

3.矩阵乘法

wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);

4.将acc_frag的值存储回内存中的矩阵

wmma::store_matrix_sync(c , acc_frag, M, wmma::mem_row_major);

若选择列优先,则选择了将其结果的逆矩阵存回。

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐