Coverage for src/flag_gems/runtime/backend/_ascend/ops/softmax.py: 0%
256 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_non_inner"])
18@triton.jit
19def softmax_kernel_non_inner(
20 output_ptr,
21 input_ptr,
22 M,
23 N,
24 K,
25 TILE_N: tl.constexpr,
26 TILE_K: tl.constexpr,
27 ONE_TILE_PER_CTA: tl.constexpr,
28):
29 pid_k = ext.program_id(1)
30 pid_m = ext.program_id(0)
32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
34 if ONE_TILE_PER_CTA:
35 n_offsets = tl.arange(0, TILE_N)
36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
37 mask = (n_offsets[:, None] < N) & (k_offsets < K)
38 input_ptrs = input_ptr + offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
40 m = tl.max(inp, 0)
41 e = tl.exp(inp - m[None, :])
42 z = tl.sum(e, 0)
43 out = e / z
44 output_ptrs = output_ptr + offset
45 tl.store(output_ptrs, out, mask=mask)
46 else:
47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
50 # specialization does not improve performance inn this example, as tested
51 for start_n in range(0, N, TILE_N):
52 n_offsets = start_n + tl.arange(0, TILE_N)
53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
54 mask = (n_offsets[:, None] < N) & (k_offsets < K)
55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
56 m_new = tl.maximum(m, inp)
57 all_neg_inf = m_new == float("-inf")
58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
59 m = m_new
61 m_reduced = tl.max(m, 0) # (TILE_K,)
62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
63 m = m_reduced
65 # specialization does not improve performance inn this example, as tested
66 previous_multiple = prev_multiple_of(N, TILE_N)
67 for start_n in range(0, N, TILE_N):
68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
72 o = tl.exp(inp - m[None, :]) / z[None, :]
73 tl.store(output_ptr + offsets, o, mask=mask)
76@triton.jit
77def next_multiple_of(a, b):
78 # the smallest x>=a that x%b ==0
79 return tl.cidv(a, b) * b
82@triton.jit
83def prev_multiple_of(a, b):
84 # the largest x<a that x%b ==0
85 return tl.cdiv(a, b) * b - b
88@libentry()
89@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_inner"])
90@triton.jit
91def softmax_kernel_inner(
92 output_ptr,
93 input_ptr,
94 M,
95 N,
96 TILE_N: tl.constexpr,
97 ONE_TILE_PER_CTA: tl.constexpr,
98):
99 pid_m = ext.program_id(0)
100 if ONE_TILE_PER_CTA:
101 n_offsets = tl.arange(0, TILE_N)
102 offset = pid_m * N + n_offsets
103 input_ptrs = input_ptr + offset
104 mask = n_offsets < N
105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
106 output_ptr.dtype.element_ty
107 )
108 m = tl.max(inp, 0)
109 e = tl.exp(inp - m)
110 z = tl.sum(e, 0)
111 out = e / z
112 output_ptrs = output_ptr + offset
113 tl.store(output_ptrs, out, mask=mask)
114 else:
115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
117 input_ptr += pid_m * N
118 output_ptr += pid_m * N
120 previous_multiple = prev_multiple_of(N, TILE_N)
121 for start_n in range(0, previous_multiple, TILE_N):
122 n_offsets = start_n + tl.arange(0, TILE_N)
123 inp = tl.load(input_ptr + n_offsets)
124 m_new = tl.maximum(m, inp)
125 # it is possible that there are -inf's in the input
126 all_neg_inf = m_new == float("-inf")
127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
128 m = m_new
129 # specialize the last iteration
130 for start_n in range(previous_multiple, N, TILE_N):
131 n_offsets = start_n + tl.arange(0, TILE_N)
132 mask = n_offsets < N
133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
134 m_new = tl.maximum(m, inp)
135 all_neg_inf = m_new == float("-inf")
136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
137 m = m_new
139 m_reduced = tl.max(m, 0)
140 z = tl.sum(z * tl.exp(m - m_reduced), 0)
141 m = m_reduced
143 previous_multiple = prev_multiple_of(N, TILE_N)
144 # specialize the first iteration
145 for start_n in range(0, TILE_N, TILE_N):
146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
147 mask = n_offsets < N
148 inp = tl.load(
149 input_ptr + n_offsets,
150 mask=mask,
151 other=-float("inf"),
152 eviction_policy="evict_first",
153 )
154 o = tl.exp(inp - m) / z
155 tl.store(output_ptr + n_offsets, o, mask=mask)
156 for start_n in range(TILE_N, N, TILE_N):
157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
159 o = tl.exp(inp - m) / z
160 tl.store(output_ptr + n_offsets, o)
163# ------------------------ backward -------------------------------
164@libentry()
165@triton.autotune(
166 configs=runtime.get_tuned_config("softmax_non_inner"),
167 key=[
168 "M",
169 "N",
170 "K",
171 ],
172)
173@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_backward_non_inner"])
174@triton.jit
175def softmax_backward_kernel_non_inner(
176 out_ptr,
177 out_grad_ptr,
178 in_grad_ptr,
179 M,
180 N,
181 K,
182 TILE_N: tl.constexpr,
183 TILE_K: tl.constexpr,
184 ONE_TILE_PER_CTA: tl.constexpr,
185):
186 pid_m = ext.program_id(0)
187 pid_k = ext.program_id(1)
188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
190 if ONE_TILE_PER_CTA:
191 offsets_n = tl.arange(0, TILE_N)
192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
193 mask = (offsets_n < N)[:, None] & (offsets_k < K)
194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
196 scale = tl.sum(out_tile * out_grad_tile, axis=0)
197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
199 else:
200 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
201 for off in range(0, N, TILE_N):
202 offsets_n = tl.arange(0, TILE_N) + off
203 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
204 mask = (offsets_n < N)[:, None] & (offsets_k < K)
205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
207 scale += out_tile * out_grad_tile
208 # offsets_n += TILE_N
209 # offsets += TILE_N * K
210 scale = tl.sum(scale, axis=0) # (TILE_K)
212 for off in range(0, N, TILE_N):
213 offsets_n = tl.arange(0, TILE_N) + off
214 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
215 mask = (offsets_n < N)[:, None] & (offsets_k < K)
216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
220 # offsets_n += TILE_N
221 # offsets += TILE_N * K
224@libentry()
225@triton.autotune(
226 configs=runtime.get_tuned_config("softmax_inner"),
227 key=["M", "N"],
228)
229@triton.heuristics(
230 values=_hcu.HEURISTICS_CONFIGS["softmax_backward_inner"],
231)
232@triton.jit
233def softmax_backward_kernel_inner(
234 out_ptr,
235 out_grad_ptr,
236 in_grad_ptr,
237 M,
238 N,
239 TILE_M: tl.constexpr,
240 TILE_N: tl.constexpr,
241 ONE_TILE_PER_CTA: tl.constexpr,
242):
243 pid_m = ext.program_id(0)
244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
245 if ONE_TILE_PER_CTA:
246 n_offsets = tl.arange(0, TILE_N)
247 offsets = m_offsets[:, None] * N + n_offsets
248 mask = (m_offsets[:, None] < M) & (n_offsets < N)
249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
251 scale = tl.sum(out_tile * out_grad_tile, 1)
252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
254 else:
255 # scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
256 scale = tl.zeros([TILE_M], dtype=tl.float32)
258 for off in range(0, N, TILE_N):
259 n_offsets = tl.arange(0, TILE_N) + off
260 offsets = m_offsets[:, None] * N + n_offsets
261 mask = (m_offsets[:, None] < M) & (n_offsets < N)
262 out_tile = tl.load(
263 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
264 ).to(tl.float32)
265 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
266 # scale += out_tile * out_grad_tile
267 scale += tl.sum(out_tile * out_grad_tile, axis=1)
268 # n_offsets += TILE_N
269 # offsets += TILE_N
270 # scale = tl.sum(scale, 1) # (TILE_M,)
272 for off in range(0, N, TILE_N):
273 n_offsets = tl.arange(0, TILE_N) + off
274 offsets = m_offsets[:, None] * N + n_offsets
275 mask = (m_offsets[:, None] < M) & (n_offsets < N)
276 out_tile = tl.load(
277 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
278 ).to(tl.float32)
279 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
280 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
281 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
282 # n_offsets += TILE_N
283 # offsets += TILE_N
286def softmax(self, dim, half_to_float=False):
287 logger.debug("GEMS_ASCEND SOFTMAX")
289 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
290 dim = dim % self.ndim
291 M = 1
292 N = self.shape[dim]
293 for i in range(dim):
294 M *= self.shape[i]
295 self = self.contiguous()
296 dtype = torch.float32 if half_to_float else self.dtype
297 out = torch.empty_like(self, dtype=dtype)
298 if N == 0 or self.numel() == 0:
299 return out
300 K = self.numel() // M // N
302 with torch_device_fn.device(self.device):
303 if K > 1:
304 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
305 softmax_kernel_non_inner[grid](out, self, M, N, K)
306 else:
307 grid = (M, 1, 1)
308 softmax_kernel_inner[grid](out, self, M, N)
309 return out
312def softmax_out(self, dim, half_to_float=False, *, out):
313 logger.debug("GEMS_ASCEND SOFTMAX_OUT")
315 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
317 if self.numel() == 0:
318 if tuple(out.shape) != tuple(self.shape):
319 out.resize_(self.shape)
320 return out
322 dim = dim % self.ndim
323 M = 1
324 N = self.shape[dim]
325 for i in range(dim):
326 M *= self.shape[i]
327 self = self.contiguous()
328 if tuple(out.shape) != tuple(self.shape):
329 out.resize_(self.shape)
330 K = self.numel() // M // N
332 with torch_device_fn.device(self.device):
333 if K > 1:
334 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
335 softmax_kernel_non_inner[grid](out, self, M, N, K)
336 else:
337 grid = (M, 1, 1)
338 softmax_kernel_inner[grid](out, self, M, N)
339 return out
342def softmax_backward(grad_output, output, dim, input_dtype):
343 logger.debug("GEMS_ASCEND SOFTMAX BACKWARD")
345 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
346 dim = dim % output.ndim
347 M = 1
348 N = output.shape[dim]
349 for i in range(dim):
350 M *= output.shape[i]
352 grad_output = grad_output.contiguous()
353 in_grad = torch.empty_like(output, dtype=input_dtype)
354 K = output.numel() // M // N
356 with torch_device_fn.device(in_grad.device):
357 if K > 1:
358 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
359 softmax_backward_kernel_non_inner[grid](
360 output, grad_output, in_grad, M, N, K
361 )
362 else:
363 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
364 softmax_backward_kernel_inner[grid](output, grad_output, in_grad, M, N)
365 return in_grad
368def softmax_backward_out(grad_output, output, dim, input_dtype, *, grad_input):
369 logger.debug("GEMS_ASCEND SOFTMAX BACKWARD_OUT")
371 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
372 dim = dim % output.ndim
373 M = 1
374 N = output.shape[dim]
375 for i in range(dim):
376 M *= output.shape[i]
378 grad_output = grad_output.contiguous()
379 if tuple(grad_input.shape) != tuple(output.shape):
380 grad_input.resize_(output.shape)
381 K = output.numel() // M // N
383 with torch_device_fn.device(grad_input.device):
384 if K > 1:
385 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
386 softmax_backward_kernel_non_inner[grid](
387 output, grad_output, grad_input, M, N, K
388 )
389 else:
390 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
391 softmax_backward_kernel_inner[grid](output, grad_output, grad_input, M, N)
392 return grad_input