Coverage for src/flag_gems/ops/embedding_dense_backward.py: 53%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _embedding_dense_backward_kernel( 

14 grad_output_ptr, 

15 indices_ptr, 

16 grad_weight_ptr, 

17 num_weights, 

18 padding_idx, 

19 BLOCK_D: tl.constexpr, 

20 EMBED_DIM: tl.constexpr, 

21): 

22 pid_n = tl.program_id(0) 

23 pid_d = tl.program_id(1) 

24 

25 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

26 mask_d = offs_d < EMBED_DIM 

27 

28 idx = tl.load(indices_ptr + pid_n) 

29 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

30 

31 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d 

32 go = tl.load(go_ptrs, mask=mask_d, other=0).to(tl.float32) 

33 

34 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d 

35 mask = mask_d & valid 

36 tl.atomic_add(gw_ptrs, go, mask=mask) 

37 

38 

39@triton.jit 

40def _embedding_dense_backward_count_kernel( 

41 indices_ptr, 

42 counts_ptr, 

43 N, 

44 num_weights, 

45 padding_idx, 

46 BLOCK_N: tl.constexpr, 

47): 

48 pid = tl.program_id(0) 

49 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

50 mask = offs < N 

51 idx = tl.load(indices_ptr + offs, mask=mask, other=0).to(tl.int32) 

52 valid = mask & (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

53 tl.atomic_add(counts_ptr + idx, 1, mask=valid) 

54 

55 

56@triton.jit 

57def _embedding_dense_backward_kernel_scale_by_freq( 

58 grad_output_ptr, 

59 indices_ptr, 

60 counts_ptr, 

61 grad_weight_ptr, 

62 num_weights, 

63 padding_idx, 

64 BLOCK_D: tl.constexpr, 

65 EMBED_DIM: tl.constexpr, 

66): 

67 pid_n = tl.program_id(0) 

68 pid_d = tl.program_id(1) 

69 

70 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

71 mask_d = offs_d < EMBED_DIM 

72 

73 idx = tl.load(indices_ptr + pid_n).to(tl.int32) 

74 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

75 

76 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d 

77 # go = tl.load(go_ptrs, mask=mask_d, other=0.0).to(tl.float32) 

78 go = tl.load(go_ptrs, mask=mask_d, other=0.0) 

79 

80 # cnt = tl.load(counts_ptr + idx, mask=valid, other=1).to(tl.float32) 

81 cnt = tl.load(counts_ptr + idx, mask=valid, other=1) 

82 go = go / cnt 

83 

84 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d 

85 mask = mask_d & valid 

86 tl.atomic_add(gw_ptrs, go, mask=mask) 

87 

88 

89def embedding_dense_backward( 

90 grad_output: torch.Tensor, 

91 indices: torch.Tensor, 

92 num_weights: int, 

93 padding_idx: int, 

94 scale_grad_by_freq: bool, 

95): 

96 logger.debug("GEMS: embedding_dense_backward") 

97 assert indices.dtype in ( 

98 torch.int32, 

99 torch.int64, 

100 ), "Indices must be int32 or int64." 

101 if ( 

102 grad_output.device.type != flag_gems.device 

103 or indices.device.type != flag_gems.device 

104 or grad_output.device != indices.device 

105 ): 

106 raise ValueError( 

107 f"Inputs must be {flag_gems.device} tensors on the same device." 

108 ) 

109 

110 device = grad_output.device 

111 assert ( 

112 grad_output.dim() >= 2 

113 ), "grad_output must have embedding dimension as the last dim." 

114 

115 D = grad_output.shape[-1] 

116 go = grad_output.contiguous().view(-1, D) # (N, D) 

117 idx = indices.contiguous().view(-1) 

118 N = idx.numel() 

119 

120 assert go.shape[0] == N, "indices number must match grad_output rows." 

121 grad_weight_fp32 = torch.zeros((num_weights, D), device=device, dtype=torch.float32) 

122 

123 BLOCK_D = 128 

124 grid = (N, triton.cdiv(D, BLOCK_D)) 

125 

126 if scale_grad_by_freq: 

127 counts = torch.zeros((num_weights,), device=device, dtype=torch.int32) 

128 BLOCK_N = 512 

129 _embedding_dense_backward_count_kernel[(triton.cdiv(N, BLOCK_N),)]( 

130 idx, 

131 counts, 

132 N, 

133 num_weights, 

134 padding_idx if padding_idx is not None else -1, 

135 BLOCK_N=BLOCK_N, 

136 ) 

137 

138 _embedding_dense_backward_kernel_scale_by_freq[grid]( 

139 go, 

140 idx, 

141 counts, 

142 grad_weight_fp32, 

143 num_weights, 

144 padding_idx if padding_idx is not None else -1, 

145 BLOCK_D=BLOCK_D, 

146 EMBED_DIM=D, 

147 ) 

148 else: 

149 _embedding_dense_backward_kernel[grid]( 

150 go, 

151 idx, 

152 grad_weight_fp32, 

153 num_weights, 

154 padding_idx if padding_idx is not None else -1, 

155 BLOCK_D=BLOCK_D, 

156 EMBED_DIM=D, 

157 ) 

158 

159 if grad_output.dtype != torch.float32: 

160 return grad_weight_fp32.to(grad_output.dtype) 

161 return grad_weight_fp32