Coverage for src/flag_gems/ops/logsumexp.py: 36%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
17@triton.jit
18def logsumexp_kernel_non_inner(
19 output_ptr,
20 input_ptr,
21 M,
22 N,
23 K,
24 TILE_N: tl.constexpr,
25 TILE_K: tl.constexpr,
26 ONE_TILE_PER_CTA: tl.constexpr,
27):
28 """Kernel for logsumexp when reduction dimension is not the innermost."""
29 pid_m = ext.program_id(0)
30 pid_k = ext.program_id(1)
32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
34 if ONE_TILE_PER_CTA:
35 n_offsets = tl.arange(0, TILE_N)[:, None]
36 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
37 mask = (n_offsets < N) & (k_offsets < K)
38 input_ptrs = input_ptr + inp_offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
40 m = tl.max(inp, axis=0, keep_dims=True)
41 # Handle case where entire column is -inf
42 safe_m = tl.where(m == float("-inf"), tl.zeros_like(m), m)
43 e = tl.exp(inp - safe_m)
44 z = tl.sum(e, axis=0, keep_dims=True)
45 out = safe_m + tl.log(z)
46 # If all inputs were -inf, result should be -inf
47 out = tl.where(m == float("-inf"), m, out)
48 out_offset = pid_m * K + k_offsets
49 output_ptrs = output_ptr + out_offset
50 tl.store(output_ptrs, out, mask=k_offsets < K)
51 else:
52 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
53 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
55 for start_n in range(0, N, TILE_N):
56 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
57 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
58 mask = (n_offsets < N) & (k_offsets < K)
59 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=-float("inf")).to(
60 tl.float32
61 )
62 m_new = tl.maximum(m, inp)
63 all_neg_inf = m_new == float("-inf")
64 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
65 m = m_new
67 m_reduced = tl.max(m, axis=0, keep_dims=True)
68 z = tl.sum(z * tl.exp(m - m_reduced), axis=0, keep_dims=True)
69 m = m_reduced
70 # Handle case where all inputs were -inf
71 out = tl.where(m == float("-inf"), m, m + tl.log(z))
72 out_offset = pid_m * K + k_offsets
73 output_ptrs = output_ptr + out_offset
74 tl.store(output_ptrs, out, mask=k_offsets < K)
77@libentry()
78@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
79@triton.jit
80def logsumexp_kernel_inner(
81 output_ptr,
82 input_ptr,
83 M,
84 N,
85 TILE_N: tl.constexpr,
86 ONE_TILE_PER_CTA: tl.constexpr,
87):
88 """Kernel for logsumexp when reduction dimension is the innermost."""
89 pid_m = ext.program_id(0)
90 if ONE_TILE_PER_CTA:
91 n_offsets = tl.arange(0, TILE_N)
92 offset = pid_m * N + n_offsets
93 input_ptrs = input_ptr + offset
94 mask = n_offsets < N
95 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
96 m = tl.max(inp, axis=0)
97 # Handle case where all inputs are -inf
98 safe_m = tl.where(m == float("-inf"), 0.0, m)
99 e = tl.exp(inp - safe_m)
100 z = tl.sum(e, axis=0)
101 out = safe_m + tl.log(z)
102 # If all inputs were -inf, result should be -inf
103 out = tl.where(m == float("-inf"), m, out)
104 output_ptrs = output_ptr + pid_m
105 tl.store(output_ptrs, out)
106 else:
107 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
108 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
109 input_ptr += pid_m * N
111 for start_n in range(0, N, TILE_N):
112 n_offsets = start_n + tl.arange(0, TILE_N)
113 mask = n_offsets < N
114 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")).to(
115 tl.float32
116 )
117 m_new = tl.maximum(m, inp)
118 all_neg_inf = m_new == float("-inf")
119 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
120 m = m_new
122 m_reduced = tl.max(m, axis=0)
123 z = tl.sum(z * tl.exp(m - m_reduced), axis=0)
124 m = m_reduced
125 # Handle case where all inputs were -inf
126 out = tl.where(m == float("-inf"), m, m + tl.log(z))
127 output_ptrs = output_ptr + pid_m
128 tl.store(output_ptrs, out)
131def logsumexp(inp, dim, keepdim=False):
132 logger.debug("GEMS LOGSUMEXP")
134 if isinstance(dim, (list, tuple)):
135 # Handle multi-dimensional reduction
136 if len(dim) == 0:
137 # Empty dim list means no reduction, just return the input
138 return inp.clone()
139 if len(dim) == 1:
140 dim = dim[0]
141 else:
142 # For multiple dims, reduce sequentially
143 # Sort dims in descending order to handle dimension shifts correctly
144 sorted_dims = sorted([d % inp.ndim for d in dim], reverse=True)
145 result = inp
146 for d in sorted_dims:
147 result = logsumexp(result, d, keepdim=True)
148 if not keepdim:
149 # Remove the reduced dimensions
150 for d in sorted(sorted_dims, reverse=True):
151 result = result.squeeze(d)
152 return result
154 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
155 dim = dim % inp.ndim
156 M = 1
157 N = inp.shape[dim]
158 for i in range(dim):
159 M *= inp.shape[i]
160 inp = inp.contiguous()
161 K = inp.numel() // M // N
163 # Output shape with reduction dimension set to 1
164 shape = list(inp.shape)
165 shape[dim] = 1
166 out = torch.empty(shape, dtype=inp.dtype, device=inp.device)
168 with torch_device_fn.device(inp.device):
169 if K > 1:
170 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
171 logsumexp_kernel_non_inner[grid](
172 out,
173 inp,
174 M,
175 N,
176 K,
177 )
178 else:
179 grid = (M, 1, 1)
180 logsumexp_kernel_inner[grid](
181 out,
182 inp,
183 M,
184 N,
185 )
187 if not keepdim:
188 out = out.squeeze(dim=dim)
189 return out