Coverage for src/flag_gems/runtime/backend/_nvidia/turing/ops/_safe_softmax.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops._safe_softmax import _safe_softmax as generic_safe_softmax 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _safe_softmax_kernel( 

14 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr 

15): 

16 row_id = tl.program_id(0) 

17 row_start = input_ptr + row_id * n_cols 

18 out_start = output_ptr + row_id * n_cols 

19 

20 m_i = -float("inf") 

21 l_i = 0.0 

22 

23 for col_offset in range(0, n_cols, BLOCK_SIZE): 

24 cols = col_offset + tl.arange(0, BLOCK_SIZE) 

25 mask = cols < n_cols 

26 x = tl.load(row_start + cols, mask=mask, other=-float("inf")) 

27 x_fp32 = x.to(tl.float32) 

28 

29 m_ij = tl.max(x_fp32, axis=0) 

30 m_i_new = tl.maximum(m_i, m_ij) 

31 

32 alpha = tl.where(m_i == -float("inf"), 0.0, tl.exp(m_i - m_i_new)) 

33 beta = tl.exp(x_fp32 - m_i_new) 

34 

35 sum_block = tl.sum(tl.where(mask, beta, 0.0), axis=0) 

36 l_i = l_i * alpha + sum_block 

37 m_i = m_i_new 

38 

39 all_neginf = m_i == -float("inf") 

40 

41 for col_offset in range(0, n_cols, BLOCK_SIZE): 

42 cols = col_offset + tl.arange(0, BLOCK_SIZE) 

43 mask = cols < n_cols 

44 x = tl.load(row_start + cols, mask=mask, other=-float("inf")) 

45 x_fp32 = x.to(tl.float32) 

46 

47 p = tl.exp(x_fp32 - m_i) / l_i 

48 p = tl.where(all_neginf, 0.0, p) 

49 tl.store(out_start + cols, p, mask=mask) 

50 

51 

52def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None): 

53 # Fallback to generic implementation for non-FP32 dtypes to avoid performance regression 

54 if x.dtype != torch.float32: 

55 return generic_safe_softmax(x, dim=dim, dtype=dtype) 

56 

57 logger.debug("GEMS TURING FP32 _SAFE_SOFTMAX") 

58 assert x.is_cuda, "Input tensor must be on CUDA device" 

59 assert x.ndim >= 1, "Input tensor must have at least 1 dimension" 

60 

61 dim = dim if dim >= 0 else x.ndim + dim 

62 assert 0 <= dim < x.ndim, "Invalid dim for softmax" 

63 

64 if dim != x.ndim - 1: 

65 perm = list(range(x.ndim)) 

66 perm[dim], perm[-1] = perm[-1], perm[dim] 

67 y = x.permute(perm).contiguous() 

68 inv_perm = [0] * x.ndim 

69 for i, p in enumerate(perm): 

70 inv_perm[p] = i 

71 else: 

72 y = x.contiguous() 

73 inv_perm = None 

74 

75 n_cols = y.shape[-1] 

76 n_rows = y.numel() // n_cols 

77 

78 y_fp32 = y.float() 

79 out_fp32 = torch.empty_like(y_fp32) 

80 

81 def _next_pow2(v: int) -> int: 

82 if v <= 1: 

83 return 1 

84 return 1 << (v - 1).bit_length() 

85 

86 BLOCK_SIZE = min(4096, _next_pow2(n_cols)) 

87 

88 num_warps = 4 

89 if BLOCK_SIZE >= 2048: 

90 num_warps = 8 

91 if BLOCK_SIZE >= 4096: 

92 num_warps = 16 

93 

94 grid = lambda meta: (n_rows,) 

95 

96 _safe_softmax_kernel[grid]( 

97 y_fp32, out_fp32, n_rows, n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE 

98 ) 

99 

100 out = out_fp32 

101 if dtype is not None: 

102 out = out.to(dtype) 

103 else: 

104 out = out.to(x.dtype) 

105 

106 out = out.view(*y.shape) 

107 if inv_perm is not None: 

108 out = out.permute(inv_perm) 

109 

110 return out