Coverage for src/flag_gems/ops/feature_dropout.py: 59%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) 

17def generate_feature_mask_kernel( 

18 MASK, 

19 N, # batch size 

20 C, # number of channels 

21 p, 

22 scale, 

23 philox_seed, 

24 philox_offset, 

25 BLOCK_N: tl.constexpr, 

26 BLOCK_C: tl.constexpr, 

27): 

28 """ 

29 Generate a feature dropout mask of shape (N, C). 

30 Each element is either 0 (dropped) or scale (kept). 

31 Each (n, c) pair gets its own random value. 

32 """ 

33 philox_seed = philox_seed.to(tl.int64) 

34 philox_offset = philox_offset.to(tl.int64) 

35 

36 pid_n = tl.program_id(0) 

37 pid_c = tl.program_id(1) 

38 

39 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

40 c_offset = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) 

41 

42 n_mask = n_offset < N 

43 c_mask = c_offset < C 

44 

45 # Compute flat indices for random number generation 

46 # flat_idx = n * C + c 

47 flat_idx = n_offset[:, None] * C + c_offset[None, :] 

48 

49 # Generate random numbers using philox 

50 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

51 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

52 i4 = flat_idx.to(tl.uint32) 

53 c0 = c0 + i4 

54 _O = c0 * 0 

55 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O) 

56 rand_vals = uint_to_uniform_float(r0) 

57 

58 # Create mask: scale if rand > p (keep), 0 if rand <= p (drop) 

59 mask_vals = tl.where(rand_vals > p, scale, 0.0) 

60 

61 # Store mask 

62 mask_offsets = n_offset[:, None] * C + c_offset[None, :] 

63 mask_mask = n_mask[:, None] & c_mask[None, :] 

64 tl.store(MASK + mask_offsets, mask_vals, mask=mask_mask) 

65 

66 

67@triton.jit 

68def apply_feature_mask_kernel( 

69 X, 

70 Y, 

71 MASK, 

72 numel, 

73 N, # batch size 

74 C, # channels 

75 spatial_size, # H * W or D1 * D2 * ... 

76 BLOCK: tl.constexpr, 

77): 

78 """ 

79 Apply feature mask to input tensor. 

80 Input shape: (N, C, ...) flattened to (numel,) 

81 Mask shape: (N, C) 

82 

83 For element at flat index i: 

84 - For contiguous (N, C, H, W) layout: i = n * (C * spatial) + c * spatial + spatial_idx 

85 - n = i // (C * spatial_size) 

86 - c = (i // spatial_size) % C 

87 - mask_idx = n * C + c 

88 """ 

89 pid = tl.program_id(0) 

90 offset = pid * BLOCK + tl.arange(0, BLOCK) 

91 mask = offset < numel 

92 

93 # Compute batch and channel index for each element 

94 channel_spatial_size = C * spatial_size 

95 n_idx = offset // channel_spatial_size 

96 c_idx = (offset % channel_spatial_size) // spatial_size 

97 

98 # Compute mask index: n * C + c 

99 mask_idx = n_idx * C + c_idx 

100 

101 # Load input and mask 

102 x = tl.load(X + offset, mask=mask, other=0.0) 

103 m = tl.load(MASK + mask_idx, mask=mask, other=0.0) 

104 

105 # Apply mask 

106 y = x * m 

107 

108 tl.store(Y + offset, y, mask=mask) 

109 

110 

111def feature_dropout(input, p, train=True): 

112 """ 

113 Applies feature dropout to the input tensor. 

114 

115 Randomly zeroes out entire channels of the input tensor with probability p. 

116 Each batch element has its own independent channel mask. 

117 

118 Args: 

119 input: Input tensor of shape (N, C, ...) where N is batch size, C is channels 

120 p: Probability of a channel to be zeroed. Default: 0.5 

121 train: If True, applies dropout. If False, returns input unchanged. 

122 

123 Returns: 

124 Output tensor of same shape as input 

125 """ 

126 logger.debug("GEMS FEATURE_DROPOUT") 

127 

128 if not train or p == 0: 

129 return input.clone() 

130 

131 if p == 1: 

132 return torch.zeros_like(input) 

133 

134 if input.ndim < 2: 

135 raise RuntimeError( 

136 "Feature dropout requires at least 2 dimensions in the input" 

137 ) 

138 

139 assert 0.0 < p < 1.0, "p must be in (0, 1)" 

140 

141 device = input.device 

142 input = input.contiguous() 

143 out = torch.empty_like(input) 

144 

145 # Get dimensions 

146 batch_size = input.shape[0] 

147 num_channels = input.shape[1] 

148 spatial_size = 1 

149 for i in range(2, input.ndim): 

150 spatial_size *= input.shape[i] 

151 

152 N = batch_size 

153 C = num_channels 

154 numel = input.numel() 

155 scale = 1.0 / (1.0 - p) 

156 

157 # Create mask tensor of shape (N, C) 

158 mask = torch.empty(N, C, device=device, dtype=torch.float32) 

159 

160 # Generate mask 

161 BLOCK_N = min(triton.next_power_of_2(N), 64) 

162 BLOCK_C = min(triton.next_power_of_2(C), 64) 

163 grid_mask = (triton.cdiv(N, BLOCK_N), triton.cdiv(C, BLOCK_C)) 

164 

165 # Need N * C random numbers 

166 increment = triton.cdiv(N * C, 4) * 4 

167 with torch_device_fn.device(device): 

168 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

169 generate_feature_mask_kernel[grid_mask]( 

170 mask, N, C, p, scale, philox_seed, philox_offset, BLOCK_N, BLOCK_C 

171 ) 

172 

173 # Apply mask to input 

174 BLOCK = 1024 

175 grid_apply = (triton.cdiv(numel, BLOCK),) 

176 

177 with torch_device_fn.device(device): 

178 apply_feature_mask_kernel[grid_apply]( 

179 input, out, mask, numel, N, C, spatial_size, BLOCK 

180 ) 

181 

182 return out 

183 

184 

185def feature_dropout_(input, p, train=True): 

186 """ 

187 In-place version of feature_dropout. 

188 """ 

189 logger.debug("GEMS FEATURE_DROPOUT_") 

190 if not train or p == 0: 

191 return input 

192 if p == 1: 

193 input.zero_() 

194 return input 

195 out = feature_dropout(input, p, train) 

196 input.copy_(out) 

197 return input