Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_add.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils import dim_compress, libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger(__name__)
13@libentry()
14@triton.jit
15def index_add_kernel(
16 inp_ptr,
17 out_ptr,
18 index_ptr,
19 src_ptr,
20 M,
21 N,
22 alpha,
23 inp_len,
24 BLOCK_M: tl.constexpr,
25 BLOCK_N: tl.constexpr,
26):
27 pid_m = tle.program_id(axis=0)
28 pid_n = tle.program_id(axis=1)
30 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
31 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
33 rows_mask = rows_offset < M
34 cols_mask = cols_offset < N
35 block_mask = rows_mask & cols_mask
37 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0)
39 inp_off = rows_offset * inp_len + cur_indices
40 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0)
42 src_off = rows_offset * N + cols_offset
43 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0)
45 result = cur_inp + alpha * cur_src
46 tl.store(out_ptr + inp_off, result, mask=block_mask)
49def _get_block_config(M, N):
50 BLOCK_M = 4 if M < 4096 else 8
51 BLOCK_N = max(4, min(512, triton.next_power_of_2(N)))
52 return BLOCK_M, BLOCK_N
55def index_add(inp, dim, index, src, alpha=1):
56 logger.debug("GEMS_ASCEND INDEX ADD")
58 inp = inp.contiguous()
59 index = index.contiguous()
60 src = src.contiguous()
62 dim = dim % inp.ndim
63 inp_len = inp.size(dim)
64 N = index.numel()
65 M = src.numel() // N
67 final_dim = inp.ndim - 1
68 if dim != final_dim:
69 inp = dim_compress(inp, dim)
70 src = dim_compress(src, dim)
72 out = inp.clone()
74 BLOCK_M, BLOCK_N = _get_block_config(M, N)
75 grid = (
76 triton.cdiv(M, BLOCK_M),
77 triton.cdiv(N, BLOCK_N),
78 )
80 with torch_device_fn.device(inp.device):
81 index_add_kernel[grid](
82 inp,
83 out,
84 index,
85 src,
86 M,
87 N,
88 alpha,
89 inp_len,
90 BLOCK_M=BLOCK_M,
91 BLOCK_N=BLOCK_N,
92 )
94 if dim != final_dim:
95 order = list(range(out.ndim - 1))
96 order.insert(dim, final_dim)
97 return out.permute(order).contiguous()
98 else:
99 return out
102def index_add_(inp, dim, index, src, alpha=1):
103 logger.debug("GEMS_ASCEND INDEX ADD_")
105 index = index.contiguous()
106 src = src.contiguous()
108 dim = dim % inp.ndim
109 inp_len = inp.size(dim)
110 N = index.numel()
111 M = src.numel() // N
113 final_dim = inp.ndim - 1
115 if dim != final_dim:
116 inp_work = dim_compress(inp.clone().contiguous(), dim)
117 src_work = dim_compress(src, dim)
118 out_work = inp_work.clone()
120 BLOCK_M, BLOCK_N = _get_block_config(M, N)
121 grid = (
122 triton.cdiv(M, BLOCK_M),
123 triton.cdiv(N, BLOCK_N),
124 )
126 with torch_device_fn.device(inp.device):
127 index_add_kernel[grid](
128 inp_work,
129 out_work,
130 index,
131 src_work,
132 M,
133 N,
134 alpha,
135 inp_len,
136 BLOCK_M=BLOCK_M,
137 BLOCK_N=BLOCK_N,
138 )
140 order = list(range(out_work.ndim - 1))
141 order.insert(dim, final_dim)
142 inp_work = out_work.permute(order).contiguous()
143 inp.copy_(inp_work)
144 else:
145 inp_contig = inp.contiguous()
146 out_contig = inp_contig.clone()
148 BLOCK_M, BLOCK_N = _get_block_config(M, N)
149 grid = (
150 triton.cdiv(M, BLOCK_M),
151 triton.cdiv(N, BLOCK_N),
152 )
154 with torch_device_fn.device(inp.device):
155 index_add_kernel[grid](
156 inp_contig,
157 out_contig,
158 index,
159 src,
160 M,
161 N,
162 alpha,
163 inp_len,
164 BLOCK_M=BLOCK_M,
165 BLOCK_N=BLOCK_N,
166 )
168 if inp.is_contiguous():
169 inp.copy_(out_contig)
170 else:
171 inp.copy_(out_contig)
173 return inp