Coverage for src/flag_gems/runtime/backend/_spacemit/ops/bmm.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.extra.smt as smt 

7 

8from flag_gems import runtime 

9from flag_gems.fused import outer # noqa: E402 

10from flag_gems.ops import mul # noqa: E402 

11from flag_gems.utils import libentry, libtuner 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.get_tuned_config("bmm_spacemit"), 

19 key=["M", "N", "K"], 

20) 

21@triton.jit 

22def bmm_kernel( 

23 A, 

24 B, 

25 O, 

26 M, 

27 N, 

28 K, 

29 stride_ab, 

30 stride_am, 

31 stride_ak, 

32 stride_bb, 

33 stride_bk, 

34 stride_bn, 

35 stride_cb, 

36 stride_cm, 

37 stride_cn, 

38 TILE_M: tl.constexpr, 

39 TILE_N: tl.constexpr, 

40 EVEN_K: tl.constexpr, 

41 TILE_K: tl.constexpr, 

42 MICRO_M: tl.constexpr, 

43 MICRO_K: tl.constexpr, 

44 MICRO_N: tl.constexpr, 

45 SUB_BLK_K: tl.constexpr, 

46): 

47 pidx = tl.program_id(0) 

48 pidy = tl.program_id(1) 

49 pid_b = tl.program_id(2) 

50 

51 pid_m = pidx 

52 pid_n = pidy 

53 

54 block_m = pid_m * TILE_M 

55 block_n = pid_n * TILE_N 

56 

57 offset_a = pid_b * stride_ab 

58 offset_b = pid_b * stride_bb 

59 offset_o = pid_b * stride_cb 

60 

61 a_ptr = tl.make_block_ptr( 

62 A + offset_a, 

63 shape=(M, K), 

64 strides=(stride_am, stride_ak), 

65 offsets=(block_m, 0), 

66 block_shape=(TILE_M, TILE_K), 

67 order=(1, 0), 

68 ) 

69 

70 b_ptr = tl.make_block_ptr( 

71 B + offset_b, 

72 shape=(K, N), 

73 strides=(stride_bk, stride_bn), 

74 offsets=(0, block_n), 

75 block_shape=(TILE_K, TILE_N), 

76 order=(1, 0), 

77 ) 

78 

79 o_ptr = tl.make_block_ptr( 

80 O + offset_o, 

81 shape=(M, N), 

82 strides=(stride_cm, stride_cn), 

83 offsets=(block_m, block_n), 

84 block_shape=(TILE_M, TILE_N), 

85 order=(1, 0), 

86 ) 

87 

88 if EVEN_K: 

89 a_descriptor_load = smt.descriptor_load(a_ptr, (0, 0)) 

90 a = smt.view(a_descriptor_load, (0, 0), (TILE_M, TILE_K), (MICRO_M, MICRO_K)) 

91 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0)) 

92 b = smt.view(b_descriptor_load, (0, 0), (TILE_K, TILE_N), (MICRO_K, MICRO_N)) 

93 acc = smt.dot(a, b) 

94 else: 

95 acc = tl.zeros((TILE_M, TILE_N), dtype=A.type.element_ty) 

96 acc = smt.view(acc, (0, 0), (TILE_M, TILE_N), (MICRO_M, MICRO_N)) 

97 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K 

98 for k in tl.range(0, sub_num): 

99 a_descriptor_load = smt.descriptor_load(a_ptr, (0, 0)) 

100 a = smt.view( 

101 a_descriptor_load, 

102 (0, k * SUB_BLK_K), 

103 (TILE_M, SUB_BLK_K), 

104 (MICRO_M, MICRO_K), 

105 ) 

106 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0)) 

107 b = smt.view( 

108 b_descriptor_load, 

109 (k * SUB_BLK_K, 0), 

110 (SUB_BLK_K, TILE_N), 

111 (MICRO_K, MICRO_N), 

112 ) 

113 acc += smt.dot(a, b) 

114 acc = smt.view(acc, (0, 0), (TILE_M, TILE_N), (1, 1)) 

115 

116 c = acc.to(o_ptr.dtype.element_ty) 

117 

118 tl.store(o_ptr, c, boundary_check=(0, 1)) 

119 

120 

121def bmm(A, B): 

122 logger.debug("GEMS_SPACEMIT BMM") 

123 batch, M, K = A.shape 

124 _, _, N = B.shape 

125 if A.stride(0) > 1 and A.stride(1) > 1: 

126 A = A.contiguous() 

127 if B.stride(0) > 1 and B.stride(1) > 1: 

128 B = B.contiguous() 

129 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

130 

131 if K == 1 and batch == 1: 

132 vec_a = A[0, :, 0] 

133 vec_b = B[0, 0, :] 

134 result = outer(vec_a, vec_b) 

135 return result.unsqueeze(0) 

136 

137 if K == 1: 

138 return mul(A, B) 

139 

140 def grid_fn(meta): 

141 return ( 

142 triton.cdiv(meta["M"], meta["TILE_M"]), 

143 triton.cdiv(meta["N"], meta["TILE_N"]), 

144 batch, 

145 ) 

146 

147 TILE_K = triton.next_power_of_2(K) 

148 SUB_BLK_K = min(1024, TILE_K) 

149 

150 bmm_kernel[grid_fn]( 

151 A, 

152 B, 

153 out, 

154 M, 

155 N, 

156 K, 

157 A.stride(0), 

158 A.stride(1), 

159 A.stride(2), 

160 B.stride(0), 

161 B.stride(1), 

162 B.stride(2), 

163 out.stride(0), 

164 out.stride(1), 

165 out.stride(2), 

166 TILE_K=TILE_K, 

167 SUB_BLK_K=SUB_BLK_K, 

168 ) 

169 return out