diff --git a/gimmik/__init__.py b/gimmik/__init__.py index b32ebdc..cd21134 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -8,6 +8,7 @@ from gimmik.hip import HIPMatMul from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul +from gimmik.ptx import PTXMatMul def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', 'cuda': CUDAMatMul, 'ispc': ISPCMatMul, 'hip': HIPMatMul, - 'opencl': OpenCLMatMul + 'opencl': OpenCLMatMul, + 'ptx': PTXMatMul } mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc) diff --git a/gimmik/base.py b/gimmik/base.py index f547afc..0ecc29a 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -144,7 +144,8 @@ def _render_kernel(self, dtype, tplname, tplargs): src = tpl.render(**tplargs) # At single precision suffix all floating point constants by 'f' - if dtype == 'float': + # (PTX doesn't use an 'f' suffix for FP literals) + if dtype == 'float' and self.platform != 'ptx': src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index b18c509..9e1da43 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul @@ -8,7 +10,15 @@ class CUDAMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + @staticmethod + def is_suitable(arr): + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + **kwargs): # B loading, C streaming kernel yield ('cstream', {}, {}) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako new file mode 100644 index 0000000..dbd8433 --- /dev/null +++ b/gimmik/kernels/ptx/base.mako @@ -0,0 +1,4 @@ +.version ${ptx[0]}.${ptx[1]} +.target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} +.address_size 64 +${next.body()} diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako new file mode 100644 index 0000000..2ef85e9 --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -0,0 +1,261 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * dwidth_i +use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv, csub<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _bsub[${bsub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if use_cpasync: +## Async fill of chunk 0 +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +## Sync fill of chunk 0 +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + { + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(len(bchunks)): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 +%> +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% else: + { + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): + ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif +% if kx == alix[row_j]: +% if beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor +## End of Main loop over B-chunks + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako new file mode 100644 index 0000000..45eb1a7 --- /dev/null +++ b/gimmik/kernels/ptx/bstream.mako @@ -0,0 +1,159 @@ +<%inherit file='base'/> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} csub<${m}>, bv<${len(bix)}>; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +% if not beta_zero: +## Pre-load C so per-row completion is a plain store +% for j in range(m): +% if afix[j] != -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; +% endif +% endif +% endfor +% for j in range(m): +% if afix[j] != -1: + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif + +## Main compute +% for kx in bix: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if beta_zero and kx == afix[j]: + mul.${pftype} csub${j}, bv${bix[kx]}, ${jx}; +% else: + fma.rn.${pftype} csub${j}, bv${bix[kx]}, ${jx}, csub${j}; +% endif +% endif +% if kx == alix[j]: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; +% endif + +% endif +% endfor +% endfor + +% if has_zero_rows: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% for j, jx in enumerate(afix): +% if jx == -1 and beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif +% endif +% endfor + } +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako new file mode 100644 index 0000000..5d704de --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -0,0 +1,176 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv<${bv_per_thread}>, cv<${cv_per_thread}>, dotp; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _csub[${csub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% + loaded = set() + kbx_idx = {kx: i for i, kx in enumerate(kbx)} +%> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% + nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] + owner_bid = row_idx % ksplit +%> +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${kxi}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +<% loaded.add(kx) %> +% endif +% endfor +% if nz: +% for kxi, kx, jx in nz: +% if loop.first: + mul.${pftype} dotp, bv${kxi}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; +% endif +% endfor +% else: + mov.${pftype} dotp, ${fzero}; +% endif +% if owner_bid == bid: + mov.${pftype} cv${row_idx // ksplit}, dotp; +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: + mov.${pftype} dotp, cv${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; + } +% endif +% endfor +% if beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako new file mode 100644 index 0000000..ce7301d --- /dev/null +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -0,0 +1,86 @@ +<%inherit file='base'/> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .f64 bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; + } + +## Batch-load B column pairs +% for i, kx in enumerate(bix): + ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +% endfor + +## Main compute: two parallel dot-product streams per thread +% for j in range(m): +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: + mul.f64 dotp_a, bv_a${bix[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix[kx]}, ${jx}; +% else: + fma.rn.f64 dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; +% endif +% endfor +% if beta_zero: + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta_zero: + { + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako new file mode 100644 index 0000000..9ce4c4d --- /dev/null +++ b/gimmik/kernels/ptx/cstream.mako @@ -0,0 +1,147 @@ +<%inherit file='base'/> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv<${len(bix)}>, dotp; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +## Compute and store each output row +% for j in range(m): +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: + mul.${pftype} dotp, bv${bix[kx]}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${bix[kx]}, ${jx}, dotp; +% endif +% endfor +% if beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% else: +## Zero row of A +% if beta_zero: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako new file mode 100644 index 0000000..3df43c0 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -0,0 +1,183 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako new file mode 100644 index 0000000..9a88b64 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -0,0 +1,302 @@ +<%inherit file='base'/> + +<% +# Cooperative-copy params (gA-only) +blockx = 32 * warps_per_cta +a_pairs = m_tiles*k_tiles*32 // 2 +a_pairs_tail = m_tiles*k_tiles*32 % 2 +copy_v2_iters = (a_pairs + blockx - 1) // blockx +bs = bool(block_stealing) +%> + +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${m_tiles*k_tiles*32}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag; +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared via v2 loads + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(copy_v2_iters): +<% + base_pair = ci * blockx + pairs_this = min(blockx, a_pairs - base_pair) +%> + { + .reg .u32 pidx; + .reg .u64 off64, gaddr, saddr; + .reg .${pftype} v0, v1; +% if loop.last and pairs_this < blockx: + .reg .pred plast; + add.u32 pidx, tid, ${base_pair}; + setp.lt.u32 plast, pidx, ${a_pairs}; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + @plast st.shared.v2.${pftype} [saddr], {v0, v1}; +% else: + add.u32 pidx, tid, ${base_pair}; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + st.shared.v2.${pftype} [saddr], {v0, v1}; +% endif + } +% endfor +% if a_pairs_tail: + // Tail element (only when m_tiles*k_tiles*32 is odd) + { + .reg .pred plast; + .reg .u64 gaddr, saddr; + .reg .${pftype} v; + setp.eq.u32 plast, tid, 0; + add.u64 gaddr, a_glb_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; + } +% endif + } + bar.sync 0; + + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + + { + .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: + mov.u32 cta, %ctaid.x; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: + @pwarp_exit bra $L_EXIT; +% endif + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + } +% endfor +% endfor + +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + // 1D grid: extract just x + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako new file mode 100644 index 0000000..e151372 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -0,0 +1,432 @@ +<%inherit file='base'/> + +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + B load for ctaid_x's work + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + .reg .u64 a_glb; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${m_tiles*k_tiles*32 * 8}, [tma_mbar]; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes + m_tiles*k_tiles*32 * 8}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + +<%def name="compute_warp_body()"> + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on B + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [c_desc]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } +% endif + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.${pftype} d_x_${mt}_${nt}, ${fzero}; + mov.${pftype} d_y_${mt}_${nt}, ${fzero}; +% endfor +% endfor + + .reg .${pftype} a_f; +% for mt in range(m_tiles): +% for kt in range(k_tiles): +<% + k_tail = (k_rem != 0 and loop.last) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_tiles) * dwidth_i}; + ld.shared.${pftype} a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .${pftype} b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.${pftype} b_f, ${fzero}; + @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; +% else: + ld.shared.${pftype} b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + +% if beta_zero: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = (m_pad > m) and ((mt + 1) * 8 > m) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% endif + +% if not beta_zero: + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } +% endif + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_C: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_C; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + } +$L_AFTER_COMPUTE: + + +<%def name="data_warp_body()"> + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_D: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_D; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + + // TMA loads of next B + { + mul.lo.u32 n_start_next, block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + } + bar.warp.sync 0xffffffff; + +% if not beta_zero: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } +% endif + + // Wait for next B to be ready, then signal B and C ready + { + .reg .b64 b_state, _bready_state, _c_state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } + } +$L_AFTER_DATA: + + +<%def name="ctrl_warp_body()"> + // --- Controller Warp + @!p_steal bra.uni $L_AFTER_CTRL; + { + .reg .pred p1, p2, p_canc; + .reg .b64 _state; + .reg .b128 resp; + @p_warp_lead fence.proxy.async.shared::cta; + @p_warp_lead clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 + [wid_smem], [steal_mbar]; + @p_warp_lead mbarrier.arrive.expect_tx.shared::cta.b64 + _state, [steal_mbar], 16; + +$L_WAIT_STEAL: + mbarrier.try_wait.parity.shared::cta.b64 p1, [steal_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_STEAL; + + // Signal new work + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_new_mbar]; + + // Query if there's new work + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + selp.b32 work, 1, 0, p_canc; + + // Wait for old work to be used +$L_WAIT_WUSED: + mbarrier.try_wait.parity.shared::cta.b64 p2, [wid_used_mbar], phase, ${mbar_maxwait}; + @!p2 bra.uni $L_WAIT_WUSED; + } +$L_AFTER_CTRL: + + +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + ${compute_warp_body()} + + ${data_warp_body()} + + ${ctrl_warp_body()} + + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py new file mode 100644 index 0000000..f7be8f4 --- /dev/null +++ b/gimmik/ptx.py @@ -0,0 +1,295 @@ +import numpy as np + +from gimmik.base import MatMul + + +class PTXMatMul(MatMul): + platform = 'ptx' + basemeta = { + 'block': (128, 1, 1), + 'width': 1, + 'shared': 0, + 'dynamic_shared': 0 + } + + # Map Supported CC -> Minimum PTX version + PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 0), (10, 0): (8, 7), (10, 3): (8, 7), + (12, 0): (8, 7), (12, 1): (8, 7)} + + @classmethod + def is_sparse_suitable(cls, arr, cc): + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM + + @classmethod + def is_dense_suitable(cls, arr, cc): + cc_appropriate = cc in cls.PTX_SM and cc >= (9, 0) + return (arr.dtype == np.float64 and cc_appropriate + and arr.shape[0] <= 128 and arr.shape[1] <= 128) + + @classmethod + def is_suitable(cls, arr, cc): + return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + smem_info=None): + cc = compute_capability or (0, 0) + ptx = self.PTX_SM.get(cc, (0, 0)) + smem_info = smem_info or (48*1024, 48*1024) + base_args = { + 'ptx': ptx, + 'cc': cc, + 'smem_info': smem_info, + 'pred_emit': self._pred_emit, + 'pftype': 'f32' if dtype == 'float' else 'f64', + 'dwidth_i': 4 if dtype == 'float' else 8, + 'fzero': ('0f00000000' if dtype == 'float' + else '0d0000000000000000'), + 'beta_zero': self.beta == 0, + 'mbar_maxwait': '0x989680', + } + + if self.is_sparse_suitable(self.A, cc): + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + + if self.is_dense_suitable(self.A, cc): + yield from self._dense_kernel_generators(dtype, dsize, base_args) + + def _sparse_kernel_generators(self, dtype, dsize, base_args): + # Sparse-shared template constants + base_args = base_args | { + 'has_zero_rows': bool(self.has_zero_rows), + 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) + if self.A[j, kx] != 0] for j in range(self.m)], + } + + # B loading, C streaming kernel + yield ('cstream', base_args, {'desc': 'cstream'}) + + # B streaming, C accumulation kernel + yield ('bstream', base_args, {'desc': 'bstream'}) + + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } + yield ('bstream-msplit', args, meta) + + # Single-warp LDGSTS variant for medium-M beta=0 large-K cases + if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: + ms, bsz, blkx = 1, 32, 64 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } + yield ('bstream-msplit', args, meta) + + # Two-way k-split B loading, C streaming kernel + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } + yield ('cstream-ksplit', args, meta) + + # Four-way k-split for large K + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } + yield ('cstream-ksplit', args, meta) + + # Width-2 vector cstream for fp64 small-K + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = { + 'block': (blkx, 1, 1), + 'width': 2, + 'desc': f'cstream-w2/x{blkx}', + } + yield ('cstream-w2', args, meta) + + def _dense_kernel_generators(self, dtype, dsize, base_args): + cc = base_args['cc'] or (0, 0) + + # Block stealing requires sm_100+ + block_steal = cc >= (10, 0) + if block_steal: + dense_configs = [('dense-mma-smem-gA', 4, 4)] + else: + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + + for tpl, nn, w in dense_configs: + if (n_per_cta := 8 * nn * w) > self.n: + continue + setup = self._dense_mma_setup(nn, w, block_steal) + blkx = 32 * w + args = base_args | setup + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'{tpl}/nn{nn}-w{w}{'-bs' if block_steal else ''}', + } + yield (tpl, args, meta) + + # Warp-specialised dense DMMA, required block stealing + if block_steal: + yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) + + def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + static_max, dynamic_max = base_args['smem_info'] + + # (nn, compute) -- block has compute + 2 warps (producer, stealer) + ws_configs = [(1, 4), (2, 4), (4, 4)] + for nn, w in ws_configs: + if (n_per_cta := 8 * nn * w) > self.n: + continue + + setup = self._dense_mma_setup(nn, w, True) + ws_setup = self._dense_ws_setup(setup, w) + + if ws_setup['dynm_total_bytes'] > dynamic_max: + continue + + blkx = 32 * (w + 2) + args = base_args | setup | ws_setup + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'dense-mma-ws/nn{nn}-w{w}', + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': ws_setup['dynm_total_bytes'], + } + if self.beta != 0: + meta |= {'ws_out_tile': (n_per_cta, setup['m_pad'])} + yield ('dense-mma-ws', args, meta) + + @staticmethod + def _dsmem_alloc(regions, mbars, align=16): + out, off = {}, 0 + for name, size in regions: + off = (off + align - 1) & ~(align - 1) + out[f'{name}_off'] = off + off += size + for name in mbars: + out[f'{name}_mbar_off'] = off + off += 8 + total = (off + align - 1) & ~(align - 1) + return out, total + + @classmethod + def _dense_ws_setup(cls, setup, n_comp_warps): + n_per_cta = setup['n_per_cta'] + b_tile_bytes = setup['k_pad'] * n_per_cta * 8 + c_tile_bytes = setup['m_pad'] * n_per_cta * 8 + a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 + + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) + + args = { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + 2), + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } + + return offsets | args + + def _dense_mma_setup(self, nn, warps_per_cta, block_steal): + a = self.A + m, k = a.shape + m_tiles = (m + 7) // 8 + k_tiles = (k + 3) // 4 + k_rem = k % 4 + + # A in DMMA-fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] + # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as + # uint64 + a_pad = np.zeros((m_tiles*8, k_tiles*4)) + a_pad[:m, :k] = a + tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).swapaxes(1, 2) + a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] + + n_per_warp = 8 * nn + n_per_cta = warps_per_cta * n_per_warp + + # Predicate-elision flags + n_col_aligned = (self.n is not None and self.n % n_per_warp == 0) + def pm_runtime(mt): + return (mt + 1) * 8 > m + + return { + 'warps_per_cta': warps_per_cta, + 'nn': nn, + 'm_tiles': m_tiles, + 'k_tiles': k_tiles, + 'k_rem': k_rem, + 'm_pad': m_tiles * 8, + 'k_pad': k_tiles * 4, + 'a_u64': a_u64, + 'n_per_warp': n_per_warp, + 'n_per_cta': n_per_cta, + 'frag_stride_bytes': 32 * 8, + 'b_kiter_stride': 4 * (self.ldb or 0) * 8, + 'b_ntile_stride': 8 * 8, + 'c_mtile_stride': 8 * (self.ldc or 0) * 8, + 'c_ntile_stride': 8 * 8, + 'n_col_aligned': n_col_aligned, + 'pm_runtime': pm_runtime, + 'block_stealing': block_steal, + } + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): + actual = [p for p in preds if p is not None] + if not actual: + return instr + if len(actual) == 1: + return f'@{actual[0]} {instr}' + if pred_reg is None: + raise ValueError('pred_reg required when combining multiple ' + 'predicates') + lines = [f'.reg .pred {pred_reg};', + f'and.pred {pred_reg}, {actual[0]}, {actual[1]};'] + for p in actual[2:]: + lines.append(f'and.pred {pred_reg}, {pred_reg}, {p};') + lines.append(f'@{pred_reg} {instr}') + return f'\n{indent}'.join(lines) + + def _process_meta(self, meta): + if self.n is not None and 'grid' not in meta: + div = meta['block'][0]*meta['width'] + meta['grid'] = (-(-self.n // div), 1, 1)