Coverage for src/flag_gems/runtime/backend/_sunrise/ops/one_hot.py: 0%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def one_hot_kernel_16( 

17 input_ptr, 

18 output_ptr, 

19 num_elements, 

20 actual_classes, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = ext.program_id(axis=0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offsets < num_elements 

27 

28 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

29 out_base = offsets * actual_classes 

30 

31 class_offsets = tl.arange(0, 16) 

32 out_offsets = out_base[:, None] + class_offsets[None, :] 

33 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

34 valid_classes = class_offsets < actual_classes 

35 combined_mask = mask[:, None] & valid_classes[None, :] 

36 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

37 

38 

39@libentry() 

40@triton.jit 

41def one_hot_kernel_32( 

42 input_ptr, 

43 output_ptr, 

44 num_elements, 

45 actual_classes, 

46 BLOCK_SIZE: tl.constexpr, 

47): 

48 pid = ext.program_id(axis=0) 

49 block_start = pid * BLOCK_SIZE 

50 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

51 mask = offsets < num_elements 

52 

53 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

54 out_base = offsets * actual_classes 

55 

56 class_offsets = tl.arange(0, 32) 

57 out_offsets = out_base[:, None] + class_offsets[None, :] 

58 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

59 valid_classes = class_offsets < actual_classes 

60 combined_mask = mask[:, None] & valid_classes[None, :] 

61 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

62 

63 

64@libentry() 

65@triton.jit 

66def one_hot_kernel_64( 

67 input_ptr, 

68 output_ptr, 

69 num_elements, 

70 actual_classes, 

71 BLOCK_SIZE: tl.constexpr, 

72): 

73 pid = ext.program_id(axis=0) 

74 block_start = pid * BLOCK_SIZE 

75 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

76 mask = offsets < num_elements 

77 

78 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

79 out_base = offsets * actual_classes 

80 

81 class_offsets = tl.arange(0, 64) 

82 out_offsets = out_base[:, None] + class_offsets[None, :] 

83 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

84 valid_classes = class_offsets < actual_classes 

85 combined_mask = mask[:, None] & valid_classes[None, :] 

86 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

87 

88 

89@libentry() 

90@triton.jit 

91def one_hot_set_one_kernel( 

92 input_ptr, 

93 output_ptr, 

94 num_elements, 

95 num_classes, 

96 BLOCK_SIZE: tl.constexpr, 

97): 

98 pid = ext.program_id(axis=0) 

99 block_start = pid * BLOCK_SIZE 

100 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

101 mask = offsets < num_elements 

102 

103 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

104 out_offsets = offsets * num_classes + indices 

105 tl.store(output_ptr + out_offsets, 1, mask=mask) 

106 

107 

108def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor: 

109 logger.debug("GEMS ONE_HOT") 

110 

111 if tensor.dtype != torch.int64: 

112 raise RuntimeError( 

113 "one_hot is only applicable to index tensor of type LongTensor." 

114 ) 

115 

116 if tensor.numel() == 0: 

117 if num_classes <= 0: 

118 raise RuntimeError( 

119 "Can not infer total number of classes from empty tensor." 

120 ) 

121 return torch.empty( 

122 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64 

123 ) 

124 

125 if num_classes == -1: 

126 num_classes = int(tensor.max().item()) + 1 

127 

128 if (tensor < 0).any(): 

129 raise RuntimeError("Class values must be non-negative.") 

130 

131 if num_classes < 1: 

132 raise RuntimeError("num_classes should be positive") 

133 

134 if (tensor >= num_classes).any(): 

135 raise RuntimeError("Class values must be smaller than num_classes.") 

136 

137 if not tensor.is_ptpu: 

138 out = torch.zeros( 

139 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64 

140 ) 

141 out.scatter_(-1, tensor.unsqueeze(-1), 1) 

142 return out 

143 

144 flat_input = tensor.contiguous().view(-1) 

145 num_elements = flat_input.numel() 

146 

147 with torch_device_fn.device(tensor.device): 

148 if num_classes <= 16: 

149 out = torch.empty( 

150 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

151 ) 

152 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

153 one_hot_kernel_16[grid]( 

154 flat_input, 

155 out, 

156 num_elements, 

157 num_classes, 

158 BLOCK_SIZE=128, 

159 ) 

160 elif num_classes <= 32: 

161 out = torch.empty( 

162 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

163 ) 

164 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

165 one_hot_kernel_32[grid]( 

166 flat_input, 

167 out, 

168 num_elements, 

169 num_classes, 

170 BLOCK_SIZE=128, 

171 ) 

172 elif num_classes <= 64: 

173 out = torch.empty( 

174 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

175 ) 

176 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

177 one_hot_kernel_64[grid]( 

178 flat_input, 

179 out, 

180 num_elements, 

181 num_classes, 

182 BLOCK_SIZE=128, 

183 ) 

184 else: 

185 out = torch.zeros( 

186 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

187 ) 

188 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

189 one_hot_set_one_kernel[grid]( 

190 flat_input, 

191 out, 

192 num_elements, 

193 num_classes, 

194 BLOCK_SIZE=1024, 

195 ) 

196 

197 return out.view(*tensor.shape, num_classes)