Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_add.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems import runtime 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger( 

12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" 

13) 

14 

15 

16@libentry() 

17@triton.heuristics(runtime.get_heuristic_config("index_add")) 

18@triton.jit 

19def index_add_kernel( 

20 inp_ptr, 

21 out_ptr, 

22 index_ptr, 

23 src_ptr, 

24 M, 

25 N, 

26 alpha, 

27 inp_len, 

28 BLOCK_M: tl.constexpr, 

29 BLOCK_N: tl.constexpr, 

30): 

31 """ 

32 Kernel for index_add operation with autotune. 

33 

34 After dim_compress, tensors are reshaped so that: 

35 - inp has shape (M, inp_len) where inp_len is the size of target dimension 

36 - src has shape (M, N) where N is the size of index 

37 

38 For each row m and each index position n: 

39 out[m, index[n]] += alpha * src[m, n] 

40 """ 

41 pid_m = tle.program_id(axis=0) 

42 pid_n = tle.program_id(axis=1) 

43 

44 # Calculate row and column offsets 

45 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

46 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

47 

48 # Create masks 

49 rows_mask = rows_offset < M 

50 cols_mask = cols_offset < N 

51 block_mask = rows_mask & cols_mask 

52 

53 # Load indices for this block of columns 

54 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0) 

55 

56 # Calculate offsets into inp/out (which has shape M x inp_len) 

57 inp_off = rows_offset * inp_len + cur_indices 

58 

59 # Load current values from input 

60 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0) 

61 

62 # Calculate offsets into src (which has shape M x N) 

63 src_off = rows_offset * N + cols_offset 

64 

65 # Load source values 

66 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0) 

67 

68 # Compute: out = inp + alpha * src 

69 result = cur_inp + alpha * cur_src 

70 

71 # Store result 

72 tl.store(out_ptr + inp_off, result, mask=block_mask) 

73 

74 

75def index_add(inp, dim, index, src, alpha=1): 

76 """ 

77 Optimized index_add for mthreads backend. 

78 

79 self.index_add_(dim, index, source, alpha=1) -> Tensor 

80 

81 For a 3-D tensor the output is: 

82 self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 

83 self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 

84 self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 

85 """ 

86 logger.debug("GEMS_MTHREADS INDEX ADD") 

87 

88 # Make inputs contiguous 

89 inp = inp.contiguous() 

90 index = index.contiguous() 

91 src = src.contiguous() 

92 

93 # Normalize dimension 

94 dim = dim % inp.ndim 

95 inp_len = inp.size(dim) 

96 N = index.numel() 

97 M = src.numel() // N 

98 

99 # Move target dim to last position for coalesced memory access 

100 final_dim = inp.ndim - 1 

101 if dim != final_dim: 

102 inp = dim_compress(inp, dim) 

103 src = dim_compress(src, dim) 

104 

105 # Clone input for output 

106 out = inp.clone() 

107 

108 # Calculate grid with autotune 

109 grid = lambda meta: ( 

110 triton.cdiv(M, meta["BLOCK_M"]), 

111 triton.cdiv(N, meta["BLOCK_N"]), 

112 ) 

113 

114 with torch_device_fn.device(inp.device): 

115 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len) 

116 

117 # Restore original dimension order if needed 

118 if dim != final_dim: 

119 order = list(range(out.ndim - 1)) 

120 order.insert(dim, final_dim) 

121 return out.permute(order).contiguous() 

122 else: 

123 return out 

124 

125 

126def index_add_(inp, dim, index, src, alpha=1): 

127 """ 

128 In-place version of index_add. 

129 """ 

130 logger.debug("GEMS_MTHREADS INDEX ADD_") 

131 

132 # Make index and src contiguous 

133 index = index.contiguous() 

134 src = src.contiguous() 

135 

136 # Normalize dimension 

137 dim = dim % inp.ndim 

138 inp_len = inp.size(dim) 

139 N = index.numel() 

140 M = src.numel() // N 

141 

142 # Move target dim to last position 

143 final_dim = inp.ndim - 1 

144 

145 if dim != final_dim: 

146 # Need to work on a permuted copy 

147 inp_work = dim_compress(inp.clone().contiguous(), dim) 

148 src_work = dim_compress(src, dim) 

149 

150 # Calculate grid with autotune 

151 grid = lambda meta: ( 

152 triton.cdiv(M, meta["BLOCK_M"]), 

153 triton.cdiv(N, meta["BLOCK_N"]), 

154 ) 

155 

156 with torch_device_fn.device(inp.device): 

157 index_add_kernel[grid]( 

158 inp_work, inp_work, index, src_work, M, N, alpha, inp_len 

159 ) 

160 

161 # Restore original dimension order and copy back 

162 order = list(range(inp_work.ndim - 1)) 

163 order.insert(dim, final_dim) 

164 inp_work = inp_work.permute(order).contiguous() 

165 inp.copy_(inp_work) 

166 else: 

167 # Can work directly on input if already contiguous 

168 inp_contig = inp.contiguous() 

169 

170 # Calculate grid with autotune 

171 grid = lambda meta: ( 

172 triton.cdiv(M, meta["BLOCK_M"]), 

173 triton.cdiv(N, meta["BLOCK_N"]), 

174 ) 

175 

176 with torch_device_fn.device(inp.device): 

177 index_add_kernel[grid]( 

178 inp_contig, inp_contig, index, src, M, N, alpha, inp_len 

179 ) 

180 

181 # Copy back if input wasn't contiguous 

182 if not inp.is_contiguous(): 

183 inp.copy_(inp_contig) 

184 

185 return inp