Coverage for src/flag_gems/runtime/backend/_spacemit/ops/softmax.py: 0%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.softmax import softmax_backward as common_softmax_backward 

8from flag_gems.utils import tl_extra_shim 

9 

10logger = logging.getLogger(__name__) 

11exp = tl_extra_shim.exp 

12 

13 

14@triton.jit 

15def softmax_kernel_spacemit( 

16 output_ptr, 

17 input_ptr, 

18 input_row_stride, 

19 output_row_stride, 

20 n_rows, 

21 n_cols, 

22 ROW_SIZE: tl.constexpr, 

23 COL_SIZE: tl.constexpr, 

24): 

25 row_start = tl.program_id(0) * ROW_SIZE 

26 element_ty = output_ptr.type.element_ty 

27 

28 for row_idx in range(row_start, row_start + ROW_SIZE): 

29 if row_idx < n_rows: 

30 denominator = tl.zeros((1,), dtype=tl.float32) 

31 row_max = tl.full((COL_SIZE,), value=-float("inf"), dtype=tl.float32) 

32 

33 for col_idx in range(0, n_cols, COL_SIZE): 

34 input_block_ptr = tl.make_block_ptr( 

35 base=input_ptr + row_idx * input_row_stride, 

36 shape=(n_cols,), 

37 strides=(1,), 

38 offsets=(col_idx,), 

39 block_shape=(COL_SIZE,), 

40 order=(0,), 

41 ) 

42 row = tl.load( 

43 input_block_ptr, boundary_check=(0,), padding_option="neg_inf" 

44 ).to(tl.float32) 

45 row_max = tl.maximum(row, row_max) 

46 

47 row_max_total = tl.max(row_max, axis=0) 

48 

49 for col_idx in range(0, n_cols, COL_SIZE): 

50 input_block_ptr = tl.make_block_ptr( 

51 base=input_ptr + row_idx * input_row_stride, 

52 shape=(n_cols,), 

53 strides=(1,), 

54 offsets=(col_idx,), 

55 block_shape=(COL_SIZE,), 

56 order=(0,), 

57 ) 

58 output_block_ptr = tl.make_block_ptr( 

59 base=output_ptr + row_idx * output_row_stride, 

60 shape=(n_cols,), 

61 strides=(1,), 

62 offsets=(col_idx,), 

63 block_shape=(COL_SIZE,), 

64 order=(0,), 

65 ) 

66 row = tl.load( 

67 input_block_ptr, boundary_check=(0,), padding_option="neg_inf" 

68 ).to(tl.float32) 

69 numerator = exp(row - row_max_total) 

70 denominator += tl.sum(numerator, axis=0) 

71 tl.store( 

72 output_block_ptr, numerator.to(element_ty), boundary_check=(0,) 

73 ) 

74 

75 inv_denom = 1.0 / denominator 

76 for col_idx in range(0, n_cols, COL_SIZE): 

77 output_block_ptr = tl.make_block_ptr( 

78 base=output_ptr + row_idx * output_row_stride, 

79 shape=(n_cols,), 

80 strides=(1,), 

81 offsets=(col_idx,), 

82 block_shape=(COL_SIZE,), 

83 order=(0,), 

84 ) 

85 exp_out = tl.load(output_block_ptr, boundary_check=(0,)).to(tl.float32) 

86 tl.store( 

87 output_block_ptr, 

88 (exp_out * inv_denom).to(element_ty), 

89 boundary_check=(0,), 

90 ) 

91 

92 

93def _spacemit_softmax_lastdim(inp, out): 

94 n_rows, n_cols = inp.shape 

95 row_size = 1 if n_rows < 2 else (2 if n_rows < 8 else 4) 

96 col_size = 64 

97 grid = lambda meta: (triton.cdiv(n_rows, meta["ROW_SIZE"]),) 

98 softmax_kernel_spacemit[grid]( 

99 out, 

100 inp, 

101 inp.stride(0), 

102 out.stride(0), 

103 n_rows, 

104 n_cols, 

105 ROW_SIZE=row_size, 

106 COL_SIZE=col_size, 

107 ) 

108 

109 

110def softmax(self, dim, half_to_float=False): 

111 logger.debug("GEMS_SPACEMIT SOFTMAX") 

112 

113 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

114 dim = dim % self.ndim 

115 

116 if half_to_float: 

117 dtype = torch.float32 

118 else: 

119 dtype = self.dtype 

120 

121 inp = self.contiguous() 

122 

123 n_cols = inp.shape[-1] 

124 n_rows = inp.numel() // n_cols 

125 inp_2d = inp.view(n_rows, n_cols) 

126 out_2d = torch.empty_like(inp_2d, dtype=dtype) 

127 _spacemit_softmax_lastdim(inp_2d, out_2d) 

128 return out_2d.view_as(inp) 

129 

130 

131def softmax_backward(grad_output, output, dim, input_dtype): 

132 logger.debug("GEMS_SPACEMIT SOFTMAX_VJP") 

133 return common_softmax_backward(grad_output, output, dim, input_dtype)