Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mean.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13 

14@libentry() 

15@triton.jit 

16def mean_kernel_1( 

17 inp, 

18 mid, 

19 M, 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tle.program_id(0) 

23 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

24 inp_ptrs = inp + offset 

25 mask = offset < M 

26 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) 

27 sum_val = tl.sum(inp_val, axis=0) 

28 mid_ptr = mid + pid 

29 tl.store(mid_ptr, sum_val) 

30 

31 

32@libentry() 

33@triton.jit 

34def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): 

35 offset = tl.arange(0, BLOCK_MID) 

36 mid_ptrs = mid + offset 

37 mask = offset < MID_SIZE 

38 mid_val = tl.load(mid_ptrs, mask=mask, other=0.0) 

39 sum_val = tl.sum(mid_val, axis=0) / M 

40 tl.store(out, sum_val) 

41 

42 

43def mean(inp, *, dtype=None): 

44 logging.debug("GEMS_SPACEMIT MEAN") 

45 M = inp.numel() 

46 if dtype is None: 

47 dtype = inp.dtype 

48 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

49 mid_size = triton.cdiv(M, block_size) 

50 block_mid = triton.next_power_of_2(mid_size) 

51 

52 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

53 out = torch.empty([], dtype=dtype, device=inp.device) 

54 

55 with torch_device_fn.device(inp.device): 

56 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

57 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) 

58 return out 

59 

60 

61@libentry() 

62@triton.autotune( 

63 configs=runtime.get_tuned_config("mean"), 

64 key=["M", "N"], 

65) 

66@triton.jit 

67def mean_dim_kernel(X, Mean, M, N, TILE_N: tl.constexpr): 

68 row = tl.program_id(0) 

69 X = X + row * N 

70 Mean = Mean + row 

71 _mean = 0.0 

72 

73 num_pid_n = tl.cdiv(N, TILE_N) 

74 

75 x_ptr_desc = tl.make_block_ptr( 

76 base=X, 

77 shape=[N], 

78 strides=[1], 

79 offsets=[0], 

80 block_shape=[TILE_N], 

81 order=[0], 

82 ) 

83 

84 for off_n in range(0, num_pid_n): 

85 a = tl.load( 

86 x_ptr_desc, 

87 boundary_check=[0], 

88 ) 

89 

90 _mean += tl.sum(a) 

91 

92 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N]) 

93 

94 mean = _mean / N 

95 

96 tl.store(Mean, mean) 

97 

98 

99def mean_dim(x, dim, keepdim=False, *, dtype=None): 

100 logging.debug("GEMS_SPACEMIT MEAN_DIM") 

101 

102 if dtype is None: 

103 dtype = x.dtype 

104 if dim is None: 

105 out = mean(x, dtype=dtype) 

106 if not keepdim: 

107 out = out.reshape([1] * x.ndim) 

108 return out 

109 

110 shape = list(x.shape) 

111 dim = [d % x.ndim for d in dim] 

112 x = dim_compress(x, dim) 

113 N = 1 

114 for i in dim: 

115 N *= shape[i] 

116 shape[i] = 1 

117 M = x.numel() // N 

118 out = torch.empty(shape, dtype=dtype, device=x.device) 

119 grid = (M,) 

120 with torch_device_fn.device(x.device): 

121 mean_dim_kernel[grid](x, out, M, N) 

122 if not keepdim: 

123 out = out.squeeze(dim) 

124 return out 

125 

126 

127def global_avg_pool(x, _output_size=None): 

128 return mean_dim(x, dim=[2, 3], keepdim=True)