Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mean.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, libtuner 

10 

11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"] 

19) 

20@triton.jit 

21def mean_kernel_1( 

22 inp, 

23 out, 

24 M, 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = tl.program_id(0) 

28 num_jobs = tl.num_programs(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 step = num_jobs * BLOCK_SIZE 

31 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

32 block_start = block_start.to(tl.int64) 

33 for off in range(block_start, M, step): 

34 offset = off + tl.arange(0, BLOCK_SIZE) 

35 mask = offset < M 

36 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

37 _tmp = inp_val + _tmp 

38 

39 mean_val = tl.sum(_tmp, axis=0) / M 

40 tl.atomic_add(out, mean_val) 

41 

42 

43def mean(inp, *, dtype=None): 

44 logger.debug("GEMS_CAMBRICON MEAN") 

45 inp = inp.contiguous() 

46 M = inp.numel() 

47 if dtype is None: 

48 dtype = inp.dtype 

49 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

50 out = torch.zeros([], dtype=torch.float32, device=inp.device) 

51 

52 with torch_device_fn.device(inp.device): 

53 mean_kernel_1[grid](inp, out, M) 

54 return out.to(dtype) 

55 

56 

57@libentry() 

58@libtuner( 

59 configs=runtime.get_tuned_config("mean"), 

60 key=["M", "N"], 

61 strategy=["log", "log"], 

62) 

63@triton.jit 

64def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

65 # Map the program id to the row of X it should compute. 

66 num_prog = tl.num_programs(0) 

67 task_num = tl.cdiv(M, BLOCK_M) 

68 iter_num = tl.cdiv(task_num, num_prog) 

69 for i in range(0, iter_num): 

70 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

71 :, None 

72 ] 

73 X_ptr = X + pid * N 

74 Mean_ptr = Mean + pid 

75 row_mask = pid < M 

76 

77 # Compute mean 

78 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

79 for off in range(0, N, BLOCK_N): 

80 cols = off + tl.arange(0, BLOCK_N)[None, :] 

81 col_mask = cols < N 

82 mask = row_mask and col_mask 

83 

84 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

85 _mean += a 

86 _mean /= N 

87 mean = tl.sum(_mean, axis=1)[:, None] 

88 tl.store(Mean_ptr, mean, row_mask) 

89 

90 

91def mean_dim(x, dim, keepdim=False, *, dtype=None): 

92 logger.debug("GEMS_CAMBRICON MEAN DIM") 

93 

94 if dtype is None: 

95 dtype = x.dtype 

96 if dim is None: 

97 out = mean(x, dtype=dtype) 

98 if not keepdim: 

99 out = out.reshape([1] * x.ndim) 

100 return out 

101 

102 shape = list(x.shape) 

103 dim = [d % x.ndim for d in dim] 

104 x = dim_compress(x, dim) 

105 N = 1 

106 for i in dim: 

107 N *= shape[i] 

108 shape[i] = 1 

109 M = x.numel() // N 

110 out = torch.empty(shape, dtype=dtype, device=x.device) 

111 grid = lambda META: (min(triton.cdiv(M, META["BLOCK_M"]), TOTAL_CORE_NUM),) 

112 with torch_device_fn.device(x.device): 

113 mean_dim_kernel[grid](x, out, M, N) 

114 if not keepdim: 

115 out = out.squeeze(dim) 

116 return out