Coverage for src/flag_gems/ops/embedding_dense_backward.py: 53%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems
9logger = logging.getLogger(__name__)
12@triton.jit
13def _embedding_dense_backward_kernel(
14 grad_output_ptr,
15 indices_ptr,
16 grad_weight_ptr,
17 num_weights,
18 padding_idx,
19 BLOCK_D: tl.constexpr,
20 EMBED_DIM: tl.constexpr,
21):
22 pid_n = tl.program_id(0)
23 pid_d = tl.program_id(1)
25 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
26 mask_d = offs_d < EMBED_DIM
28 idx = tl.load(indices_ptr + pid_n)
29 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
31 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d
32 go = tl.load(go_ptrs, mask=mask_d, other=0).to(tl.float32)
34 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d
35 mask = mask_d & valid
36 tl.atomic_add(gw_ptrs, go, mask=mask)
39@triton.jit
40def _embedding_dense_backward_count_kernel(
41 indices_ptr,
42 counts_ptr,
43 N,
44 num_weights,
45 padding_idx,
46 BLOCK_N: tl.constexpr,
47):
48 pid = tl.program_id(0)
49 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N)
50 mask = offs < N
51 idx = tl.load(indices_ptr + offs, mask=mask, other=0).to(tl.int32)
52 valid = mask & (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
53 tl.atomic_add(counts_ptr + idx, 1, mask=valid)
56@triton.jit
57def _embedding_dense_backward_kernel_scale_by_freq(
58 grad_output_ptr,
59 indices_ptr,
60 counts_ptr,
61 grad_weight_ptr,
62 num_weights,
63 padding_idx,
64 BLOCK_D: tl.constexpr,
65 EMBED_DIM: tl.constexpr,
66):
67 pid_n = tl.program_id(0)
68 pid_d = tl.program_id(1)
70 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
71 mask_d = offs_d < EMBED_DIM
73 idx = tl.load(indices_ptr + pid_n).to(tl.int32)
74 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
76 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d
77 # go = tl.load(go_ptrs, mask=mask_d, other=0.0).to(tl.float32)
78 go = tl.load(go_ptrs, mask=mask_d, other=0.0)
80 # cnt = tl.load(counts_ptr + idx, mask=valid, other=1).to(tl.float32)
81 cnt = tl.load(counts_ptr + idx, mask=valid, other=1)
82 go = go / cnt
84 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d
85 mask = mask_d & valid
86 tl.atomic_add(gw_ptrs, go, mask=mask)
89def embedding_dense_backward(
90 grad_output: torch.Tensor,
91 indices: torch.Tensor,
92 num_weights: int,
93 padding_idx: int,
94 scale_grad_by_freq: bool,
95):
96 logger.debug("GEMS: embedding_dense_backward")
97 assert indices.dtype in (
98 torch.int32,
99 torch.int64,
100 ), "Indices must be int32 or int64."
101 if (
102 grad_output.device.type != flag_gems.device
103 or indices.device.type != flag_gems.device
104 or grad_output.device != indices.device
105 ):
106 raise ValueError(
107 f"Inputs must be {flag_gems.device} tensors on the same device."
108 )
110 device = grad_output.device
111 assert (
112 grad_output.dim() >= 2
113 ), "grad_output must have embedding dimension as the last dim."
115 D = grad_output.shape[-1]
116 go = grad_output.contiguous().view(-1, D) # (N, D)
117 idx = indices.contiguous().view(-1)
118 N = idx.numel()
120 assert go.shape[0] == N, "indices number must match grad_output rows."
121 grad_weight_fp32 = torch.zeros((num_weights, D), device=device, dtype=torch.float32)
123 BLOCK_D = 128
124 grid = (N, triton.cdiv(D, BLOCK_D))
126 if scale_grad_by_freq:
127 counts = torch.zeros((num_weights,), device=device, dtype=torch.int32)
128 BLOCK_N = 512
129 _embedding_dense_backward_count_kernel[(triton.cdiv(N, BLOCK_N),)](
130 idx,
131 counts,
132 N,
133 num_weights,
134 padding_idx if padding_idx is not None else -1,
135 BLOCK_N=BLOCK_N,
136 )
138 _embedding_dense_backward_kernel_scale_by_freq[grid](
139 go,
140 idx,
141 counts,
142 grad_weight_fp32,
143 num_weights,
144 padding_idx if padding_idx is not None else -1,
145 BLOCK_D=BLOCK_D,
146 EMBED_DIM=D,
147 )
148 else:
149 _embedding_dense_backward_kernel[grid](
150 go,
151 idx,
152 grad_weight_fp32,
153 num_weights,
154 padding_idx if padding_idx is not None else -1,
155 BLOCK_D=BLOCK_D,
156 EMBED_DIM=D,
157 )
159 if grad_output.dtype != torch.float32:
160 return grad_weight_fp32.to(grad_output.dtype)
161 return grad_weight_fp32