Coverage for src/flag_gems/runtime/backend/_sunrise/ops/embedding.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def embedding_kernel( 

17 out_ptr, # pointer to the output 

18 in_ptr, # pointer to the input 

19 weight_ptr, # pointer to the weights 

20 N: tl.constexpr, # number of columns in X 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = ext.program_id(0) 

24 out_ptr += pid * N 

25 in_ptr += pid 

26 

27 mask = tl.arange(0, BLOCK_SIZE) < N 

28 cols = tl.arange(0, BLOCK_SIZE) 

29 

30 row_idx = tl.load(in_ptr) 

31 weight_ptr += row_idx * N 

32 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0) 

33 tl.store(out_ptr + cols, embedding_weight, mask) 

34 

35 

36@libentry() 

37@triton.jit 

38def indice_freq_kernel( 

39 indices_freq, 

40 indices, # pointer to the input 

41 elem_cnt: tl.constexpr, # number of columns in X 

42 INDICE_BLOCK_SIZE: tl.constexpr, 

43): 

44 pid = ext.program_id(0) 

45 block_start = pid * INDICE_BLOCK_SIZE 

46 

47 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE) 

48 mask = offsets < elem_cnt 

49 

50 index_element = tl.load(indices + offsets, mask=mask) 

51 current_freq = tl.load(indices_freq + index_element, mask=mask, other=0) 

52 tl.store(indices_freq + index_element, current_freq + 1, mask=mask) 

53 

54 

55@libentry() 

56@triton.jit(do_not_specialize=["padding_idx"]) 

57def embedding_backward_kernel( 

58 grad_in, # pointer to the gradient input 

59 grad_out, # pointer to the gradient output 

60 indices, # pointer to the input 

61 padding_idx, # padding_idx 

62 HAS_PADDING_IDX: tl.constexpr, 

63 N: tl.constexpr, # number of columns in X 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 pid = ext.program_id(0) 

67 grad_out += pid * N 

68 indices += pid 

69 

70 mask = tl.arange(0, BLOCK_SIZE) < N 

71 cols = tl.arange(0, BLOCK_SIZE) 

72 

73 row_idx = tl.load(indices).to(tl.int32) 

74 if not HAS_PADDING_IDX: 

75 grad_in += row_idx * N 

76 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

77 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

78 embedding_grad = embedding_grad.to(tl.float32) 

79 current_grad = tl.load(grad_in + cols, mask, other=0.0).to(tl.float32) 

80 new_grad = current_grad + embedding_grad 

81 tl.store(grad_in + cols, new_grad, mask=mask) 

82 else: 

83 if row_idx != padding_idx: 

84 grad_in += row_idx * N 

85 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

86 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

87 embedding_grad = embedding_grad.to(tl.float32) 

88 current_grad = tl.load(grad_in + cols, mask, other=0.0).to(tl.float32) 

89 new_grad = current_grad + embedding_grad 

90 tl.store(grad_in + cols, new_grad, mask=mask) 

91 

92 

93@libentry() 

94@triton.jit(do_not_specialize=["n_rows"]) 

95def embedding_grad_scale_kernel( 

96 grad_out, 

97 indice_freq, 

98 n_rows, 

99 N, 

100 BLOCK_SIZE: tl.constexpr, 

101): 

102 row_start = ext.program_id(0) 

103 row_step = ext.num_programs(0) 

104 

105 for row_idx in range(row_start, n_rows, row_step): 

106 embedding_scale = 1.0 

107 indice_freq_val = tl.load(indice_freq + row_idx) 

108 if indice_freq_val > 1: 

109 embedding_scale = 1.0 / indice_freq_val 

110 

111 cols = tl.arange(0, BLOCK_SIZE) 

112 mask = tl.arange(0, BLOCK_SIZE) < N 

113 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask) 

114 scaled_embedding_grad = embedding_grad * embedding_scale 

115 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask) 

116 

117 

118def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 

119 logger.debug("GEMS EMBEDDING FORWARD") 

120 assert not sparse, "Currently do not support sparse format" 

121 

122 M = indices.numel() 

123 N = weight.shape[-1] 

124 

125 BLOCK_SIZE = triton.next_power_of_2(N) 

126 # TODO: remove contiguous enforcement 

127 indices = indices.contiguous() 

128 weight = weight.contiguous() 

129 output = torch.empty((*indices.shape, N), device=indices.device, dtype=weight.dtype) 

130 

131 with torch_device_fn.device(weight.device): 

132 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE) 

133 

134 return output 

135 

136 

137def embedding_backward( 

138 grad_outputs, 

139 indices, 

140 num_weights, 

141 padding_idx=-1, 

142 scale_grad_by_freq=False, 

143 sparse=False, 

144): 

145 logger.debug("GEMS EMBEDDING BACKWARD") 

146 assert not sparse, "Currently do not support sparse format" 

147 

148 M = indices.numel() 

149 N = grad_outputs.shape[-1] 

150 

151 grad_inputs = torch.zeros( 

152 (num_weights, grad_outputs.shape[-1]), 

153 device=grad_outputs.device, 

154 dtype=( 

155 torch.float32 

156 if grad_outputs.dtype is torch.bfloat16 

157 else grad_outputs.dtype 

158 ), 

159 ) 

160 

161 if scale_grad_by_freq: 

162 indice_freq = torch.zeros( 

163 (num_weights,), 

164 requires_grad=False, 

165 device=grad_outputs.device, 

166 dtype=torch.int32, 

167 ) 

168 INDICE_BLOCK_SIZE = 256 

169 indice_grid = (triton.cdiv(M, INDICE_BLOCK_SIZE),) 

170 

171 with torch_device_fn.device(grad_outputs.device): 

172 indice_freq_kernel[indice_grid](indice_freq, indices, M, INDICE_BLOCK_SIZE) 

173 else: 

174 indice_freq = None 

175 

176 BLOCK_SIZE = triton.next_power_of_2(N) 

177 

178 HAS_PADDING_IDX = padding_idx is not None 

179 

180 with torch_device_fn.device(grad_outputs.device): 

181 embedding_backward_kernel[M,]( 

182 grad_inputs, 

183 grad_outputs, 

184 indices, 

185 padding_idx, 

186 HAS_PADDING_IDX, 

187 N, 

188 BLOCK_SIZE, 

189 ) 

190 

191 if scale_grad_by_freq: 

192 with torch_device_fn.device(grad_outputs.device): 

193 embedding_grad_scale_kernel[M,]( 

194 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE 

195 ) 

196 return ( 

197 grad_inputs.to(torch.bfloat16) 

198 if grad_outputs.dtype is torch.bfloat16 

199 else grad_inputs 

200 )