Coverage for src/flag_gems/runtime/backend/_spacemit/ops/bmm.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.extra.smt as smt
8from flag_gems import runtime
9from flag_gems.fused import outer # noqa: E402
10from flag_gems.ops import mul # noqa: E402
11from flag_gems.utils import libentry, libtuner
13logger = logging.getLogger(__name__)
16@libentry()
17@libtuner(
18 configs=runtime.get_tuned_config("bmm_spacemit"),
19 key=["M", "N", "K"],
20)
21@triton.jit
22def bmm_kernel(
23 A,
24 B,
25 O,
26 M,
27 N,
28 K,
29 stride_ab,
30 stride_am,
31 stride_ak,
32 stride_bb,
33 stride_bk,
34 stride_bn,
35 stride_cb,
36 stride_cm,
37 stride_cn,
38 TILE_M: tl.constexpr,
39 TILE_N: tl.constexpr,
40 EVEN_K: tl.constexpr,
41 TILE_K: tl.constexpr,
42 MICRO_M: tl.constexpr,
43 MICRO_K: tl.constexpr,
44 MICRO_N: tl.constexpr,
45 SUB_BLK_K: tl.constexpr,
46):
47 pidx = tl.program_id(0)
48 pidy = tl.program_id(1)
49 pid_b = tl.program_id(2)
51 pid_m = pidx
52 pid_n = pidy
54 block_m = pid_m * TILE_M
55 block_n = pid_n * TILE_N
57 offset_a = pid_b * stride_ab
58 offset_b = pid_b * stride_bb
59 offset_o = pid_b * stride_cb
61 a_ptr = tl.make_block_ptr(
62 A + offset_a,
63 shape=(M, K),
64 strides=(stride_am, stride_ak),
65 offsets=(block_m, 0),
66 block_shape=(TILE_M, TILE_K),
67 order=(1, 0),
68 )
70 b_ptr = tl.make_block_ptr(
71 B + offset_b,
72 shape=(K, N),
73 strides=(stride_bk, stride_bn),
74 offsets=(0, block_n),
75 block_shape=(TILE_K, TILE_N),
76 order=(1, 0),
77 )
79 o_ptr = tl.make_block_ptr(
80 O + offset_o,
81 shape=(M, N),
82 strides=(stride_cm, stride_cn),
83 offsets=(block_m, block_n),
84 block_shape=(TILE_M, TILE_N),
85 order=(1, 0),
86 )
88 if EVEN_K:
89 a_descriptor_load = smt.descriptor_load(a_ptr, (0, 0))
90 a = smt.view(a_descriptor_load, (0, 0), (TILE_M, TILE_K), (MICRO_M, MICRO_K))
91 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0))
92 b = smt.view(b_descriptor_load, (0, 0), (TILE_K, TILE_N), (MICRO_K, MICRO_N))
93 acc = smt.dot(a, b)
94 else:
95 acc = tl.zeros((TILE_M, TILE_N), dtype=A.type.element_ty)
96 acc = smt.view(acc, (0, 0), (TILE_M, TILE_N), (MICRO_M, MICRO_N))
97 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K
98 for k in tl.range(0, sub_num):
99 a_descriptor_load = smt.descriptor_load(a_ptr, (0, 0))
100 a = smt.view(
101 a_descriptor_load,
102 (0, k * SUB_BLK_K),
103 (TILE_M, SUB_BLK_K),
104 (MICRO_M, MICRO_K),
105 )
106 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0))
107 b = smt.view(
108 b_descriptor_load,
109 (k * SUB_BLK_K, 0),
110 (SUB_BLK_K, TILE_N),
111 (MICRO_K, MICRO_N),
112 )
113 acc += smt.dot(a, b)
114 acc = smt.view(acc, (0, 0), (TILE_M, TILE_N), (1, 1))
116 c = acc.to(o_ptr.dtype.element_ty)
118 tl.store(o_ptr, c, boundary_check=(0, 1))
121def bmm(A, B):
122 logger.debug("GEMS_SPACEMIT BMM")
123 batch, M, K = A.shape
124 _, _, N = B.shape
125 if A.stride(0) > 1 and A.stride(1) > 1:
126 A = A.contiguous()
127 if B.stride(0) > 1 and B.stride(1) > 1:
128 B = B.contiguous()
129 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
131 if K == 1 and batch == 1:
132 vec_a = A[0, :, 0]
133 vec_b = B[0, 0, :]
134 result = outer(vec_a, vec_b)
135 return result.unsqueeze(0)
137 if K == 1:
138 return mul(A, B)
140 def grid_fn(meta):
141 return (
142 triton.cdiv(meta["M"], meta["TILE_M"]),
143 triton.cdiv(meta["N"], meta["TILE_N"]),
144 batch,
145 )
147 TILE_K = triton.next_power_of_2(K)
148 SUB_BLK_K = min(1024, TILE_K)
150 bmm_kernel[grid_fn](
151 A,
152 B,
153 out,
154 M,
155 N,
156 K,
157 A.stride(0),
158 A.stride(1),
159 A.stride(2),
160 B.stride(0),
161 B.stride(1),
162 B.stride(2),
163 out.stride(0),
164 out.stride(1),
165 out.stride(2),
166 TILE_K=TILE_K,
167 SUB_BLK_K=SUB_BLK_K,
168 )
169 return out