Coverage for src/flag_gems/ops/_safe_softmax.py: 73%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def _safe_softmax_kernel(
15 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr
16):
17 row_id = tl.program_id(0)
18 cols = tl.arange(0, BLOCK_SIZE)
19 mask = cols < n_cols
21 row_offset = row_id * n_cols
22 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf"))
23 x_fp32 = x.to(tl.float32)
25 x_max = tl.max(x_fp32, axis=0)
26 all_neginf = x_max == -float("inf")
28 x_shifted = x_fp32 - x_max
29 exp_x = tl.exp(x_shifted)
30 sum_exp = tl.sum(exp_x, axis=0)
31 softmax = exp_x / sum_exp
33 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax)
35 tl.store(output_ptr + row_offset + cols, softmax, mask=mask)
38def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None):
39 logger.debug("GEMS _SAFE_SOFTMAX")
40 assert (
41 x.device.type == flag_gems.device
42 ), f"Input tensor must be on {flag_gems.device} device"
43 assert x.ndim >= 1, "Input tensor must have at least 1 dimension"
45 dim = dim if dim >= 0 else x.ndim + dim
46 assert 0 <= dim < x.ndim, "Invalid dim for softmax"
48 if dim != x.ndim - 1:
49 perm = list(range(x.ndim))
50 perm[dim], perm[-1] = perm[-1], perm[dim]
51 y = x.permute(perm).contiguous()
52 inv_perm = [0] * x.ndim
53 for i, p in enumerate(perm):
54 inv_perm[p] = i
55 else:
56 y = x.contiguous()
57 inv_perm = None
59 n_cols = y.shape[-1]
60 n_rows = y.numel() // n_cols
62 y_fp32 = y.float()
63 out_fp32 = torch.empty_like(y_fp32)
65 def _next_pow2(v: int) -> int:
66 if v <= 1:
67 return 1
68 return 1 << (v - 1).bit_length()
70 BLOCK_SIZE = min(4096, _next_pow2(n_cols))
71 grid = lambda meta: (n_rows,)
73 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE)
75 out = out_fp32
76 if dtype is not None:
77 out = out.to(dtype)
78 else:
79 out = out.to(x.dtype)
81 out = out.view(*y.shape)
82 if inv_perm is not None:
83 out = out.permute(inv_perm)
85 return out