Coverage for src/flag_gems/runtime/backend/_sunrise/ops/embedding.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def embedding_kernel(
17 out_ptr, # pointer to the output
18 in_ptr, # pointer to the input
19 weight_ptr, # pointer to the weights
20 N: tl.constexpr, # number of columns in X
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = ext.program_id(0)
24 out_ptr += pid * N
25 in_ptr += pid
27 mask = tl.arange(0, BLOCK_SIZE) < N
28 cols = tl.arange(0, BLOCK_SIZE)
30 row_idx = tl.load(in_ptr)
31 weight_ptr += row_idx * N
32 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)
33 tl.store(out_ptr + cols, embedding_weight, mask)
36@libentry()
37@triton.jit
38def indice_freq_kernel(
39 indices_freq,
40 indices, # pointer to the input
41 elem_cnt: tl.constexpr, # number of columns in X
42 INDICE_BLOCK_SIZE: tl.constexpr,
43):
44 pid = ext.program_id(0)
45 block_start = pid * INDICE_BLOCK_SIZE
47 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
48 mask = offsets < elem_cnt
50 index_element = tl.load(indices + offsets, mask=mask)
51 current_freq = tl.load(indices_freq + index_element, mask=mask, other=0)
52 tl.store(indices_freq + index_element, current_freq + 1, mask=mask)
55@libentry()
56@triton.jit(do_not_specialize=["padding_idx"])
57def embedding_backward_kernel(
58 grad_in, # pointer to the gradient input
59 grad_out, # pointer to the gradient output
60 indices, # pointer to the input
61 padding_idx, # padding_idx
62 HAS_PADDING_IDX: tl.constexpr,
63 N: tl.constexpr, # number of columns in X
64 BLOCK_SIZE: tl.constexpr,
65):
66 pid = ext.program_id(0)
67 grad_out += pid * N
68 indices += pid
70 mask = tl.arange(0, BLOCK_SIZE) < N
71 cols = tl.arange(0, BLOCK_SIZE)
73 row_idx = tl.load(indices).to(tl.int32)
74 if not HAS_PADDING_IDX:
75 grad_in += row_idx * N
76 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
77 if tl.constexpr(embedding_grad.dtype.is_bf16()):
78 embedding_grad = embedding_grad.to(tl.float32)
79 current_grad = tl.load(grad_in + cols, mask, other=0.0).to(tl.float32)
80 new_grad = current_grad + embedding_grad
81 tl.store(grad_in + cols, new_grad, mask=mask)
82 else:
83 if row_idx != padding_idx:
84 grad_in += row_idx * N
85 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
86 if tl.constexpr(embedding_grad.dtype.is_bf16()):
87 embedding_grad = embedding_grad.to(tl.float32)
88 current_grad = tl.load(grad_in + cols, mask, other=0.0).to(tl.float32)
89 new_grad = current_grad + embedding_grad
90 tl.store(grad_in + cols, new_grad, mask=mask)
93@libentry()
94@triton.jit(do_not_specialize=["n_rows"])
95def embedding_grad_scale_kernel(
96 grad_out,
97 indice_freq,
98 n_rows,
99 N,
100 BLOCK_SIZE: tl.constexpr,
101):
102 row_start = ext.program_id(0)
103 row_step = ext.num_programs(0)
105 for row_idx in range(row_start, n_rows, row_step):
106 embedding_scale = 1.0
107 indice_freq_val = tl.load(indice_freq + row_idx)
108 if indice_freq_val > 1:
109 embedding_scale = 1.0 / indice_freq_val
111 cols = tl.arange(0, BLOCK_SIZE)
112 mask = tl.arange(0, BLOCK_SIZE) < N
113 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
114 scaled_embedding_grad = embedding_grad * embedding_scale
115 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)
118def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
119 logger.debug("GEMS EMBEDDING FORWARD")
120 assert not sparse, "Currently do not support sparse format"
122 M = indices.numel()
123 N = weight.shape[-1]
125 BLOCK_SIZE = triton.next_power_of_2(N)
126 # TODO: remove contiguous enforcement
127 indices = indices.contiguous()
128 weight = weight.contiguous()
129 output = torch.empty((*indices.shape, N), device=indices.device, dtype=weight.dtype)
131 with torch_device_fn.device(weight.device):
132 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)
134 return output
137def embedding_backward(
138 grad_outputs,
139 indices,
140 num_weights,
141 padding_idx=-1,
142 scale_grad_by_freq=False,
143 sparse=False,
144):
145 logger.debug("GEMS EMBEDDING BACKWARD")
146 assert not sparse, "Currently do not support sparse format"
148 M = indices.numel()
149 N = grad_outputs.shape[-1]
151 grad_inputs = torch.zeros(
152 (num_weights, grad_outputs.shape[-1]),
153 device=grad_outputs.device,
154 dtype=(
155 torch.float32
156 if grad_outputs.dtype is torch.bfloat16
157 else grad_outputs.dtype
158 ),
159 )
161 if scale_grad_by_freq:
162 indice_freq = torch.zeros(
163 (num_weights,),
164 requires_grad=False,
165 device=grad_outputs.device,
166 dtype=torch.int32,
167 )
168 INDICE_BLOCK_SIZE = 256
169 indice_grid = (triton.cdiv(M, INDICE_BLOCK_SIZE),)
171 with torch_device_fn.device(grad_outputs.device):
172 indice_freq_kernel[indice_grid](indice_freq, indices, M, INDICE_BLOCK_SIZE)
173 else:
174 indice_freq = None
176 BLOCK_SIZE = triton.next_power_of_2(N)
178 HAS_PADDING_IDX = padding_idx is not None
180 with torch_device_fn.device(grad_outputs.device):
181 embedding_backward_kernel[M,](
182 grad_inputs,
183 grad_outputs,
184 indices,
185 padding_idx,
186 HAS_PADDING_IDX,
187 N,
188 BLOCK_SIZE,
189 )
191 if scale_grad_by_freq:
192 with torch_device_fn.device(grad_outputs.device):
193 embedding_grad_scale_kernel[M,](
194 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE
195 )
196 return (
197 grad_inputs.to(torch.bfloat16)
198 if grad_outputs.dtype is torch.bfloat16
199 else grad_inputs
200 )