Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_add.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems import runtime
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(
12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}"
13)
16@libentry()
17@triton.heuristics(runtime.get_heuristic_config("index_add"))
18@triton.jit
19def index_add_kernel(
20 inp_ptr,
21 out_ptr,
22 index_ptr,
23 src_ptr,
24 M,
25 N,
26 alpha,
27 inp_len,
28 BLOCK_M: tl.constexpr,
29 BLOCK_N: tl.constexpr,
30):
31 """
32 Kernel for index_add operation with autotune.
34 After dim_compress, tensors are reshaped so that:
35 - inp has shape (M, inp_len) where inp_len is the size of target dimension
36 - src has shape (M, N) where N is the size of index
38 For each row m and each index position n:
39 out[m, index[n]] += alpha * src[m, n]
40 """
41 pid_m = ext.program_id(axis=0)
42 pid_n = ext.program_id(axis=1)
44 # Calculate row and column offsets
45 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
46 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
48 # Create masks
49 rows_mask = rows_offset < M
50 cols_mask = cols_offset < N
51 block_mask = rows_mask & cols_mask
53 # Load indices for this block of columns
54 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0)
56 # Calculate offsets into inp/out (which has shape M x inp_len)
57 inp_off = rows_offset * inp_len + cur_indices
59 # Load current values from input
60 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0)
62 # Calculate offsets into src (which has shape M x N)
63 src_off = rows_offset * N + cols_offset
65 # Load source values
66 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0)
68 # Compute: out = inp + alpha * src
69 result = cur_inp + alpha * cur_src
71 # Store result
72 tl.store(out_ptr + inp_off, result, mask=block_mask)
75def index_add(inp, dim, index, src, alpha=1):
76 """
77 Optimized index_add for mthreads backend.
79 self.index_add_(dim, index, source, alpha=1) -> Tensor
81 For a 3-D tensor the output is:
82 self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0
83 self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1
84 self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2
85 """
86 logger.debug("GEMS_MTHREADS INDEX ADD")
88 # Make inputs contiguous
89 inp = inp.contiguous()
90 index = index.contiguous()
91 src = src.contiguous()
93 # Normalize dimension
94 dim = dim % inp.ndim
95 inp_len = inp.size(dim)
96 N = index.numel()
97 M = src.numel() // N
99 # Move target dim to last position for coalesced memory access
100 final_dim = inp.ndim - 1
101 if dim != final_dim:
102 inp = dim_compress(inp, dim)
103 src = dim_compress(src, dim)
105 # Clone input for output
106 out = inp.clone()
108 # Calculate grid with autotune
109 grid = lambda meta: (
110 triton.cdiv(M, meta["BLOCK_M"]),
111 triton.cdiv(N, meta["BLOCK_N"]),
112 )
114 with torch_device_fn.device(inp.device):
115 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
117 # Restore original dimension order if needed
118 if dim != final_dim:
119 order = list(range(out.ndim - 1))
120 order.insert(dim, final_dim)
121 return out.permute(order).contiguous()
122 else:
123 return out
126def index_add_(inp, dim, index, src, alpha=1):
127 """
128 In-place version of index_add.
129 """
130 logger.debug("GEMS_MTHREADS INDEX ADD_")
132 # Make index and src contiguous
133 index = index.contiguous()
134 src = src.contiguous()
136 # Normalize dimension
137 dim = dim % inp.ndim
138 inp_len = inp.size(dim)
139 N = index.numel()
140 M = src.numel() // N
142 # Move target dim to last position
143 final_dim = inp.ndim - 1
145 if dim != final_dim:
146 # Need to work on a permuted copy
147 inp_work = dim_compress(inp.clone().contiguous(), dim)
148 src_work = dim_compress(src, dim)
150 # Calculate grid with autotune
151 grid = lambda meta: (
152 triton.cdiv(M, meta["BLOCK_M"]),
153 triton.cdiv(N, meta["BLOCK_N"]),
154 )
156 with torch_device_fn.device(inp.device):
157 index_add_kernel[grid](
158 inp_work, inp_work, index, src_work, M, N, alpha, inp_len
159 )
161 # Restore original dimension order and copy back
162 order = list(range(inp_work.ndim - 1))
163 order.insert(dim, final_dim)
164 inp_work = inp_work.permute(order).contiguous()
165 inp.copy_(inp_work)
166 else:
167 # Can work directly on input if already contiguous
168 inp_contig = inp.contiguous()
170 # Calculate grid with autotune
171 grid = lambda meta: (
172 triton.cdiv(M, meta["BLOCK_M"]),
173 triton.cdiv(N, meta["BLOCK_N"]),
174 )
176 with torch_device_fn.device(inp.device):
177 index_add_kernel[grid](
178 inp_contig, inp_contig, index, src, M, N, alpha, inp_len
179 )
181 # Copy back if input wasn't contiguous
182 if not inp.is_contiguous():
183 inp.copy_(inp_contig)
185 return inp