Coverage for src/flag_gems/ops/_safe_softmax.py: 73%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _safe_softmax_kernel( 

15 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr 

16): 

17 row_id = tl.program_id(0) 

18 cols = tl.arange(0, BLOCK_SIZE) 

19 mask = cols < n_cols 

20 

21 row_offset = row_id * n_cols 

22 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf")) 

23 x_fp32 = x.to(tl.float32) 

24 

25 x_max = tl.max(x_fp32, axis=0) 

26 all_neginf = x_max == -float("inf") 

27 

28 x_shifted = x_fp32 - x_max 

29 exp_x = tl.exp(x_shifted) 

30 sum_exp = tl.sum(exp_x, axis=0) 

31 softmax = exp_x / sum_exp 

32 

33 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax) 

34 

35 tl.store(output_ptr + row_offset + cols, softmax, mask=mask) 

36 

37 

38def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None): 

39 logger.debug("GEMS _SAFE_SOFTMAX") 

40 assert ( 

41 x.device.type == flag_gems.device 

42 ), f"Input tensor must be on {flag_gems.device} device" 

43 assert x.ndim >= 1, "Input tensor must have at least 1 dimension" 

44 

45 dim = dim if dim >= 0 else x.ndim + dim 

46 assert 0 <= dim < x.ndim, "Invalid dim for softmax" 

47 

48 if dim != x.ndim - 1: 

49 perm = list(range(x.ndim)) 

50 perm[dim], perm[-1] = perm[-1], perm[dim] 

51 y = x.permute(perm).contiguous() 

52 inv_perm = [0] * x.ndim 

53 for i, p in enumerate(perm): 

54 inv_perm[p] = i 

55 else: 

56 y = x.contiguous() 

57 inv_perm = None 

58 

59 n_cols = y.shape[-1] 

60 n_rows = y.numel() // n_cols 

61 

62 y_fp32 = y.float() 

63 out_fp32 = torch.empty_like(y_fp32) 

64 

65 def _next_pow2(v: int) -> int: 

66 if v <= 1: 

67 return 1 

68 return 1 << (v - 1).bit_length() 

69 

70 BLOCK_SIZE = min(4096, _next_pow2(n_cols)) 

71 grid = lambda meta: (n_rows,) 

72 

73 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE) 

74 

75 out = out_fp32 

76 if dtype is not None: 

77 out = out.to(dtype) 

78 else: 

79 out = out.to(x.dtype) 

80 

81 out = out.view(*y.shape) 

82 if inv_perm is not None: 

83 out = out.permute(inv_perm) 

84 

85 return out