Coverage for src/flag_gems/runtime/backend/_arm/ops/bmm.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10 

11# @libentry() 

12@triton.autotune( 

13 configs=runtime.get_tuned_config("bmm"), 

14 key=["M", "N", "K"], 

15) 

16@triton.heuristics(runtime.get_heuristic_config("bmm")) 

17@triton.jit 

18def bmm_kernel( 

19 A, 

20 B, 

21 O, 

22 M, 

23 N, 

24 K, 

25 TILE_M: tl.constexpr, 

26 TILE_N: tl.constexpr, 

27 TILE_K: tl.constexpr, 

28 GROUP_M: tl.constexpr, 

29 DIVISIBLE_M: tl.constexpr, 

30 DIVISIBLE_N: tl.constexpr, 

31 DIVISIBLE_K: tl.constexpr, 

32): 

33 # batch offsets 

34 pid_b = tle.program_id(2) 

35 A += pid_b * M * K 

36 B += pid_b * K * N 

37 O += pid_b * M * N 

38 

39 pidx = tle.program_id(0) 

40 pidy = tle.program_id(1) 

41 

42 if GROUP_M == 1: 

43 pid_m, pid_n = pidx, pidy 

44 else: 

45 # reorder CTAs 

46 gridx = tle.num_programs(0) 

47 gridy = tle.num_programs(1) 

48 pid = pidx + pidy * gridx 

49 

50 num_CTA_per_group = gridy * GROUP_M 

51 

52 group_id = pid // num_CTA_per_group 

53 inner_group_id = pid % num_CTA_per_group 

54 GROUP_SIZE = tl.where( 

55 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

56 ) 

57 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

58 pid_n = inner_group_id // GROUP_SIZE 

59 

60 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

61 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

62 offs_k = tl.arange(0, TILE_K) 

63 

64 if not DIVISIBLE_M: 

65 mask_m = offs_m < M 

66 if not DIVISIBLE_N: 

67 mask_n = offs_n < N 

68 

69 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

70 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

71 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

72 

73 num_iters = tl.cdiv(K, TILE_K) 

74 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

75 for _ in range(num_iters): 

76 if DIVISIBLE_K: 

77 if DIVISIBLE_M: 

78 mask_a = None 

79 else: 

80 mask_a = mask_m[:, None] 

81 if DIVISIBLE_N: 

82 mask_b = None 

83 else: 

84 mask_b = mask_n[None, :] 

85 else: 

86 mask_k = offs_k < K 

87 if DIVISIBLE_M: 

88 mask_a = mask_k[None, :] 

89 else: 

90 mask_a = mask_m[:, None] & mask_k[None, :] 

91 if DIVISIBLE_N: 

92 mask_b = mask_k[:, None] 

93 else: 

94 mask_b = mask_k[:, None] & mask_n[None, :] 

95 

96 a = tl.load(a_ptrs, mask_a) 

97 b = tl.load(b_ptrs, mask_b) 

98 

99 offs_k += TILE_K 

100 a_ptrs += TILE_K 

101 b_ptrs += TILE_K * N 

102 

103 o += tl.dot(a, b, allow_tf32=False) 

104 

105 if DIVISIBLE_M and DIVISIBLE_N: 

106 mask_c = None 

107 elif DIVISIBLE_M and not DIVISIBLE_N: 

108 mask_c = mask_n[None, :] 

109 elif not DIVISIBLE_M and DIVISIBLE_N: 

110 mask_c = mask_m[:, None] 

111 else: 

112 mask_c = mask_m[:, None] & mask_n[None, :] 

113 tl.store(o_ptrs, o, mask_c) 

114 

115 

116def bmm(A, B): 

117 logging.debug("GEMS BMM") 

118 batch, M, K = A.shape 

119 _, _, N = B.shape 

120 A = A.contiguous() 

121 B = B.contiguous() 

122 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

123 

124 grid_fn = lambda meta: ( 

125 triton.cdiv(meta["M"], meta["TILE_M"]), 

126 triton.cdiv(meta["N"], meta["TILE_N"]), 

127 batch, 

128 ) 

129 # with torch_device_fn.device(A.device): 

130 bmm_kernel[grid_fn](A, B, out, M, N, K) 

131 return out