Coverage for src/flag_gems/runtime/backend/_sunrise/ops/gather.py: 0%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.ops.scatter import scatter_ 

9from flag_gems.utils.code_cache import code_cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11from flag_gems.utils.shape_utils import restride_dim 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

17 code.writeline("import torch") 

18 code.writeline("try:") 

19 code.writeline(" import torch_ptpu") 

20 code.writeline("except ImportError:") 

21 code.writeline(" import torch.cuda as torch_ptpu") 

22 code.writeline("import triton") 

23 code.writeline("import triton.language as tl") 

24 code.newline() 

25 code.writeline("from flag_gems.utils import libentry") 

26 code.writeline("from flag_gems import runtime") 

27 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

28 

29 code.newline() 

30 code.newline() 

31 return code 

32 

33 

34def generate_gather_kernel( 

35 rank: int, 

36 kernel_name: str, 

37 code: IndentedBuffer, 

38) -> IndentedBuffer: 

39 # make the inlined function visible in the context 

40 code.newline() 

41 

42 code.writeline("@libentry()") 

43 code.writeline("@triton.heuristics({'BLOCK_SIZE_N': lambda args: 128})") 

44 code.writeline("@triton.jit") 

45 code.writeline(f"def {kernel_name}(") 

46 with code.indent(): 

47 args = [ 

48 "inp, ", 

49 "index, ", 

50 "out, ", 

51 ] 

52 args += [f"inp_shape{i}," for i in range(rank)] 

53 args += [f"index_shape{i}, " for i in range(rank)] 

54 args += [f"out_shape{i}, " for i in range(rank)] 

55 args += [f"inp_stride{i}, " for i in range(rank)] 

56 args += [f"index_stride{i}, " for i in range(rank)] 

57 args += [f"out_stride{i}, " for i in range(rank)] 

58 args += ["dim, ", "dim_stride, ", "N, ", "BLOCK_SIZE_N: tl.constexpr, "] 

59 code.writelines(args) 

60 code.writeline("):") 

61 

62 with code.indent(): 

63 code.writeline("pid = ext.program_id(0)") 

64 code.writeline( 

65 "offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)" 

66 ) 

67 code.newline() 

68 code.writeline("cur_offset = offset") 

69 for i in range(rank - 1, -1, -1): 

70 code.writeline(f"index_idx{i} = cur_offset % index_shape{i}") 

71 code.writeline(f"cur_offset = cur_offset // index_shape{i}") 

72 code.newline() 

73 comp = [f"index_idx{i} * index_stride{i}" for i in range(rank)] 

74 code.writeline(f"index_offset = {' + '.join(comp)}") 

75 code.writeline("mask = offset < N") 

76 code.writeline("cur_index = tl.load(index + index_offset, mask=mask, other=0)") 

77 code.newline() 

78 comp = [f"index_idx{i} * inp_stride{i}" for i in range(rank)] 

79 code.writeline(f"inp_offset = {' + '.join(comp)}") 

80 code.writeline("inp_offset += cur_index * dim_stride") 

81 code.writeline("cur_inp = tl.load(inp + inp_offset, mask=mask, other=0)") 

82 code.newline() 

83 comp = [f"index_idx{i} * out_stride{i}" for i in range(rank)] 

84 code.writeline(f"out_offset = {' + '.join(comp)}") 

85 code.writeline("tl.store(out + out_offset, value=cur_inp, mask=mask)") 

86 

87 code.newline() 

88 code.newline() 

89 return code 

90 

91 

92def generate_gather_wrapper( 

93 rank: int, 

94 wrapper_name: str, 

95 kernel_name: str, 

96 code: IndentedBuffer, 

97) -> IndentedBuffer: 

98 code.writeline(f"def {wrapper_name}(inp, dim, index, out, dim_stride, N):") 

99 with code.indent(): 

100 code.writeline("inp_shape = inp.shape") 

101 code.writeline("inp_stride = inp.stride()") 

102 code.writeline("index_shape = index.shape") 

103 code.writeline("index_stride = index.stride()") 

104 code.writeline("out_shape = out.shape") 

105 code.writeline("out_stride = out.stride()") 

106 code.writeline("grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )") 

107 code.writeline(f"{kernel_name}[grid](") 

108 with code.indent(): 

109 args = [ 

110 "inp, ", 

111 "index, ", 

112 "out, ", 

113 ] 

114 args += [f"inp_shape[{i}], " for i in range(rank)] 

115 args += [f"index_shape[{i}], " for i in range(rank)] 

116 args += [f"out_shape[{i}], " for i in range(rank)] 

117 args += [f"inp_stride[{i}], " for i in range(rank)] 

118 args += [f"index_stride[{i}], " for i in range(rank)] 

119 args += [f"out_stride[{i}], " for i in range(rank)] 

120 args += [ 

121 "dim, ", 

122 "dim_stride, ", 

123 "N, ", 

124 ] 

125 code.writelines(args) 

126 code.writeline(")") 

127 code.writeline("return out") 

128 code.newline() 

129 code.newline() 

130 return code 

131 

132 

133def generate_code( 

134 inputs: Tuple[Any], 

135 wrapper_name: str, 

136 kernel_name: str, 

137 code: IndentedBuffer, 

138) -> IndentedBuffer: 

139 rank = inputs[0].ndim 

140 

141 code = generate_imports(code) 

142 code = generate_gather_kernel(rank, kernel_name, code) 

143 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

144 return code 

145 

146 

147class GatherFunction: 

148 def __init__(self): 

149 self.pid = os.getpid() 

150 self.overloads: Mapping[str, Callable] = {} 

151 

152 def __call__(self, *args, **kwargs): 

153 key = f"{self.arg_key(*args)}" 

154 if key in self.overloads: 

155 overload = self.overloads[key] 

156 else: 

157 code = IndentedBuffer() 

158 code = generate_code( 

159 args, 

160 "_gather_wrapper", 

161 "_gather_flaggems_jit_function", 

162 code, 

163 ) 

164 

165 file_name = f"gather_rank_{key}.py" 

166 file_path = code_cache_dir() / file_name 

167 write_atomic(file_path, code.getvalue()) 

168 

169 # load 

170 spec = importlib.util.spec_from_file_location( 

171 f"_gen_module_rank_{key}", 

172 file_path, 

173 ) 

174 

175 m = importlib.util.module_from_spec(spec) 

176 spec.loader.exec_module(m) 

177 overload = getattr(m, "_gather_wrapper") 

178 self.overloads[key] = overload 

179 

180 return overload(*args, **kwargs) 

181 

182 def arg_key(self, *args): 

183 return args[0].ndim 

184 

185 

186_gather_func = GatherFunction() 

187 

188 

189def gather(inp, dim, index, out=None, sparse_grad=False): 

190 logger.debug("GEMS GATHER") 

191 if out is None: 

192 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

193 dim_stride = inp.stride(dim) 

194 inp_strided = restride_dim(inp, dim, index.shape) 

195 N = index.numel() 

196 _gather_func(inp_strided, dim, index, out, dim_stride, N) 

197 return out 

198 

199 

200def gather_backward(grad, self, dim, index, sparse_grad): 

201 logger.debug("GEMS GATHER BACKWARD") 

202 result = grad.new_zeros(self.shape) 

203 return scatter_(result, dim, index, grad, reduce="add")