Coverage for src/flag_gems/runtime/backend/_cambricon/ops/dropout.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import torch_mlu # noqa: F401 

5import triton 

6import triton.language as tl 

7from triton.language.extra.mlu.libdevice import philox as _philox 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils.random_utils import ( 

12 philox_backend_seed_offset, 

13 uint_to_uniform_float, 

14) 

15 

16from ..utils import TOTAL_CORE_NUM 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19 

20UNROLL = 4 

21 

22 

23@libentry() 

24@libtuner( 

25 configs=[ 

26 triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1), 

27 triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1), 

28 triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1), 

29 triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1), 

30 ], 

31 key=["N"], 

32) 

33@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) 

34def dropout_forward_kernel( 

35 X, 

36 Y, 

37 dropout_mask, 

38 N, 

39 p, 

40 philox_seed, 

41 philox_offset, 

42 BLOCK: tl.constexpr, 

43): 

44 UNROLL: tl.constexpr = 4 

45 philox_seed = philox_seed.to(tl.int64) 

46 philox_offset = philox_offset.to(tl.int64) 

47 

48 pid = tl.program_id(0) 

49 num_jobs = tl.num_programs(0) 

50 i4_start = pid * BLOCK 

51 block_start = pid * UNROLL * BLOCK 

52 step = num_jobs * BLOCK * UNROLL 

53 mp = 1.0 / (1.0 - p) 

54 

55 for block_offset in range(block_start, N, step): 

56 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

57 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

60 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

61 r = uint_to_uniform_float(r) 

62 

63 mask = r > p 

64 mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) 

65 

66 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

67 valid = off < N 

68 x = tl.load(X + off, mask=valid, other=0.0) 

69 y = tl.where(mask_reshaped, x * mp, 0.0) 

70 tl.store(dropout_mask + off, mask_reshaped, mask=valid) 

71 tl.store(Y + off, y, mask=valid) 

72 i4_start += num_jobs * BLOCK 

73 

74 

75@libentry() 

76@libtuner( 

77 configs=[ 

78 triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1), 

79 triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1), 

80 triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1), 

81 triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1), 

82 ], 

83 key=["N"], 

84) 

85@triton.jit(do_not_specialize=["scale"]) 

86def dropout_backward_kernel( 

87 DY, 

88 DX, 

89 dropout_mask, 

90 N, 

91 scale, 

92 BLOCK: tl.constexpr, 

93): 

94 UNROLL: tl.constexpr = 4 

95 pid = tl.program_id(0) 

96 num_programs = tl.num_programs(0) 

97 block_start = pid * UNROLL * BLOCK 

98 step = num_programs * UNROLL * BLOCK 

99 for block_offset in range(block_start, N, step): 

100 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

101 valid = off < N 

102 mask = tl.load( 

103 dropout_mask + off, mask=valid, other=0, eviction_policy="evict_first" 

104 ) 

105 dy = tl.load(DY + off, mask=valid, other=0.0, eviction_policy="evict_first") 

106 dx = dy * mask * scale 

107 tl.store(DX + off, dx, mask=valid, eviction_policy="evict_first") 

108 

109 

110def dropout(input, p, train=True): 

111 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT FORWARD") 

112 if not train or p == 0: 

113 out = input.clone() 

114 mask = torch.ones_like(input, dtype=torch.bool) 

115 return out, mask 

116 if p == 1: 

117 out = torch.zeros_like(input) 

118 mask = torch.zeros_like(input, dtype=torch.bool) 

119 return out, mask 

120 assert p > 0.0 and p < 1.0, "p must be in (0, 1)" 

121 device = input.device 

122 input = input.contiguous() 

123 out = torch.empty_like(input) 

124 mask = torch.empty_like(input, dtype=torch.bool) 

125 N = input.numel() 

126 grid_fn = lambda meta: ( 

127 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

128 ) 

129 increment = triton.cdiv(N, UNROLL) 

130 with torch_device_fn.device(device): 

131 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

132 dropout_forward_kernel[grid_fn]( 

133 input, 

134 out, 

135 mask, 

136 N, 

137 p, 

138 philox_seed, 

139 philox_offset, 

140 ) 

141 return out, mask 

142 

143 

144def dropout_backward(grad_output, mask, scale): 

145 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT BACKWARD") 

146 grad_output = grad_output.contiguous() 

147 grad_input = torch.empty_like(grad_output) 

148 N = grad_output.numel() 

149 grid_fn = lambda meta: ( 

150 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

151 ) 

152 with torch_device_fn.device(grad_output.device): 

153 dropout_backward_kernel[grid_fn]( 

154 grad_output, 

155 grad_input, 

156 mask, 

157 N, 

158 scale, 

159 ) 

160 return grad_input