Coverage for src/flag_gems/ops/log_softmax.py: 53%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def log_softmax_kernel( 

18 output_ptr, 

19 input_ptr, 

20 M, 

21 N, 

22 K, 

23 BLOCK_M: tl.constexpr = 8, 

24 BLOCK_N: tl.constexpr = 256, 

25): 

26 pid_m = tle.program_id(0) 

27 pid_k = tle.program_id(1) 

28 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

29 

30 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type 

31 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32) 

32 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

33 for start_n in range(0, N, BLOCK_N): 

34 n_offset = start_n + tl.arange(0, BLOCK_N) 

35 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

36 mask = m_offset[:, None] < M and n_offset[None, :] < N 

37 input_ptrs = input_ptr + offset 

38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

39 m_new = tl.maximum(inp, m) 

40 all_neg_inf = m_new == float("-inf") 

41 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

42 m = m_new 

43 

44 m_reduced = tl.max(m, 1) 

45 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

46 m = m_reduced 

47 

48 for start_n in range(0, N, BLOCK_N): 

49 n_offset = start_n + tl.arange(0, BLOCK_N) 

50 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

51 mask = m_offset[:, None] < M and n_offset[None, :] < N 

52 input_ptrs = input_ptr + offset 

53 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

54 o = inp - m[:, None] - tl.log(z[:, None]) 

55 tl.store(output_ptr + offset, o, mask=mask) 

56 

57 

58@libentry() 

59@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

60@triton.jit 

61def log_softmax_backward_kernel( 

62 out_ptr, 

63 out_grad_ptr, 

64 in_grad_ptr, 

65 M, 

66 N, 

67 K, 

68 BLOCK_M: tl.constexpr, 

69 BLOCK_N: tl.constexpr, 

70): 

71 pid_m = tle.program_id(0) 

72 pid_k = tle.program_id(1) 

73 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

74 

75 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

76 for start_n in range(0, N, BLOCK_N): 

77 n_offset = start_n + tl.arange(0, BLOCK_N) 

78 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

79 mask = m_offset[:, None] < M and n_offset[None, :] < N 

80 out_grad_ptrs = out_grad_ptr + offsets 

81 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

82 scale += out_grad 

83 scale = tl.sum(scale, 1) 

84 

85 for start_n in range(0, N, BLOCK_N): 

86 n_offset = start_n + tl.arange(0, BLOCK_N) 

87 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

88 mask = m_offset[:, None] < M and n_offset[None, :] < N 

89 out_ptrs = out_ptr + offsets 

90 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

91 out_grad_ptrs = out_grad_ptr + offsets 

92 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

93 in_grad = out_grad - tl.exp(out) * scale[:, None] 

94 in_grad_ptrs = in_grad_ptr + offsets 

95 tl.store(in_grad_ptrs, in_grad, mask=mask) 

96 

97 

98def log_softmax_out(self, dim, half_to_float=False, *, out): 

99 logger.debug("GEMS LOG_SOFTMAX_OUT") 

100 

101 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

102 dim = dim % self.ndim 

103 M = 1 

104 N = self.shape[dim] 

105 for i in range(dim): 

106 M *= self.shape[i] 

107 inp = self.contiguous() 

108 if half_to_float: 

109 dtype = torch.float32 

110 else: 

111 dtype = self.dtype 

112 if tuple(out.shape) != tuple(inp.shape): 

113 out.resize_(inp.shape) 

114 if out.dtype != dtype: 

115 raise RuntimeError( 

116 f"_log_softmax.out: expected out dtype {dtype}, got {out.dtype}" 

117 ) 

118 K = inp.numel() // M // N 

119 

120 grid = lambda meta: ( 

121 triton.cdiv(M, meta["BLOCK_M"]), 

122 K, 

123 ) 

124 with torch_device_fn.device(inp.device): 

125 log_softmax_kernel[grid]( 

126 out, 

127 inp, 

128 M, 

129 N, 

130 K, 

131 num_warps=8, 

132 ) 

133 return out 

134 

135 

136def log_softmax(self, dim, half_to_float=False): 

137 logger.debug("GEMS LOG_SOFTMAX") 

138 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

139 dim = dim % self.ndim 

140 dtype = torch.float32 if half_to_float else self.dtype 

141 out = torch.empty_like(self.contiguous(), dtype=dtype) 

142 return log_softmax_out(self, dim, half_to_float, out=out) 

143 

144 

145def log_softmax_backward_out(grad_output, output, dim, input_dtype, *, out): 

146 logger.debug("GEMS LOG_SOFTMAX_BACKWARD_OUT") 

147 

148 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

149 dim = dim % output.ndim 

150 M = 1 

151 N = output.shape[dim] 

152 for i in range(dim): 

153 M *= output.shape[i] 

154 

155 grad_output = grad_output.contiguous() 

156 if tuple(out.shape) != tuple(output.shape): 

157 out.resize_(output.shape) 

158 if out.dtype != input_dtype: 

159 raise RuntimeError( 

160 f"_log_softmax_backward_data.out: expected out dtype {input_dtype}, got {out.dtype}" 

161 ) 

162 K = output.numel() // M // N 

163 

164 grid = lambda meta: ( 

165 triton.cdiv(M, meta["BLOCK_M"]), 

166 K, 

167 ) 

168 with torch_device_fn.device(out.device): 

169 log_softmax_backward_kernel[grid]( 

170 output, 

171 grad_output, 

172 out, 

173 M, 

174 N, 

175 K, 

176 ) 

177 return out 

178 

179 

180def log_softmax_backward(grad_output, output, dim, input_dtype): 

181 logger.debug("GEMS LOG_SOFTMAX_BACKWARD") 

182 in_grad = torch.empty_like(output, dtype=input_dtype) 

183 return log_softmax_backward_out(grad_output, output, dim, input_dtype, out=in_grad)