Coverage for src/flag_gems/ops/log_softmax.py: 53%
115 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def log_softmax_kernel(
18 output_ptr,
19 input_ptr,
20 M,
21 N,
22 K,
23 BLOCK_M: tl.constexpr = 8,
24 BLOCK_N: tl.constexpr = 256,
25):
26 pid_m = ext.program_id(0)
27 pid_k = ext.program_id(1)
28 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
30 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type
31 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32)
32 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
33 for start_n in range(0, N, BLOCK_N):
34 n_offset = start_n + tl.arange(0, BLOCK_N)
35 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
36 mask = m_offset[:, None] < M and n_offset[None, :] < N
37 input_ptrs = input_ptr + offset
38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
39 m_new = tl.maximum(inp, m)
40 all_neg_inf = m_new == float("-inf")
41 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
42 m = m_new
44 m_reduced = tl.max(m, 1)
45 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
46 m = m_reduced
48 for start_n in range(0, N, BLOCK_N):
49 n_offset = start_n + tl.arange(0, BLOCK_N)
50 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
51 mask = m_offset[:, None] < M and n_offset[None, :] < N
52 input_ptrs = input_ptr + offset
53 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
54 o = inp - m[:, None] - tl.log(z[:, None])
55 tl.store(output_ptr + offset, o, mask=mask)
58@libentry()
59@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
60@triton.jit
61def log_softmax_backward_kernel(
62 out_ptr,
63 out_grad_ptr,
64 in_grad_ptr,
65 M,
66 N,
67 K,
68 BLOCK_M: tl.constexpr,
69 BLOCK_N: tl.constexpr,
70):
71 pid_m = ext.program_id(0)
72 pid_k = ext.program_id(1)
73 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
75 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
76 for start_n in range(0, N, BLOCK_N):
77 n_offset = start_n + tl.arange(0, BLOCK_N)
78 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
79 mask = m_offset[:, None] < M and n_offset[None, :] < N
80 out_grad_ptrs = out_grad_ptr + offsets
81 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
82 scale += out_grad
83 scale = tl.sum(scale, 1)
85 for start_n in range(0, N, BLOCK_N):
86 n_offset = start_n + tl.arange(0, BLOCK_N)
87 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
88 mask = m_offset[:, None] < M and n_offset[None, :] < N
89 out_ptrs = out_ptr + offsets
90 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
91 out_grad_ptrs = out_grad_ptr + offsets
92 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
93 in_grad = out_grad - tl.exp(out) * scale[:, None]
94 in_grad_ptrs = in_grad_ptr + offsets
95 tl.store(in_grad_ptrs, in_grad, mask=mask)
98def log_softmax_out(self, dim, half_to_float=False, *, out):
99 logger.debug("GEMS LOG_SOFTMAX_OUT")
101 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
102 dim = dim % self.ndim
103 M = 1
104 N = self.shape[dim]
105 for i in range(dim):
106 M *= self.shape[i]
107 inp = self.contiguous()
108 if half_to_float:
109 dtype = torch.float32
110 else:
111 dtype = self.dtype
112 if tuple(out.shape) != tuple(inp.shape):
113 out.resize_(inp.shape)
114 if out.dtype != dtype:
115 raise RuntimeError(
116 f"_log_softmax.out: expected out dtype {dtype}, got {out.dtype}"
117 )
118 K = inp.numel() // M // N
120 grid = lambda meta: (
121 triton.cdiv(M, meta["BLOCK_M"]),
122 K,
123 )
124 with torch_device_fn.device(inp.device):
125 log_softmax_kernel[grid](
126 out,
127 inp,
128 M,
129 N,
130 K,
131 num_warps=8,
132 )
133 return out
136def log_softmax(self, dim, half_to_float=False):
137 logger.debug("GEMS LOG_SOFTMAX")
138 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
139 dim = dim % self.ndim
140 dtype = torch.float32 if half_to_float else self.dtype
141 out = torch.empty_like(self.contiguous(), dtype=dtype)
142 return log_softmax_out(self, dim, half_to_float, out=out)
145def log_softmax_backward_out(grad_output, output, dim, input_dtype, *, out):
146 logger.debug("GEMS LOG_SOFTMAX_BACKWARD_OUT")
148 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
149 dim = dim % output.ndim
150 M = 1
151 N = output.shape[dim]
152 for i in range(dim):
153 M *= output.shape[i]
155 grad_output = grad_output.contiguous()
156 if tuple(out.shape) != tuple(output.shape):
157 out.resize_(output.shape)
158 if out.dtype != input_dtype:
159 raise RuntimeError(
160 f"_log_softmax_backward_data.out: expected out dtype {input_dtype}, got {out.dtype}"
161 )
162 K = output.numel() // M // N
164 grid = lambda meta: (
165 triton.cdiv(M, meta["BLOCK_M"]),
166 K,
167 )
168 with torch_device_fn.device(out.device):
169 log_softmax_backward_kernel[grid](
170 output,
171 grad_output,
172 out,
173 M,
174 N,
175 K,
176 )
177 return out
180def log_softmax_backward(grad_output, output, dim, input_dtype):
181 logger.debug("GEMS LOG_SOFTMAX_BACKWARD")
182 in_grad = torch.empty_like(output, dtype=input_dtype)
183 return log_softmax_backward_out(grad_output, output, dim, input_dtype, out=in_grad)