Coverage for src/flag_gems/ops/feature_dropout.py: 59%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
13logger = logging.getLogger(__name__)
16@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
17def generate_feature_mask_kernel(
18 MASK,
19 N, # batch size
20 C, # number of channels
21 p,
22 scale,
23 philox_seed,
24 philox_offset,
25 BLOCK_N: tl.constexpr,
26 BLOCK_C: tl.constexpr,
27):
28 """
29 Generate a feature dropout mask of shape (N, C).
30 Each element is either 0 (dropped) or scale (kept).
31 Each (n, c) pair gets its own random value.
32 """
33 philox_seed = philox_seed.to(tl.int64)
34 philox_offset = philox_offset.to(tl.int64)
36 pid_n = tl.program_id(0)
37 pid_c = tl.program_id(1)
39 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
40 c_offset = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
42 n_mask = n_offset < N
43 c_mask = c_offset < C
45 # Compute flat indices for random number generation
46 # flat_idx = n * C + c
47 flat_idx = n_offset[:, None] * C + c_offset[None, :]
49 # Generate random numbers using philox
50 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
51 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
52 i4 = flat_idx.to(tl.uint32)
53 c0 = c0 + i4
54 _O = c0 * 0
55 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O)
56 rand_vals = uint_to_uniform_float(r0)
58 # Create mask: scale if rand > p (keep), 0 if rand <= p (drop)
59 mask_vals = tl.where(rand_vals > p, scale, 0.0)
61 # Store mask
62 mask_offsets = n_offset[:, None] * C + c_offset[None, :]
63 mask_mask = n_mask[:, None] & c_mask[None, :]
64 tl.store(MASK + mask_offsets, mask_vals, mask=mask_mask)
67@triton.jit
68def apply_feature_mask_kernel(
69 X,
70 Y,
71 MASK,
72 numel,
73 N, # batch size
74 C, # channels
75 spatial_size, # H * W or D1 * D2 * ...
76 BLOCK: tl.constexpr,
77):
78 """
79 Apply feature mask to input tensor.
80 Input shape: (N, C, ...) flattened to (numel,)
81 Mask shape: (N, C)
83 For element at flat index i:
84 - For contiguous (N, C, H, W) layout: i = n * (C * spatial) + c * spatial + spatial_idx
85 - n = i // (C * spatial_size)
86 - c = (i // spatial_size) % C
87 - mask_idx = n * C + c
88 """
89 pid = tl.program_id(0)
90 offset = pid * BLOCK + tl.arange(0, BLOCK)
91 mask = offset < numel
93 # Compute batch and channel index for each element
94 channel_spatial_size = C * spatial_size
95 n_idx = offset // channel_spatial_size
96 c_idx = (offset % channel_spatial_size) // spatial_size
98 # Compute mask index: n * C + c
99 mask_idx = n_idx * C + c_idx
101 # Load input and mask
102 x = tl.load(X + offset, mask=mask, other=0.0)
103 m = tl.load(MASK + mask_idx, mask=mask, other=0.0)
105 # Apply mask
106 y = x * m
108 tl.store(Y + offset, y, mask=mask)
111def feature_dropout(input, p, train=True):
112 """
113 Applies feature dropout to the input tensor.
115 Randomly zeroes out entire channels of the input tensor with probability p.
116 Each batch element has its own independent channel mask.
118 Args:
119 input: Input tensor of shape (N, C, ...) where N is batch size, C is channels
120 p: Probability of a channel to be zeroed. Default: 0.5
121 train: If True, applies dropout. If False, returns input unchanged.
123 Returns:
124 Output tensor of same shape as input
125 """
126 logger.debug("GEMS FEATURE_DROPOUT")
128 if not train or p == 0:
129 return input.clone()
131 if p == 1:
132 return torch.zeros_like(input)
134 if input.ndim < 2:
135 raise RuntimeError(
136 "Feature dropout requires at least 2 dimensions in the input"
137 )
139 assert 0.0 < p < 1.0, "p must be in (0, 1)"
141 device = input.device
142 input = input.contiguous()
143 out = torch.empty_like(input)
145 # Get dimensions
146 batch_size = input.shape[0]
147 num_channels = input.shape[1]
148 spatial_size = 1
149 for i in range(2, input.ndim):
150 spatial_size *= input.shape[i]
152 N = batch_size
153 C = num_channels
154 numel = input.numel()
155 scale = 1.0 / (1.0 - p)
157 # Create mask tensor of shape (N, C)
158 mask = torch.empty(N, C, device=device, dtype=torch.float32)
160 # Generate mask
161 BLOCK_N = min(triton.next_power_of_2(N), 64)
162 BLOCK_C = min(triton.next_power_of_2(C), 64)
163 grid_mask = (triton.cdiv(N, BLOCK_N), triton.cdiv(C, BLOCK_C))
165 # Need N * C random numbers
166 increment = triton.cdiv(N * C, 4) * 4
167 with torch_device_fn.device(device):
168 philox_seed, philox_offset = philox_backend_seed_offset(increment)
169 generate_feature_mask_kernel[grid_mask](
170 mask, N, C, p, scale, philox_seed, philox_offset, BLOCK_N, BLOCK_C
171 )
173 # Apply mask to input
174 BLOCK = 1024
175 grid_apply = (triton.cdiv(numel, BLOCK),)
177 with torch_device_fn.device(device):
178 apply_feature_mask_kernel[grid_apply](
179 input, out, mask, numel, N, C, spatial_size, BLOCK
180 )
182 return out
185def feature_dropout_(input, p, train=True):
186 """
187 In-place version of feature_dropout.
188 """
189 logger.debug("GEMS FEATURE_DROPOUT_")
190 if not train or p == 0:
191 return input
192 if p == 1:
193 input.zero_()
194 return input
195 out = feature_dropout(input, p, train)
196 input.copy_(out)
197 return input