Coverage for src/flag_gems/runtime/backend/_sunrise/ops/log_softmax.py: 0%
189 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15# Filter (TILE_N, num_warps) pairs so each warp has at least 32 lanes.
16# Drops gross over-subscription (num_warps * 32 > TILE_N) which leaves most
17# lanes idle on tiny ONE_TILE_PER_CTA launches.
18_INNER_CONFIGS = [
19 triton.Config({"TILE_N": tile_n}, num_warps=num_warps)
20 for tile_n in (64, 128, 256, 512, 1024)
21 for num_warps in (1, 2, 4, 8, 16)
22 if num_warps * 32 <= tile_n
23]
26def _one_tile_per_cta(args):
27 return args["TILE_N"] >= args["N"]
30@triton.jit
31def _prev_multiple_of(a, b):
32 return tl.cdiv(a, b) * b - b
35@libentry()
36@triton.jit
37def log_softmax_kernel(
38 output_ptr,
39 input_ptr,
40 M,
41 N,
42 K,
43 BLOCK_M: tl.constexpr = 8,
44 BLOCK_N: tl.constexpr = 256,
45):
46 pid_m = ext.program_id(0)
47 pid_k = ext.program_id(1)
48 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
50 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type
51 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32)
52 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
53 for start_n in range(0, N, BLOCK_N):
54 n_offset = start_n + tl.arange(0, BLOCK_N)
55 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
56 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
57 input_ptrs = input_ptr + offset
58 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
59 m_new = tl.maximum(inp, m)
60 all_neg_inf = m_new == float("-inf")
61 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
62 m = m_new
64 m_reduced = tl.max(m, 1)
65 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
66 m = m_reduced
68 for start_n in range(0, N, BLOCK_N):
69 n_offset = start_n + tl.arange(0, BLOCK_N)
70 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
71 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
72 input_ptrs = input_ptr + offset
73 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
74 o = inp - m[:, None] - tl.log(z[:, None])
75 tl.store(output_ptr + offset, o, mask=mask)
78@libentry()
79@triton.autotune(configs=_INNER_CONFIGS, key=["M", "N"])
80@triton.heuristics({"ONE_TILE_PER_CTA": _one_tile_per_cta})
81@triton.jit
82def log_softmax_kernel_inner(
83 output_ptr,
84 input_ptr,
85 M,
86 N,
87 TILE_N: tl.constexpr,
88 ONE_TILE_PER_CTA: tl.constexpr,
89):
90 pid_m = ext.program_id(0)
92 if ONE_TILE_PER_CTA:
93 n_offsets = tl.arange(0, TILE_N)
94 offset = pid_m * N + n_offsets
95 mask = n_offsets < N
96 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(tl.float32)
97 m = tl.max(inp, 0)
98 e = tl.exp(inp - m)
99 z = tl.sum(e, 0)
100 out = inp - m - tl.log(z)
101 tl.store(output_ptr + offset, out, mask=mask)
102 else:
103 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
104 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
105 input_ptr += pid_m * N
106 output_ptr += pid_m * N
108 # Pass 1: mask-free hot loop + masked tail
109 previous_multiple = _prev_multiple_of(N, TILE_N)
110 for start_n in range(0, previous_multiple, TILE_N):
111 n_offset = start_n + tl.arange(0, TILE_N)
112 inp = tl.load(input_ptr + n_offset)
113 m_new = tl.maximum(m, inp)
114 all_neg_inf = m_new == float("-inf")
115 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
116 m = m_new
117 for start_n in range(previous_multiple, N, TILE_N):
118 n_offset = start_n + tl.arange(0, TILE_N)
119 mask = n_offset < N
120 inp = tl.load(input_ptr + n_offset, mask=mask, other=-float("inf"))
121 m_new = tl.maximum(m, inp)
122 all_neg_inf = m_new == float("-inf")
123 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
124 m = m_new
126 m_reduced = tl.max(m, 0)
127 z = tl.sum(z * tl.exp(m - m_reduced), 0)
128 m = m_reduced
129 log_z = tl.log(z)
131 # Pass 2: reverse traversal with eviction hints
132 previous_multiple = _prev_multiple_of(N, TILE_N)
133 for start_n in range(0, TILE_N, TILE_N):
134 n_offset = (previous_multiple - start_n) + tl.arange(0, TILE_N)
135 mask = n_offset < N
136 inp = tl.load(
137 input_ptr + n_offset,
138 mask=mask,
139 other=-float("inf"),
140 eviction_policy="evict_first",
141 )
142 o = inp - m - log_z
143 tl.store(output_ptr + n_offset, o, mask=mask)
144 for start_n in range(TILE_N, N, TILE_N):
145 n_offset = (previous_multiple - start_n) + tl.arange(0, TILE_N)
146 inp = tl.load(input_ptr + n_offset, eviction_policy="evict_first")
147 o = inp - m - log_z
148 tl.store(output_ptr + n_offset, o)
151@libentry()
152@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
153@triton.jit
154def log_softmax_backward_kernel(
155 out_ptr,
156 out_grad_ptr,
157 in_grad_ptr,
158 M,
159 N,
160 K,
161 BLOCK_M: tl.constexpr,
162 BLOCK_N: tl.constexpr,
163):
164 pid_m = ext.program_id(0)
165 pid_k = ext.program_id(1)
166 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
168 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
169 for start_n in range(0, N, BLOCK_N):
170 n_offset = start_n + tl.arange(0, BLOCK_N)
171 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
172 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
173 out_grad_ptrs = out_grad_ptr + offsets
174 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
175 scale += out_grad
176 scale = tl.sum(scale, 1)
178 for start_n in range(0, N, BLOCK_N):
179 n_offset = start_n + tl.arange(0, BLOCK_N)
180 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
181 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
182 out_ptrs = out_ptr + offsets
183 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
184 out_grad_ptrs = out_grad_ptr + offsets
185 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
186 in_grad = out_grad - tl.exp(out) * scale[:, None]
187 in_grad_ptrs = in_grad_ptr + offsets
188 tl.store(in_grad_ptrs, in_grad, mask=mask)
191@libentry()
192@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
193@triton.jit
194def log_softmax_backward_kernel_opt(
195 out_ptr,
196 out_grad_ptr,
197 in_grad_ptr,
198 M,
199 N,
200 BLOCK_M: tl.constexpr,
201 BLOCK_N: tl.constexpr,
202):
203 pid_m = ext.program_id(0)
204 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
206 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
207 for start_n in range(0, N, BLOCK_N):
208 n_offset = start_n + tl.arange(0, BLOCK_N)
209 offsets = m_offset[:, None] * N + n_offset[None, :]
210 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
211 out_grad_ptrs = out_grad_ptr + offsets
212 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
213 scale += out_grad
214 scale = tl.sum(scale, 1)
216 for start_n in range(0, N, BLOCK_N):
217 n_offset = start_n + tl.arange(0, BLOCK_N)
218 offsets = m_offset[:, None] * N + n_offset[None, :]
219 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
220 out_ptrs = out_ptr + offsets
221 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
222 out_grad_ptrs = out_grad_ptr + offsets
223 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
224 in_grad = out_grad - tl.exp(out) * scale[:, None]
225 in_grad_ptrs = in_grad_ptr + offsets
226 tl.store(in_grad_ptrs, in_grad, mask=mask)
229def log_softmax(self, dim, half_to_float=False):
230 logger.debug("GEMS LOG_SOFTMAX")
232 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
233 dim = dim % self.ndim
234 M = 1
235 N = self.shape[dim]
236 for i in range(dim):
237 M *= self.shape[i]
238 inp = self.contiguous()
239 if half_to_float:
240 dtype = torch.float32
241 else:
242 dtype = self.dtype
243 out = torch.empty_like(inp, dtype=dtype)
244 K = inp.numel() // M // N
246 with torch_device_fn.device(inp.device):
247 if K == 1:
248 grid = (M, 1, 1)
249 log_softmax_kernel_inner[grid](out, inp, M, N)
250 else:
251 grid = lambda meta: (
252 triton.cdiv(M, meta["BLOCK_M"]),
253 K,
254 )
255 log_softmax_kernel[grid](
256 out,
257 inp,
258 M,
259 N,
260 K,
261 num_warps=16,
262 )
263 return out
266def log_softmax_backward(grad_output, output, dim, input_dtype):
267 logger.debug("GEMS LOG_SOFTMAX VJP")
269 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
270 dim = dim % output.ndim
271 M = 1
272 N = output.shape[dim]
273 for i in range(dim):
274 M *= output.shape[i]
276 grad_output = grad_output.contiguous()
277 in_grad = torch.empty_like(output, dtype=input_dtype)
278 K = output.numel() // M // N
279 if K == 1:
280 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
281 with torch_device_fn.device(in_grad.device):
282 log_softmax_backward_kernel_opt[grid](
283 output,
284 grad_output,
285 in_grad,
286 M,
287 N,
288 )
289 else:
290 grid = lambda meta: (
291 triton.cdiv(M, meta["BLOCK_M"]),
292 K,
293 )
294 with torch_device_fn.device(in_grad.device):
295 log_softmax_backward_kernel[grid](
296 output,
297 grad_output,
298 in_grad,
299 M,
300 N,
301 K,
302 )
303 return in_grad