Coverage for src/flag_gems/runtime/backend/_nvidia/turing/ops/_safe_softmax.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops._safe_softmax import _safe_softmax as generic_safe_softmax
9logger = logging.getLogger(__name__)
12@triton.jit
13def _safe_softmax_kernel(
14 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr
15):
16 row_id = tl.program_id(0)
17 row_start = input_ptr + row_id * n_cols
18 out_start = output_ptr + row_id * n_cols
20 m_i = -float("inf")
21 l_i = 0.0
23 for col_offset in range(0, n_cols, BLOCK_SIZE):
24 cols = col_offset + tl.arange(0, BLOCK_SIZE)
25 mask = cols < n_cols
26 x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
27 x_fp32 = x.to(tl.float32)
29 m_ij = tl.max(x_fp32, axis=0)
30 m_i_new = tl.maximum(m_i, m_ij)
32 alpha = tl.where(m_i == -float("inf"), 0.0, tl.exp(m_i - m_i_new))
33 beta = tl.exp(x_fp32 - m_i_new)
35 sum_block = tl.sum(tl.where(mask, beta, 0.0), axis=0)
36 l_i = l_i * alpha + sum_block
37 m_i = m_i_new
39 all_neginf = m_i == -float("inf")
41 for col_offset in range(0, n_cols, BLOCK_SIZE):
42 cols = col_offset + tl.arange(0, BLOCK_SIZE)
43 mask = cols < n_cols
44 x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
45 x_fp32 = x.to(tl.float32)
47 p = tl.exp(x_fp32 - m_i) / l_i
48 p = tl.where(all_neginf, 0.0, p)
49 tl.store(out_start + cols, p, mask=mask)
52def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None):
53 # Fallback to generic implementation for non-FP32 dtypes to avoid performance regression
54 if x.dtype != torch.float32:
55 return generic_safe_softmax(x, dim=dim, dtype=dtype)
57 logger.debug("GEMS TURING FP32 _SAFE_SOFTMAX")
58 assert x.is_cuda, "Input tensor must be on CUDA device"
59 assert x.ndim >= 1, "Input tensor must have at least 1 dimension"
61 dim = dim if dim >= 0 else x.ndim + dim
62 assert 0 <= dim < x.ndim, "Invalid dim for softmax"
64 if dim != x.ndim - 1:
65 perm = list(range(x.ndim))
66 perm[dim], perm[-1] = perm[-1], perm[dim]
67 y = x.permute(perm).contiguous()
68 inv_perm = [0] * x.ndim
69 for i, p in enumerate(perm):
70 inv_perm[p] = i
71 else:
72 y = x.contiguous()
73 inv_perm = None
75 n_cols = y.shape[-1]
76 n_rows = y.numel() // n_cols
78 y_fp32 = y.float()
79 out_fp32 = torch.empty_like(y_fp32)
81 def _next_pow2(v: int) -> int:
82 if v <= 1:
83 return 1
84 return 1 << (v - 1).bit_length()
86 BLOCK_SIZE = min(4096, _next_pow2(n_cols))
88 num_warps = 4
89 if BLOCK_SIZE >= 2048:
90 num_warps = 8
91 if BLOCK_SIZE >= 4096:
92 num_warps = 16
94 grid = lambda meta: (n_rows,)
96 _safe_softmax_kernel[grid](
97 y_fp32, out_fp32, n_rows, n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE
98 )
100 out = out_fp32
101 if dtype is not None:
102 out = out.to(dtype)
103 else:
104 out = out.to(x.dtype)
106 out = out.view(*y.shape)
107 if inv_perm is not None:
108 out = out.permute(inv_perm)
110 return out