Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sum.py: 0%

151 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, libtuner 

10 

11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM, cfggen_reduce_op 

12from .zeros import zero_ 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@libtuner( 

19 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"] 

20) 

21@triton.jit 

22def sum_kernel_1( 

23 inp, 

24 out, 

25 M, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

29 inp.dtype.element_ty == tl.bfloat16 

30 ): 

31 cdtype = tl.float32 

32 else: 

33 cdtype = inp.dtype.element_ty 

34 

35 pid = tl.program_id(0) 

36 num_jobs = tl.num_programs(axis=0) 

37 block_start = pid * BLOCK_SIZE 

38 step = num_jobs * BLOCK_SIZE 

39 _tmp = tl.zeros([BLOCK_SIZE], dtype=cdtype) 

40 block_start = block_start.to(tl.int64) 

41 for off in range(block_start, M, step): 

42 offset = off + tl.arange(0, BLOCK_SIZE) 

43 mask = offset < M 

44 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

45 _tmp = inp_val + _tmp 

46 

47 sum_val = tl.sum(_tmp) 

48 tl.atomic_add(out, sum_val) 

49 

50 

51@libentry() 

52@libtuner( 

53 configs=runtime.get_tuned_config("sum"), 

54 key=["M", "N"], 

55 strategy=["log", "log"], 

56) 

57@triton.jit 

58def sum_kernel( 

59 inp, 

60 out, 

61 M, 

62 N, 

63 BLOCK_M: tl.constexpr, 

64 BLOCK_N: tl.constexpr, 

65): 

66 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

67 inp.dtype.element_ty == tl.bfloat16 

68 ): 

69 cdtype = tl.float32 

70 elif tl.constexpr(inp.dtype.element_ty == tl.int1): 

71 cdtype = tl.int32 

72 else: 

73 cdtype = inp.dtype.element_ty 

74 prog_num = tl.num_programs(0).to(tl.uint64) 

75 sub_pid = tl.program_id(0).to(tl.uint64) 

76 task_num = tl.cdiv(M, BLOCK_M).to(tl.uint64) 

77 while sub_pid < task_num: 

78 # Map the program id to the row of inp it should compute. 

79 pid = sub_pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

80 inp_ = inp + pid * N 

81 out_ = out + pid 

82 row_mask = pid < M 

83 

84 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

85 for off in range(0, N, BLOCK_N): 

86 cols = off + tl.arange(0, BLOCK_N)[None, :] 

87 col_mask = cols < N 

88 mask = row_mask and col_mask 

89 

90 a = tl.load(inp_ + cols, mask, other=0).to(cdtype) 

91 _sum += a 

92 sum = tl.sum(_sum, axis=1)[:, None] 

93 tl.store(out_, sum, row_mask) 

94 sub_pid += prog_num 

95 

96 

97def sum(inp, *, dtype=None): 

98 logger.debug("GEMS_CAMBRICON SUM") 

99 inp = inp.contiguous() 

100 M = inp.numel() 

101 if dtype is None: 

102 dtype = inp.dtype 

103 if dtype is torch.bool: 

104 inp = inp.to(torch.int32) 

105 dtype = torch.int32 

106 

107 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

108 out = torch.zeros([], dtype=dtype, device=inp.device) 

109 

110 with torch_device_fn.device(inp.device): 

111 sum_kernel_1[grid](inp, out, M) 

112 return out.to(dtype) 

113 

114 

115def sum_out(inp, *, dtype=None, out): 

116 logger.debug("GEMS_CAMBRICON SUM_OUT") 

117 M = inp.numel() 

118 if dtype is None: 

119 dtype = inp.dtype 

120 if dtype is torch.bool: 

121 inp = inp.to(torch.int32) 

122 dtype = torch.int32 

123 

124 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

125 

126 with torch_device_fn.device(inp.device): 

127 sum_kernel_1[grid](inp, out, M) 

128 return out.to(dtype) 

129 

130 

131def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

132 if dtype is None: 

133 dtype = inp.dtype 

134 if dtype is torch.bool: 

135 dtype = torch.int64 

136 

137 if dim is None: 

138 result = torch.sum(inp, dtype=dtype) 

139 if keepdim: 

140 result = result.reshape([1] * inp.ndim) 

141 return result 

142 

143 if dim == []: 

144 if not keepdim: 

145 return sum(inp, dtype=dtype) 

146 else: 

147 dim_num = inp.ndim 

148 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

149 shape = list(inp.shape) 

150 dim = [d % inp.ndim for d in dim] 

151 

152 inp = dim_compress(inp, dim) 

153 N = 1 

154 for i in dim: 

155 N *= shape[i] 

156 shape[i] = 1 

157 M = inp.numel() // N 

158 _out_provided = out is not None 

159 if _out_provided: 

160 dim_set = set(dim) 

161 if keepdim: 

162 out.resize_(shape) 

163 else: 

164 out.resize_([s for i, s in enumerate(shape) if i not in dim_set]) 

165 else: 

166 out = torch.empty(shape, dtype=dtype, device=inp.device) 

167 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),) 

168 with torch_device_fn.device(inp.device): 

169 sum_kernel[grid](inp, out, M, N) 

170 if not keepdim and not _out_provided: 

171 for d in sorted(dim, reverse=True): 

172 out = out.squeeze(dim=d) 

173 return out 

174 

175 

176def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

177 logger.debug("GEMS_CAMBRICON SUM DIM") 

178 # support dim = 0, which are consistent with PyTorch 

179 if inp.numel() == 0: 

180 if dtype is None: 

181 dtype = inp.dtype 

182 if dtype is torch.bool: 

183 dtype = torch.int64 

184 

185 out_shape = list(inp.shape) 

186 if dim is None: 

187 if keepdim: 

188 out_shape = [1] * len(out_shape) 

189 else: 

190 out_shape = [] 

191 elif isinstance(dim, (list, tuple)) and len(dim) == 0: 

192 if keepdim: 

193 out_shape = [1] * len(out_shape) 

194 else: 

195 out_shape = [] 

196 else: 

197 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim] 

198 if keepdim: 

199 for d in dims_to_reduce: 

200 out_shape[d % inp.ndim] = 1 

201 else: 

202 sorted_dims_to_remove = sorted( 

203 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True 

204 ) 

205 for d in sorted_dims_to_remove: 

206 index_to_remove = d % inp.ndim 

207 out_shape.pop(index_to_remove) 

208 out = torch.empty(out_shape, dtype=dtype, device=inp.device) 

209 zero_(out) 

210 return out 

211 return sum_dim_comm(inp, dim, keepdim, dtype=dtype) 

212 

213 

214def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

215 logger.debug("GEMS_CAMBRICON SUM_DIM_OUT") 

216 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)