Coverage for src/flag_gems/runtime/backend/_sunrise/ops/one_hot.py: 0%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def one_hot_kernel_16(
17 input_ptr,
18 output_ptr,
19 num_elements,
20 actual_classes,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = ext.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < num_elements
28 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
29 out_base = offsets * actual_classes
31 class_offsets = tl.arange(0, 16)
32 out_offsets = out_base[:, None] + class_offsets[None, :]
33 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
34 valid_classes = class_offsets < actual_classes
35 combined_mask = mask[:, None] & valid_classes[None, :]
36 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
39@libentry()
40@triton.jit
41def one_hot_kernel_32(
42 input_ptr,
43 output_ptr,
44 num_elements,
45 actual_classes,
46 BLOCK_SIZE: tl.constexpr,
47):
48 pid = ext.program_id(axis=0)
49 block_start = pid * BLOCK_SIZE
50 offsets = block_start + tl.arange(0, BLOCK_SIZE)
51 mask = offsets < num_elements
53 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
54 out_base = offsets * actual_classes
56 class_offsets = tl.arange(0, 32)
57 out_offsets = out_base[:, None] + class_offsets[None, :]
58 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
59 valid_classes = class_offsets < actual_classes
60 combined_mask = mask[:, None] & valid_classes[None, :]
61 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
64@libentry()
65@triton.jit
66def one_hot_kernel_64(
67 input_ptr,
68 output_ptr,
69 num_elements,
70 actual_classes,
71 BLOCK_SIZE: tl.constexpr,
72):
73 pid = ext.program_id(axis=0)
74 block_start = pid * BLOCK_SIZE
75 offsets = block_start + tl.arange(0, BLOCK_SIZE)
76 mask = offsets < num_elements
78 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
79 out_base = offsets * actual_classes
81 class_offsets = tl.arange(0, 64)
82 out_offsets = out_base[:, None] + class_offsets[None, :]
83 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
84 valid_classes = class_offsets < actual_classes
85 combined_mask = mask[:, None] & valid_classes[None, :]
86 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
89@libentry()
90@triton.jit
91def one_hot_set_one_kernel(
92 input_ptr,
93 output_ptr,
94 num_elements,
95 num_classes,
96 BLOCK_SIZE: tl.constexpr,
97):
98 pid = ext.program_id(axis=0)
99 block_start = pid * BLOCK_SIZE
100 offsets = block_start + tl.arange(0, BLOCK_SIZE)
101 mask = offsets < num_elements
103 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
104 out_offsets = offsets * num_classes + indices
105 tl.store(output_ptr + out_offsets, 1, mask=mask)
108def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor:
109 logger.debug("GEMS ONE_HOT")
111 if tensor.dtype != torch.int64:
112 raise RuntimeError(
113 "one_hot is only applicable to index tensor of type LongTensor."
114 )
116 if tensor.numel() == 0:
117 if num_classes <= 0:
118 raise RuntimeError(
119 "Can not infer total number of classes from empty tensor."
120 )
121 return torch.empty(
122 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64
123 )
125 if num_classes == -1:
126 num_classes = int(tensor.max().item()) + 1
128 if (tensor < 0).any():
129 raise RuntimeError("Class values must be non-negative.")
131 if num_classes < 1:
132 raise RuntimeError("num_classes should be positive")
134 if (tensor >= num_classes).any():
135 raise RuntimeError("Class values must be smaller than num_classes.")
137 if not tensor.is_ptpu:
138 out = torch.zeros(
139 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64
140 )
141 out.scatter_(-1, tensor.unsqueeze(-1), 1)
142 return out
144 flat_input = tensor.contiguous().view(-1)
145 num_elements = flat_input.numel()
147 with torch_device_fn.device(tensor.device):
148 if num_classes <= 16:
149 out = torch.empty(
150 num_elements * num_classes, device=tensor.device, dtype=torch.int64
151 )
152 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
153 one_hot_kernel_16[grid](
154 flat_input,
155 out,
156 num_elements,
157 num_classes,
158 BLOCK_SIZE=128,
159 )
160 elif num_classes <= 32:
161 out = torch.empty(
162 num_elements * num_classes, device=tensor.device, dtype=torch.int64
163 )
164 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
165 one_hot_kernel_32[grid](
166 flat_input,
167 out,
168 num_elements,
169 num_classes,
170 BLOCK_SIZE=128,
171 )
172 elif num_classes <= 64:
173 out = torch.empty(
174 num_elements * num_classes, device=tensor.device, dtype=torch.int64
175 )
176 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
177 one_hot_kernel_64[grid](
178 flat_input,
179 out,
180 num_elements,
181 num_classes,
182 BLOCK_SIZE=128,
183 )
184 else:
185 out = torch.zeros(
186 num_elements * num_classes, device=tensor.device, dtype=torch.int64
187 )
188 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
189 one_hot_set_one_kernel[grid](
190 flat_input,
191 out,
192 num_elements,
193 num_classes,
194 BLOCK_SIZE=1024,
195 )
197 return out.view(*tensor.shape, num_classes)