Coverage for src/flag_gems/runtime/backend/_sunrise/ops/dropout.py: 0%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.heuristics(runtime.get_heuristic_config("dropout"))
18@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
19def dropout_forward_kernel(
20 X,
21 Y,
22 dropout_mask,
23 N,
24 p,
25 philox_seed,
26 philox_offset,
27 BLOCK: tl.constexpr,
28):
29 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
30 philox_seed = philox_seed.to(tl.int64)
31 philox_offset = philox_offset.to(tl.int64)
32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
34 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
35 c0 += i4
36 _O = c0 * 0
37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
38 r0 = uint_to_uniform_float(r0)
39 r1 = uint_to_uniform_float(r1)
40 r2 = uint_to_uniform_float(r2)
41 r3 = uint_to_uniform_float(r3)
43 mask0 = r0 > p
44 mask1 = r1 > p
45 mask2 = r2 > p
46 mask3 = r3 > p
47 p = 1.0 / (1.0 - p)
49 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
50 off_1 = off_0 + BLOCK
51 off_2 = off_1 + BLOCK
52 off_3 = off_2 + BLOCK
54 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first")
55 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first")
56 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first")
57 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first")
59 y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
60 y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
61 y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
62 y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)
64 tl.store(
65 dropout_mask + off_0,
66 mask0.to(tl.int32),
67 mask=off_0 < N,
68 eviction_policy="evict_first",
69 )
70 tl.store(
71 dropout_mask + off_1,
72 mask1.to(tl.int32),
73 mask=off_1 < N,
74 eviction_policy="evict_first",
75 )
76 tl.store(
77 dropout_mask + off_2,
78 mask2.to(tl.int32),
79 mask=off_2 < N,
80 eviction_policy="evict_first",
81 )
82 tl.store(
83 dropout_mask + off_3,
84 mask3.to(tl.int32),
85 mask=off_3 < N,
86 eviction_policy="evict_first",
87 )
89 tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
90 tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
91 tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
92 tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
95@triton.heuristics(runtime.get_heuristic_config("dropout"))
96@triton.jit(do_not_specialize=["scale"])
97def dropout_backward_kernel(
98 DY,
99 DX,
100 dropout_mask,
101 N,
102 scale,
103 BLOCK: tl.constexpr,
104):
105 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
106 mask = offset < N
107 m = tl.load(
108 dropout_mask + offset, mask=mask, other=0, eviction_policy="evict_first"
109 )
110 dy = tl.load(DY + offset, mask=mask, other=0, eviction_policy="evict_first")
111 dx = dy * m * scale
112 tl.store(DX + offset, dx, mask=mask, eviction_policy="evict_first")
115UNROLL = 4
118def dropout(input, p, train=True):
119 logger.debug("GEMS NATIVE DROPOUT FORWARD")
120 if not train or p == 0:
121 out = input.clone()
122 mask = torch.ones_like(input, dtype=torch.bool)
123 return out, mask
124 if p == 1:
125 out = torch.zeros_like(input)
126 mask = torch.zeros_like(input, dtype=torch.bool)
127 return out, mask
128 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
129 device = input.device
130 # TODO: remove contiguous enforcement
131 input = input.contiguous()
132 out = torch.empty_like(input)
133 mask = torch.empty_like(input, dtype=torch.int)
134 N = input.numel()
135 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
136 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
137 # hence we cannot obtain the per thread offset as in Pytorch.
138 increment = triton.cdiv(N, UNROLL)
139 with torch_device_fn.device(device):
140 philox_seed, philox_offset = philox_backend_seed_offset(increment)
141 dropout_forward_kernel[grid_fn](
142 input, out, mask, N, p, philox_seed, philox_offset
143 )
144 return out, mask
147def dropout_backward(grad_output, mask, scale):
148 logger.debug("GEMS NATIVE DROPOUT BACKWARD")
149 grad_output = grad_output.contiguous()
150 grad_input = torch.empty_like(grad_output)
151 N = grad_output.numel()
152 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),)
153 with torch_device_fn.device(grad_output.device):
154 dropout_backward_kernel[grid_fn](grad_output, grad_input, mask, N, scale)
155 return grad_input