Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mean.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import builtins 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13from ..utils.block_size_utils import get_block_size_1d 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@libentry() 

19@triton.jit 

20def mean_scalar_kernel(inp, out, M, BLOCK_SIZE: tl.constexpr): 

21 """Scalar mean over all M elements. 

22 On XPU (USE_XHPC): intercepted by baidu::xpu::api::mean binding. 

23 Triton fallback (single CTA): sequential accumulation for correctness. 

24 Params for binding: 

25 kernelParams[0] = inp, kernelParams[1] = out 

26 kernelConsts[2] = M, kernelConsts[3] = BLOCK_SIZE 

27 """ 

28 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

29 for off in range(0, M, BLOCK_SIZE): 

30 offset = off + tl.arange(0, BLOCK_SIZE) 

31 mask = offset < M 

32 v = tl.load(inp + offset, mask=mask, other=0.0).to(tl.float32) 

33 acc += v 

34 result = tl.sum(acc) / M 

35 tl.store(out, result) 

36 

37 

38def mean(inp, *, dtype=None): 

39 logger.debug("GEMS MEAN") 

40 M = inp.numel() 

41 if dtype is None: 

42 dtype = inp.dtype 

43 BLOCK_SIZE = get_block_size_1d(M, inp.element_size()) 

44 out = torch.empty([], dtype=dtype, device=inp.device) 

45 

46 with torch_device_fn.device(inp.device): 

47 mean_scalar_kernel[(1, 1, 1)](inp, out, M, BLOCK_SIZE, buffer_size_limit=2048) 

48 return out 

49 

50 

51def heur_m_block_size(args): 

52 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

53 

54 

55def heur_n_block_size(args): 

56 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

57 

58 

59@libentry() 

60# @triton.autotune( 

61# configs=runtime.get_tuned_config("mean"), 

62# key=["M", "N"], 

63# ) 

64@triton.heuristics( 

65 values={ 

66 "BLOCK_M": heur_m_block_size, 

67 "BLOCK_N": heur_n_block_size, 

68 }, 

69) 

70@triton.jit 

71def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

72 """2-D reduction: reduce N-dim for each of M rows. 

73 On XPU (USE_XHPC): intercepted by baidu::xpu::api::mean_dim binding. 

74 Params for binding: 

75 kernelParams[0] = X, kernelParams[1] = Mean 

76 kernelParams[2] = M, kernelParams[3] = N (runtime scalars) 

77 kernelConsts[4] = BLOCK_M (constexpr), kernelConsts[5] = BLOCK_N (constexpr) 

78 """ 

79 # Map the program id to the row of X it should compute. 

80 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

81 X = X + pid * N 

82 Mean = Mean + pid 

83 row_mask = pid < M 

84 

85 # Compute mean 

86 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

87 for off in range(0, N, BLOCK_N): 

88 cols = off + tl.arange(0, BLOCK_N)[None, :] 

89 col_mask = cols < N 

90 mask = row_mask and col_mask 

91 

92 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

93 _mean += a 

94 mean = tl.sum(_mean, axis=1) / N 

95 mean = mean[:, None] 

96 tl.store(Mean, mean, row_mask) 

97 

98 

99def mean_dim(x, dim, keepdim=False, *, dtype=None): 

100 logger.debug("GEMS MEAN DIM") 

101 

102 if dtype is None: 

103 dtype = x.dtype 

104 if dim is None: 

105 out = mean(x, dtype=dtype) 

106 if not keepdim: 

107 out = out.reshape([1] * x.ndim) 

108 return out 

109 

110 shape = list(x.shape) 

111 dim = [d % x.ndim for d in dim] 

112 x = dim_compress(x, dim) 

113 N = 1 

114 for i in dim: 

115 N *= shape[i] 

116 shape[i] = 1 

117 M = x.numel() // N 

118 

119 # Edge case: M=1 means all dims are reduced → global mean over N elements. 

120 # mean_dim XPU API does not support M=1. 

121 if M == 1: 

122 scalar_out = mean(x, dtype=dtype) # 0-d tensor 

123 out = scalar_out.reshape(shape) 

124 if not keepdim: 

125 out = out.squeeze(dim) 

126 return out 

127 

128 # Edge case: N=1 means reducing a trivial (size-1) dimension. 

129 # mean of 1 element = that element; just copy with dtype conversion. 

130 # mean_dim XPU API does not support N=1. 

131 if N == 1: 

132 out = x.to(dtype=dtype).reshape(shape) 

133 if not keepdim: 

134 out = out.squeeze(dim) 

135 return out 

136 

137 out = torch.empty(shape, dtype=dtype, device=x.device) 

138 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

139 

140 with torch_device_fn.device(x.device): 

141 mean_dim_kernel[grid](x, out, M, N, buffer_size_limit=2048) 

142 if not keepdim: 

143 out = out.squeeze(dim) 

144 return out