Coverage for src/flag_gems/runtime/backend/_mthreads/ops/one_hot.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger( 

12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" 

13) 

14 

15 

16@libentry() 

17@triton.jit 

18def one_hot_kernel_16( 

19 input_ptr, 

20 output_ptr, 

21 num_elements, 

22 actual_classes, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = ext.program_id(axis=0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

28 mask = offsets < num_elements 

29 

30 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

31 out_base = offsets * actual_classes 

32 

33 class_offsets = tl.arange(0, 16) 

34 out_offsets = out_base[:, None] + class_offsets[None, :] 

35 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

36 valid_classes = class_offsets < actual_classes 

37 combined_mask = mask[:, None] & valid_classes[None, :] 

38 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

39 

40 

41@libentry() 

42@triton.jit 

43def one_hot_kernel_32( 

44 input_ptr, 

45 output_ptr, 

46 num_elements, 

47 actual_classes, 

48 BLOCK_SIZE: tl.constexpr, 

49): 

50 pid = ext.program_id(axis=0) 

51 block_start = pid * BLOCK_SIZE 

52 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

53 mask = offsets < num_elements 

54 

55 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

56 out_base = offsets * actual_classes 

57 

58 class_offsets = tl.arange(0, 32) 

59 out_offsets = out_base[:, None] + class_offsets[None, :] 

60 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

61 valid_classes = class_offsets < actual_classes 

62 combined_mask = mask[:, None] & valid_classes[None, :] 

63 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

64 

65 

66@libentry() 

67@triton.jit 

68def one_hot_kernel_64( 

69 input_ptr, 

70 output_ptr, 

71 num_elements, 

72 actual_classes, 

73 BLOCK_SIZE: tl.constexpr, 

74): 

75 pid = ext.program_id(axis=0) 

76 block_start = pid * BLOCK_SIZE 

77 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

78 mask = offsets < num_elements 

79 

80 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

81 out_base = offsets * actual_classes 

82 

83 class_offsets = tl.arange(0, 64) 

84 out_offsets = out_base[:, None] + class_offsets[None, :] 

85 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

86 valid_classes = class_offsets < actual_classes 

87 combined_mask = mask[:, None] & valid_classes[None, :] 

88 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

89 

90 

91@libentry() 

92@triton.jit 

93def one_hot_set_one_kernel( 

94 input_ptr, 

95 output_ptr, 

96 num_elements, 

97 num_classes, 

98 BLOCK_SIZE: tl.constexpr, 

99): 

100 """ 

101 Kernel that only writes 1s to the correct positions. 

102 Output tensor should be pre-initialized with zeros. 

103 """ 

104 pid = ext.program_id(axis=0) 

105 block_start = pid * BLOCK_SIZE 

106 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

107 mask = offsets < num_elements 

108 

109 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

110 out_offsets = offsets * num_classes + indices 

111 tl.store(output_ptr + out_offsets, 1, mask=mask) 

112 

113 

114def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor: 

115 logger.debug("GEMS_MTHREADS ONE_HOT") 

116 

117 if tensor.dtype != torch.int64: 

118 raise RuntimeError( 

119 "one_hot is only applicable to index tensor of type LongTensor." 

120 ) 

121 

122 if tensor.numel() == 0: 

123 if num_classes <= 0: 

124 raise RuntimeError( 

125 "Can not infer total number of classes from empty tensor." 

126 ) 

127 shape = (*tensor.shape, num_classes) 

128 return torch.empty(shape, device=tensor.device, dtype=torch.int64) 

129 

130 # Only compute max when necessary (num_classes=-1) 

131 if num_classes == -1: 

132 # Only compute max to infer num_classes 

133 maxv = int(tensor.max().item()) 

134 num_classes = maxv + 1 

135 

136 # Validate that all indices are non-negative 

137 if (tensor < 0).any(): 

138 raise RuntimeError("Class values must be non-negative.") 

139 

140 # Validate that all indices are within num_classes range 

141 if (tensor >= num_classes).any(): 

142 raise RuntimeError("Class values must be smaller than num_classes.") 

143 

144 if num_classes < 1: 

145 raise RuntimeError("num_classes should be positive") 

146 

147 # CPU tensor handling 

148 if tensor.device.type == "cpu": 

149 out = torch.zeros((*tensor.shape, num_classes), device="cpu", dtype=torch.int64) 

150 out.scatter_(-1, tensor.unsqueeze(-1), 1) 

151 return out 

152 

153 # Flatten input for kernel processing 

154 flat_input = tensor.contiguous().view(-1) 

155 num_elements = flat_input.numel() 

156 

157 # Choose kernel based on num_classes 

158 with torch_device_fn.device(tensor.device): 

159 if num_classes <= 16: 

160 out = torch.empty( 

161 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

162 ) 

163 BLOCK_SIZE = 128 

164 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

165 one_hot_kernel_16[grid]( 

166 flat_input, 

167 out, 

168 num_elements, 

169 num_classes, 

170 BLOCK_SIZE=BLOCK_SIZE, 

171 ) 

172 elif num_classes <= 32: 

173 out = torch.empty( 

174 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

175 ) 

176 BLOCK_SIZE = 128 

177 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

178 one_hot_kernel_32[grid]( 

179 flat_input, 

180 out, 

181 num_elements, 

182 num_classes, 

183 BLOCK_SIZE=BLOCK_SIZE, 

184 ) 

185 elif num_classes <= 64: 

186 out = torch.empty( 

187 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

188 ) 

189 BLOCK_SIZE = 128 

190 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

191 one_hot_kernel_64[grid]( 

192 flat_input, 

193 out, 

194 num_elements, 

195 num_classes, 

196 BLOCK_SIZE=BLOCK_SIZE, 

197 ) 

198 else: 

199 # For large num_classes, use zeros + set ones 

200 out = torch.zeros( 

201 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

202 ) 

203 BLOCK_SIZE = 1024 

204 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

205 one_hot_set_one_kernel[grid]( 

206 flat_input, 

207 out, 

208 num_elements, 

209 num_classes, 

210 BLOCK_SIZE=BLOCK_SIZE, 

211 ) 

212 

213 return out.view(*tensor.shape, num_classes)