Coverage for src/flag_gems/ops/softmax.py: 36%
247 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.zeros import zero_
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
18@triton.jit
19def softmax_kernel_non_inner(
20 output_ptr,
21 input_ptr,
22 M,
23 N,
24 K,
25 TILE_N: tl.constexpr,
26 TILE_K: tl.constexpr,
27 ONE_TILE_PER_CTA: tl.constexpr,
28):
29 pid_k = ext.program_id(1)
30 pid_m = ext.program_id(0)
32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
34 if ONE_TILE_PER_CTA:
35 n_offsets = tl.arange(0, TILE_N)
36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
37 mask = (n_offsets[:, None] < N) & (k_offsets < K)
38 input_ptrs = input_ptr + offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
40 m = tl.max(inp, 0)
41 e = tl.exp(inp - m[None, :])
42 z = tl.sum(e, 0)
43 out = e / z
44 output_ptrs = output_ptr + offset
45 tl.store(output_ptrs, out, mask=mask)
46 else:
47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
50 # specialization does not improve performance inn this example, as tested
51 for start_n in range(0, N, TILE_N):
52 n_offsets = start_n + tl.arange(0, TILE_N)
53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
54 mask = (n_offsets[:, None] < N) & (k_offsets < K)
55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
56 m_new = tl.maximum(m, inp)
57 all_neg_inf = m_new == float("-inf")
58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
59 m = m_new
61 m_reduced = tl.max(m, 0) # (TILE_K,)
62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
63 m = m_reduced
65 # specialization does not improve performance inn this example, as tested
66 previous_multiple = prev_multiple_of(N, TILE_N)
67 for start_n in range(0, N, TILE_N):
68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
72 o = tl.exp(inp - m[None, :]) / z[None, :]
73 tl.store(output_ptr + offsets, o, mask=mask)
76@triton.jit
77def next_multiple_of(a, b):
78 # the smallest x>=a that x%b ==0
79 return tl.cidv(a, b) * b
82@triton.jit
83def prev_multiple_of(a, b):
84 # the largest x<a that x%b ==0
85 return tl.cdiv(a, b) * b - b
88@libentry()
89@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
90@triton.jit
91def softmax_kernel_inner(
92 output_ptr,
93 input_ptr,
94 M,
95 N,
96 TILE_N: tl.constexpr,
97 ONE_TILE_PER_CTA: tl.constexpr,
98):
99 pid_m = ext.program_id(0)
100 if ONE_TILE_PER_CTA:
101 n_offsets = tl.arange(0, TILE_N)
102 offset = pid_m * N + n_offsets
103 input_ptrs = input_ptr + offset
104 mask = n_offsets < N
105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
106 output_ptr.dtype.element_ty
107 )
108 m = tl.max(inp, 0)
109 e = tl.exp(inp - m)
110 z = tl.sum(e, 0)
111 out = e / z
112 output_ptrs = output_ptr + offset
113 tl.store(output_ptrs, out, mask=mask)
114 else:
115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
117 input_ptr += pid_m * N
118 output_ptr += pid_m * N
120 previous_multiple = prev_multiple_of(N, TILE_N)
121 for start_n in range(0, previous_multiple, TILE_N):
122 n_offsets = start_n + tl.arange(0, TILE_N)
123 inp = tl.load(input_ptr + n_offsets)
124 m_new = tl.maximum(m, inp)
125 # it is possible that there are -inf's in the input
126 all_neg_inf = m_new == float("-inf")
127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
128 m = m_new
129 # specialize the last iteration
130 for start_n in range(previous_multiple, N, TILE_N):
131 n_offsets = start_n + tl.arange(0, TILE_N)
132 mask = n_offsets < N
133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
134 m_new = tl.maximum(m, inp)
135 all_neg_inf = m_new == float("-inf")
136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
137 m = m_new
139 m_reduced = tl.max(m, 0)
140 z = tl.sum(z * tl.exp(m - m_reduced), 0)
141 m = m_reduced
143 previous_multiple = prev_multiple_of(N, TILE_N)
144 # specialize the first iteration
145 for start_n in range(0, TILE_N, TILE_N):
146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
147 mask = n_offsets < N
148 inp = tl.load(
149 input_ptr + n_offsets,
150 mask=mask,
151 other=-float("inf"),
152 eviction_policy="evict_first",
153 )
154 o = tl.exp(inp - m) / z
155 tl.store(output_ptr + n_offsets, o, mask=mask)
156 for start_n in range(TILE_N, N, TILE_N):
157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
159 o = tl.exp(inp - m) / z
160 tl.store(output_ptr + n_offsets, o)
163# ------------------------ backward -------------------------------
164@libentry()
165@triton.autotune(
166 configs=runtime.get_tuned_config("softmax_non_inner"),
167 key=[
168 "M",
169 "N",
170 "K",
171 ],
172)
173@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
174@triton.jit
175def softmax_backward_kernel_non_inner(
176 out_ptr,
177 out_grad_ptr,
178 in_grad_ptr,
179 M,
180 N,
181 K,
182 TILE_N: tl.constexpr,
183 TILE_K: tl.constexpr,
184 ONE_TILE_PER_CTA: tl.constexpr,
185):
186 pid_m = ext.program_id(0)
187 pid_k = ext.program_id(1)
188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
190 if ONE_TILE_PER_CTA:
191 offsets_n = tl.arange(0, TILE_N)
192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
193 mask = (offsets_n < N)[:, None] & (offsets_k < K)
194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
196 scale = tl.sum(out_tile * out_grad_tile, axis=0)
197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
199 else:
200 offsets_n = tl.arange(0, TILE_N)
201 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
202 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
203 for _ in range(0, N, TILE_N):
204 mask = (offsets_n < N)[:, None] & (offsets_k < K)
205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
207 scale += out_tile * out_grad_tile
208 offsets_n += TILE_N
209 offsets += TILE_N * K
210 scale = tl.sum(scale, axis=0) # (TILE_K)
212 offsets_n = tl.arange(0, TILE_N)
213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
214 for _ in range(0, N, TILE_N):
215 mask = (offsets_n < N)[:, None] & (offsets_k < K)
216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
220 offsets_n += TILE_N
221 offsets += TILE_N * K
224@libentry()
225@triton.autotune(
226 configs=runtime.get_tuned_config("softmax_inner"),
227 key=["M", "N"],
228)
229@triton.heuristics(
230 values=runtime.get_heuristic_config("softmax_backward_inner"),
231)
232@triton.jit
233def softmax_backward_kernel_inner(
234 out_ptr,
235 out_grad_ptr,
236 in_grad_ptr,
237 M,
238 N,
239 TILE_M: tl.constexpr,
240 TILE_N: tl.constexpr,
241 ONE_TILE_PER_CTA: tl.constexpr,
242):
243 pid_m = ext.program_id(0)
244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
245 if ONE_TILE_PER_CTA:
246 n_offsets = tl.arange(0, TILE_N)
247 offsets = m_offsets[:, None] * N + n_offsets
248 mask = (m_offsets[:, None] < M) & (n_offsets < N)
249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
251 scale = tl.sum(out_tile * out_grad_tile, 1)
252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
254 else:
255 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
257 n_offsets = tl.arange(0, TILE_N)
258 offsets = m_offsets[:, None] * N + n_offsets
259 for _ in range(0, N, TILE_N):
260 mask = (m_offsets[:, None] < M) & (n_offsets < N)
261 out_tile = tl.load(
262 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
263 ).to(tl.float32)
264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
265 scale += out_tile * out_grad_tile
266 n_offsets += TILE_N
267 offsets += TILE_N
268 scale = tl.sum(scale, 1) # (TILE_M,)
270 n_offsets = tl.arange(0, TILE_N)
271 offsets = m_offsets[:, None] * N + n_offsets
272 for _ in range(0, N, TILE_N):
273 mask = (m_offsets[:, None] < M) & (n_offsets < N)
274 out_tile = tl.load(
275 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
276 ).to(tl.float32)
277 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
278 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
279 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
280 n_offsets += TILE_N
281 offsets += TILE_N
284def softmax_out(self, dim, half_to_float=False, *, out):
285 logger.debug("GEMS SOFTMAX_OUT")
287 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
289 if self.numel() == 0:
290 if tuple(out.shape) != tuple(self.shape):
291 out.resize_(self.shape)
292 zero_(out)
293 return out
295 dim = dim % self.ndim
296 M = 1
297 N = self.shape[dim]
298 for i in range(dim):
299 M *= self.shape[i]
300 self = self.contiguous()
301 dtype = torch.float32 if half_to_float else self.dtype
302 if tuple(out.shape) != tuple(self.shape):
303 out.resize_(self.shape)
304 if out.dtype != dtype:
305 raise RuntimeError(f"_softmax.out: expected out dtype {dtype}, got {out.dtype}")
306 K = self.numel() // M // N
308 with torch_device_fn.device(self.device):
309 if K > 1:
310 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
311 softmax_kernel_non_inner[grid](
312 out,
313 self,
314 M,
315 N,
316 K,
317 )
318 else:
319 grid = (M, 1, 1)
320 softmax_kernel_inner[grid](
321 out,
322 self,
323 M,
324 N,
325 )
326 return out
329def softmax(self, dim, half_to_float=False):
330 logger.debug("GEMS SOFTMAX")
332 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
334 if self.numel() == 0:
335 out_shape = list(self.shape)
336 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
337 zero_(out)
338 return out
340 dtype = torch.float32 if half_to_float else self.dtype
341 out = torch.empty_like(self, dtype=dtype)
342 return softmax_out(self, dim, half_to_float, out=out)
345def softmax_backward_out(grad_output, output, dim, input_dtype, *, grad_input):
346 logger.debug("GEMS SOFTMAX_BACKWARD_OUT")
348 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
349 dim = dim % output.ndim
350 M = 1
351 N = output.shape[dim]
352 for i in range(dim):
353 M *= output.shape[i]
355 grad_output = grad_output.contiguous()
356 if tuple(grad_input.shape) != tuple(output.shape):
357 grad_input.resize_(output.shape)
358 if grad_input.dtype != input_dtype:
359 raise RuntimeError(
360 f"_softmax_backward_data.out: expected grad_input dtype {input_dtype}, got {grad_input.dtype}"
361 )
362 K = output.numel() // M // N
364 with torch_device_fn.device(grad_input.device):
365 if K > 1:
366 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
367 softmax_backward_kernel_non_inner[grid](
368 output,
369 grad_output,
370 grad_input,
371 M,
372 N,
373 K,
374 )
375 else:
376 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
377 softmax_backward_kernel_inner[grid](
378 output,
379 grad_output,
380 grad_input,
381 M,
382 N,
383 )
384 return grad_input
387def softmax_backward(grad_output, output, dim, input_dtype):
388 logger.debug("GEMS SOFTMAX_BACKWARD")
389 in_grad = torch.empty_like(output, dtype=input_dtype)
390 return softmax_backward_out(
391 grad_output, output, dim, input_dtype, grad_input=in_grad
392 )