Coverage for src/flag_gems/runtime/backend/_cambricon/ops/fill.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=[ 

18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

21 ], 

22 key=["N"], 

23 strategy=["log"], 

24) 

25@triton.jit(do_not_specialize=["value_scalar"]) 

26def fill_scalar_kernel( 

27 out_ptr, 

28 N, 

29 value_scalar, 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = tl.program_id(0) 

33 num_jobs = tl.num_programs(axis=0) 

34 block_start = pid * BLOCK_SIZE 

35 step = num_jobs * BLOCK_SIZE 

36 block_start = block_start.to(tl.int64) 

37 for block_start_offset in range(block_start, N, step): 

38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE) 

39 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

40 

41 

42@libentry() 

43@libtuner( 

44 configs=[ 

45 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

46 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

47 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

48 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

49 ], 

50 key=["N"], 

51) 

52@triton.jit 

53def fill_tensor_kernel( 

54 out_ptr, 

55 N, 

56 value_ptr, 

57 BLOCK_SIZE: tl.constexpr, 

58): 

59 pid = tl.program_id(0) 

60 num_jobs = tl.num_programs(axis=0) 

61 block_start = pid * BLOCK_SIZE 

62 step = num_jobs * BLOCK_SIZE 

63 block_start = block_start.to(tl.int64) 

64 for block_start_offset in range(block_start, N, step): 

65 offset = block_start_offset + tl.arange(0, BLOCK_SIZE) 

66 value_scalar = tl.load(value_ptr) # load the value from the tensor. 

67 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

68 

69 

70def fill_tensor(input, value): 

71 logger.debug("GEMS_CAMBRICON FILL TENSOR") 

72 if value.ndim != 0: 

73 raise RuntimeError( 

74 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

75 ) 

76 out = torch.empty_like(input) 

77 N = out.numel() 

78 # grid = triton.cdiv(N, BLOCK_SIZE) 

79 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

80 

81 with torch_device_fn.device(input.device): 

82 fill_tensor_kernel[grid_fn](out, N, value) 

83 return out 

84 

85 

86def fill_scalar(input, value): 

87 logger.debug("GEMS_CAMBRICON FILL SCALAR") 

88 if 0 in input.shape: 

89 return input 

90 out = torch.empty_like(input) 

91 N = out.numel() 

92 # grid = triton.cdiv(N, BLOCK_SIZE) 

93 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

94 

95 with torch_device_fn.device(input.device): 

96 fill_scalar_kernel[grid_fn](out, N, value) 

97 return out 

98 

99 

100def fill_scalar_out(input, value, *, out=None): 

101 logger.debug("GEMS_CAMBRICON FILL SCALAR_OUT") 

102 if out is None: 

103 return fill_scalar(input, value) 

104 N = out.numel() 

105 # grid = triton.cdiv(N, BLOCK_SIZE) 

106 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

107 

108 with torch_device_fn.device(input.device): 

109 fill_scalar_kernel[grid_fn](out, N, value) 

110 return out 

111 

112 

113def fill_tensor_out(input, value, *, out=None): 

114 logger.debug("GEMS_CAMBRICON FILL_TENSOR_OUT") 

115 if value.ndim != 0: 

116 raise RuntimeError( 

117 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

118 ) 

119 if out is None: 

120 return fill_tensor(input, value) 

121 N = out.numel() 

122 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

123 

124 with torch_device_fn.device(input.device): 

125 fill_tensor_kernel[grid_fn](out, N, value) 

126 return out 

127 

128 

129def fill_tensor_(self, value): 

130 logger.debug("GEMS_CAMBRICON FILL_TENSOR_") 

131 if value.ndim != 0: 

132 raise RuntimeError( 

133 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

134 ) 

135 N = self.numel() 

136 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

137 

138 with torch_device_fn.device(self.device): 

139 fill_tensor_kernel[grid_fn](self, N, value) 

140 return self 

141 

142 

143def fill_scalar_(self, value): 

144 logger.debug("GEMS_CAMBRICON FILL_SCALAR_") 

145 if 0 in self.shape: 

146 return self 

147 N = self.numel() 

148 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

149 

150 with torch_device_fn.device(self.device): 

151 fill_scalar_kernel[grid_fn](self, N, value) 

152 return self