Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dropout.py: 0%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@triton.heuristics(runtime.get_heuristic_config("dropout")) 

18@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) 

19def dropout_forward_kernel( 

20 X, 

21 Y, 

22 dropout_mask, 

23 N, 

24 p, 

25 philox_seed, 

26 philox_offset, 

27 BLOCK: tl.constexpr, 

28): 

29 UNROLL: tl.constexpr = 8 

30 philox_seed = philox_seed.to(tl.int64) 

31 philox_offset = philox_offset.to(tl.int64) 

32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

34 

35 # First set of 4 random numbers 

36 i4_0 = tl.program_id(0) * BLOCK * 2 + tl.arange(0, BLOCK) 

37 c0_0 = c0 + i4_0 

38 _O = c0_0 * 0 

39 r0, r1, r2, r3 = tl.philox(philox_seed, c0_0, c1, _O, _O) 

40 r0 = uint_to_uniform_float(r0) 

41 r1 = uint_to_uniform_float(r1) 

42 r2 = uint_to_uniform_float(r2) 

43 r3 = uint_to_uniform_float(r3) 

44 

45 # Second set of 4 random numbers 

46 i4_1 = tl.program_id(0) * BLOCK * 2 + BLOCK + tl.arange(0, BLOCK) 

47 c0_1 = c0 + i4_1 

48 _O1 = c0_1 * 0 

49 r4, r5, r6, r7 = tl.philox(philox_seed, c0_1, c1, _O1, _O1) 

50 r4 = uint_to_uniform_float(r4) 

51 r5 = uint_to_uniform_float(r5) 

52 r6 = uint_to_uniform_float(r6) 

53 r7 = uint_to_uniform_float(r7) 

54 

55 mask0 = r0 > p 

56 mask1 = r1 > p 

57 mask2 = r2 > p 

58 mask3 = r3 > p 

59 mask4 = r4 > p 

60 mask5 = r5 > p 

61 mask6 = r6 > p 

62 mask7 = r7 > p 

63 scale = 1.0 / (1.0 - p) 

64 

65 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK) 

66 off_1 = off_0 + BLOCK 

67 off_2 = off_1 + BLOCK 

68 off_3 = off_2 + BLOCK 

69 off_4 = off_3 + BLOCK 

70 off_5 = off_4 + BLOCK 

71 off_6 = off_5 + BLOCK 

72 off_7 = off_6 + BLOCK 

73 

74 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0) 

75 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0) 

76 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0) 

77 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0) 

78 x4 = tl.load(X + off_4, mask=off_4 < N, other=0.0) 

79 x5 = tl.load(X + off_5, mask=off_5 < N, other=0.0) 

80 x6 = tl.load(X + off_6, mask=off_6 < N, other=0.0) 

81 x7 = tl.load(X + off_7, mask=off_7 < N, other=0.0) 

82 

83 y0 = tl.where(mask0, x0 * scale, 0.0) 

84 y1 = tl.where(mask1, x1 * scale, 0.0) 

85 y2 = tl.where(mask2, x2 * scale, 0.0) 

86 y3 = tl.where(mask3, x3 * scale, 0.0) 

87 y4 = tl.where(mask4, x4 * scale, 0.0) 

88 y5 = tl.where(mask5, x5 * scale, 0.0) 

89 y6 = tl.where(mask6, x6 * scale, 0.0) 

90 y7 = tl.where(mask7, x7 * scale, 0.0) 

91 

92 tl.store(Y + off_0, y0, mask=off_0 < N) 

93 tl.store(Y + off_1, y1, mask=off_1 < N) 

94 tl.store(Y + off_2, y2, mask=off_2 < N) 

95 tl.store(Y + off_3, y3, mask=off_3 < N) 

96 tl.store(Y + off_4, y4, mask=off_4 < N) 

97 tl.store(Y + off_5, y5, mask=off_5 < N) 

98 tl.store(Y + off_6, y6, mask=off_6 < N) 

99 tl.store(Y + off_7, y7, mask=off_7 < N) 

100 tl.store(dropout_mask + off_0, mask0, mask=off_0 < N) 

101 tl.store(dropout_mask + off_1, mask1, mask=off_1 < N) 

102 tl.store(dropout_mask + off_2, mask2, mask=off_2 < N) 

103 tl.store(dropout_mask + off_3, mask3, mask=off_3 < N) 

104 tl.store(dropout_mask + off_4, mask4, mask=off_4 < N) 

105 tl.store(dropout_mask + off_5, mask5, mask=off_5 < N) 

106 tl.store(dropout_mask + off_6, mask6, mask=off_6 < N) 

107 tl.store(dropout_mask + off_7, mask7, mask=off_7 < N) 

108 

109 

110@triton.heuristics(runtime.get_heuristic_config("dropout")) 

111@triton.jit(do_not_specialize=["scale"]) 

112def dropout_backward_kernel( 

113 DY, 

114 DX, 

115 dropout_mask, 

116 N, 

117 scale, 

118 BLOCK: tl.constexpr, 

119): 

120 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

121 mask = offset < N 

122 m = tl.load(dropout_mask + offset, mask=mask, other=0) 

123 dy = tl.load(DY + offset, mask=mask, other=0) 

124 dx = dy * m * scale 

125 tl.store(DX + offset, dx, mask=mask) 

126 

127 

128UNROLL = 8 

129 

130 

131def dropout(input, p, train=True): 

132 logger.debug("GEMS_KUNLUNXIN NATIVE_DROPOUT_FORWARD") 

133 if not train or p == 0: 

134 out = input.clone() 

135 mask = torch.ones_like(input, dtype=torch.bool) 

136 return out, mask 

137 if p == 1: 

138 out = torch.zeros_like(input) 

139 mask = torch.zeros_like(input, dtype=torch.bool) 

140 return out, mask 

141 assert p > 0.0 and p < 1.0, "p must be in (0, 1)" 

142 device = input.device 

143 input = input.contiguous() 

144 out = torch.empty_like(input) 

145 mask = torch.empty_like(input, dtype=torch.bool) 

146 N = input.numel() 

147 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

148 increment = triton.cdiv(N, UNROLL) 

149 with torch_device_fn.device(device): 

150 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

151 dropout_forward_kernel[grid_fn]( 

152 input, out, mask, N, p, philox_seed, philox_offset 

153 ) 

154 return out, mask 

155 

156 

157def dropout_backward(grad_output, mask, scale): 

158 logger.debug("GEMS_KUNLUNXIN NATIVE_DROPOUT_BACKWARD") 

159 grad_output = grad_output.contiguous() 

160 grad_input = torch.empty_like(grad_output) 

161 N = grad_output.numel() 

162 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),) 

163 with torch_device_fn.device(grad_output.device): 

164 dropout_backward_kernel[grid_fn](grad_output, grad_input, mask, N, scale) 

165 return grad_input