Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_add.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7from flag_gems.utils import dim_compress, libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.jit 

15def index_add_kernel( 

16 inp_ptr, 

17 out_ptr, 

18 index_ptr, 

19 src_ptr, 

20 M, 

21 N, 

22 alpha, 

23 inp_len, 

24 BLOCK_M: tl.constexpr, 

25 BLOCK_N: tl.constexpr, 

26): 

27 pid_m = tle.program_id(axis=0) 

28 pid_n = tle.program_id(axis=1) 

29 

30 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

31 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

32 

33 rows_mask = rows_offset < M 

34 cols_mask = cols_offset < N 

35 block_mask = rows_mask & cols_mask 

36 

37 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0) 

38 

39 inp_off = rows_offset * inp_len + cur_indices 

40 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0) 

41 

42 src_off = rows_offset * N + cols_offset 

43 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0) 

44 

45 result = cur_inp + alpha * cur_src 

46 tl.store(out_ptr + inp_off, result, mask=block_mask) 

47 

48 

49def _get_block_config(M, N): 

50 BLOCK_M = 4 if M < 4096 else 8 

51 BLOCK_N = max(4, min(512, triton.next_power_of_2(N))) 

52 return BLOCK_M, BLOCK_N 

53 

54 

55def index_add(inp, dim, index, src, alpha=1): 

56 logger.debug("GEMS_ASCEND INDEX ADD") 

57 

58 inp = inp.contiguous() 

59 index = index.contiguous() 

60 src = src.contiguous() 

61 

62 dim = dim % inp.ndim 

63 inp_len = inp.size(dim) 

64 N = index.numel() 

65 M = src.numel() // N 

66 

67 final_dim = inp.ndim - 1 

68 if dim != final_dim: 

69 inp = dim_compress(inp, dim) 

70 src = dim_compress(src, dim) 

71 

72 out = inp.clone() 

73 

74 BLOCK_M, BLOCK_N = _get_block_config(M, N) 

75 grid = ( 

76 triton.cdiv(M, BLOCK_M), 

77 triton.cdiv(N, BLOCK_N), 

78 ) 

79 

80 with torch_device_fn.device(inp.device): 

81 index_add_kernel[grid]( 

82 inp, 

83 out, 

84 index, 

85 src, 

86 M, 

87 N, 

88 alpha, 

89 inp_len, 

90 BLOCK_M=BLOCK_M, 

91 BLOCK_N=BLOCK_N, 

92 ) 

93 

94 if dim != final_dim: 

95 order = list(range(out.ndim - 1)) 

96 order.insert(dim, final_dim) 

97 return out.permute(order).contiguous() 

98 else: 

99 return out 

100 

101 

102def index_add_(inp, dim, index, src, alpha=1): 

103 logger.debug("GEMS_ASCEND INDEX ADD_") 

104 

105 index = index.contiguous() 

106 src = src.contiguous() 

107 

108 dim = dim % inp.ndim 

109 inp_len = inp.size(dim) 

110 N = index.numel() 

111 M = src.numel() // N 

112 

113 final_dim = inp.ndim - 1 

114 

115 if dim != final_dim: 

116 inp_work = dim_compress(inp.clone().contiguous(), dim) 

117 src_work = dim_compress(src, dim) 

118 out_work = inp_work.clone() 

119 

120 BLOCK_M, BLOCK_N = _get_block_config(M, N) 

121 grid = ( 

122 triton.cdiv(M, BLOCK_M), 

123 triton.cdiv(N, BLOCK_N), 

124 ) 

125 

126 with torch_device_fn.device(inp.device): 

127 index_add_kernel[grid]( 

128 inp_work, 

129 out_work, 

130 index, 

131 src_work, 

132 M, 

133 N, 

134 alpha, 

135 inp_len, 

136 BLOCK_M=BLOCK_M, 

137 BLOCK_N=BLOCK_N, 

138 ) 

139 

140 order = list(range(out_work.ndim - 1)) 

141 order.insert(dim, final_dim) 

142 inp_work = out_work.permute(order).contiguous() 

143 inp.copy_(inp_work) 

144 else: 

145 inp_contig = inp.contiguous() 

146 out_contig = inp_contig.clone() 

147 

148 BLOCK_M, BLOCK_N = _get_block_config(M, N) 

149 grid = ( 

150 triton.cdiv(M, BLOCK_M), 

151 triton.cdiv(N, BLOCK_N), 

152 ) 

153 

154 with torch_device_fn.device(inp.device): 

155 index_add_kernel[grid]( 

156 inp_contig, 

157 out_contig, 

158 index, 

159 src, 

160 M, 

161 N, 

162 alpha, 

163 inp_len, 

164 BLOCK_M=BLOCK_M, 

165 BLOCK_N=BLOCK_N, 

166 ) 

167 

168 if inp.is_contiguous(): 

169 inp.copy_(out_contig) 

170 else: 

171 inp.copy_(out_contig) 

172 

173 return inp