Coverage for src/flag_gems/ops/logsumexp.py: 36%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

17@triton.jit 

18def logsumexp_kernel_non_inner( 

19 output_ptr, 

20 input_ptr, 

21 M, 

22 N, 

23 K, 

24 TILE_N: tl.constexpr, 

25 TILE_K: tl.constexpr, 

26 ONE_TILE_PER_CTA: tl.constexpr, 

27): 

28 """Kernel for logsumexp when reduction dimension is not the innermost.""" 

29 pid_m = ext.program_id(0) 

30 pid_k = ext.program_id(1) 

31 

32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

33 

34 if ONE_TILE_PER_CTA: 

35 n_offsets = tl.arange(0, TILE_N)[:, None] 

36 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

37 mask = (n_offsets < N) & (k_offsets < K) 

38 input_ptrs = input_ptr + inp_offset 

39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

40 m = tl.max(inp, axis=0, keep_dims=True) 

41 # Handle case where entire column is -inf 

42 safe_m = tl.where(m == float("-inf"), tl.zeros_like(m), m) 

43 e = tl.exp(inp - safe_m) 

44 z = tl.sum(e, axis=0, keep_dims=True) 

45 out = safe_m + tl.log(z) 

46 # If all inputs were -inf, result should be -inf 

47 out = tl.where(m == float("-inf"), m, out) 

48 out_offset = pid_m * K + k_offsets 

49 output_ptrs = output_ptr + out_offset 

50 tl.store(output_ptrs, out, mask=k_offsets < K) 

51 else: 

52 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

53 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

54 

55 for start_n in range(0, N, TILE_N): 

56 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

57 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

58 mask = (n_offsets < N) & (k_offsets < K) 

59 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=-float("inf")).to( 

60 tl.float32 

61 ) 

62 m_new = tl.maximum(m, inp) 

63 all_neg_inf = m_new == float("-inf") 

64 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

65 m = m_new 

66 

67 m_reduced = tl.max(m, axis=0, keep_dims=True) 

68 z = tl.sum(z * tl.exp(m - m_reduced), axis=0, keep_dims=True) 

69 m = m_reduced 

70 # Handle case where all inputs were -inf 

71 out = tl.where(m == float("-inf"), m, m + tl.log(z)) 

72 out_offset = pid_m * K + k_offsets 

73 output_ptrs = output_ptr + out_offset 

74 tl.store(output_ptrs, out, mask=k_offsets < K) 

75 

76 

77@libentry() 

78@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

79@triton.jit 

80def logsumexp_kernel_inner( 

81 output_ptr, 

82 input_ptr, 

83 M, 

84 N, 

85 TILE_N: tl.constexpr, 

86 ONE_TILE_PER_CTA: tl.constexpr, 

87): 

88 """Kernel for logsumexp when reduction dimension is the innermost.""" 

89 pid_m = ext.program_id(0) 

90 if ONE_TILE_PER_CTA: 

91 n_offsets = tl.arange(0, TILE_N) 

92 offset = pid_m * N + n_offsets 

93 input_ptrs = input_ptr + offset 

94 mask = n_offsets < N 

95 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

96 m = tl.max(inp, axis=0) 

97 # Handle case where all inputs are -inf 

98 safe_m = tl.where(m == float("-inf"), 0.0, m) 

99 e = tl.exp(inp - safe_m) 

100 z = tl.sum(e, axis=0) 

101 out = safe_m + tl.log(z) 

102 # If all inputs were -inf, result should be -inf 

103 out = tl.where(m == float("-inf"), m, out) 

104 output_ptrs = output_ptr + pid_m 

105 tl.store(output_ptrs, out) 

106 else: 

107 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

108 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

109 input_ptr += pid_m * N 

110 

111 for start_n in range(0, N, TILE_N): 

112 n_offsets = start_n + tl.arange(0, TILE_N) 

113 mask = n_offsets < N 

114 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")).to( 

115 tl.float32 

116 ) 

117 m_new = tl.maximum(m, inp) 

118 all_neg_inf = m_new == float("-inf") 

119 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

120 m = m_new 

121 

122 m_reduced = tl.max(m, axis=0) 

123 z = tl.sum(z * tl.exp(m - m_reduced), axis=0) 

124 m = m_reduced 

125 # Handle case where all inputs were -inf 

126 out = tl.where(m == float("-inf"), m, m + tl.log(z)) 

127 output_ptrs = output_ptr + pid_m 

128 tl.store(output_ptrs, out) 

129 

130 

131def logsumexp(inp, dim, keepdim=False): 

132 logger.debug("GEMS LOGSUMEXP") 

133 

134 if isinstance(dim, (list, tuple)): 

135 # Handle multi-dimensional reduction 

136 if len(dim) == 0: 

137 # Empty dim list means no reduction, just return the input 

138 return inp.clone() 

139 if len(dim) == 1: 

140 dim = dim[0] 

141 else: 

142 # For multiple dims, reduce sequentially 

143 # Sort dims in descending order to handle dimension shifts correctly 

144 sorted_dims = sorted([d % inp.ndim for d in dim], reverse=True) 

145 result = inp 

146 for d in sorted_dims: 

147 result = logsumexp(result, d, keepdim=True) 

148 if not keepdim: 

149 # Remove the reduced dimensions 

150 for d in sorted(sorted_dims, reverse=True): 

151 result = result.squeeze(d) 

152 return result 

153 

154 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

155 dim = dim % inp.ndim 

156 M = 1 

157 N = inp.shape[dim] 

158 for i in range(dim): 

159 M *= inp.shape[i] 

160 inp = inp.contiguous() 

161 K = inp.numel() // M // N 

162 

163 # Output shape with reduction dimension set to 1 

164 shape = list(inp.shape) 

165 shape[dim] = 1 

166 out = torch.empty(shape, dtype=inp.dtype, device=inp.device) 

167 

168 with torch_device_fn.device(inp.device): 

169 if K > 1: 

170 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

171 logsumexp_kernel_non_inner[grid]( 

172 out, 

173 inp, 

174 M, 

175 N, 

176 K, 

177 ) 

178 else: 

179 grid = (M, 1, 1) 

180 logsumexp_kernel_inner[grid]( 

181 out, 

182 inp, 

183 M, 

184 N, 

185 ) 

186 

187 if not keepdim: 

188 out = out.squeeze(dim=dim) 

189 return out