Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dropout.py: 0%
122 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.heuristics(runtime.get_heuristic_config("dropout"))
18@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
19def dropout_forward_kernel(
20 X,
21 Y,
22 dropout_mask,
23 N,
24 p,
25 philox_seed,
26 philox_offset,
27 BLOCK: tl.constexpr,
28):
29 UNROLL: tl.constexpr = 8
30 philox_seed = philox_seed.to(tl.int64)
31 philox_offset = philox_offset.to(tl.int64)
32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
35 # First set of 4 random numbers
36 i4_0 = tl.program_id(0) * BLOCK * 2 + tl.arange(0, BLOCK)
37 c0_0 = c0 + i4_0
38 _O = c0_0 * 0
39 r0, r1, r2, r3 = tl.philox(philox_seed, c0_0, c1, _O, _O)
40 r0 = uint_to_uniform_float(r0)
41 r1 = uint_to_uniform_float(r1)
42 r2 = uint_to_uniform_float(r2)
43 r3 = uint_to_uniform_float(r3)
45 # Second set of 4 random numbers
46 i4_1 = tl.program_id(0) * BLOCK * 2 + BLOCK + tl.arange(0, BLOCK)
47 c0_1 = c0 + i4_1
48 _O1 = c0_1 * 0
49 r4, r5, r6, r7 = tl.philox(philox_seed, c0_1, c1, _O1, _O1)
50 r4 = uint_to_uniform_float(r4)
51 r5 = uint_to_uniform_float(r5)
52 r6 = uint_to_uniform_float(r6)
53 r7 = uint_to_uniform_float(r7)
55 mask0 = r0 > p
56 mask1 = r1 > p
57 mask2 = r2 > p
58 mask3 = r3 > p
59 mask4 = r4 > p
60 mask5 = r5 > p
61 mask6 = r6 > p
62 mask7 = r7 > p
63 scale = 1.0 / (1.0 - p)
65 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
66 off_1 = off_0 + BLOCK
67 off_2 = off_1 + BLOCK
68 off_3 = off_2 + BLOCK
69 off_4 = off_3 + BLOCK
70 off_5 = off_4 + BLOCK
71 off_6 = off_5 + BLOCK
72 off_7 = off_6 + BLOCK
74 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0)
75 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0)
76 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0)
77 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0)
78 x4 = tl.load(X + off_4, mask=off_4 < N, other=0.0)
79 x5 = tl.load(X + off_5, mask=off_5 < N, other=0.0)
80 x6 = tl.load(X + off_6, mask=off_6 < N, other=0.0)
81 x7 = tl.load(X + off_7, mask=off_7 < N, other=0.0)
83 y0 = tl.where(mask0, x0 * scale, 0.0)
84 y1 = tl.where(mask1, x1 * scale, 0.0)
85 y2 = tl.where(mask2, x2 * scale, 0.0)
86 y3 = tl.where(mask3, x3 * scale, 0.0)
87 y4 = tl.where(mask4, x4 * scale, 0.0)
88 y5 = tl.where(mask5, x5 * scale, 0.0)
89 y6 = tl.where(mask6, x6 * scale, 0.0)
90 y7 = tl.where(mask7, x7 * scale, 0.0)
92 tl.store(Y + off_0, y0, mask=off_0 < N)
93 tl.store(Y + off_1, y1, mask=off_1 < N)
94 tl.store(Y + off_2, y2, mask=off_2 < N)
95 tl.store(Y + off_3, y3, mask=off_3 < N)
96 tl.store(Y + off_4, y4, mask=off_4 < N)
97 tl.store(Y + off_5, y5, mask=off_5 < N)
98 tl.store(Y + off_6, y6, mask=off_6 < N)
99 tl.store(Y + off_7, y7, mask=off_7 < N)
100 tl.store(dropout_mask + off_0, mask0, mask=off_0 < N)
101 tl.store(dropout_mask + off_1, mask1, mask=off_1 < N)
102 tl.store(dropout_mask + off_2, mask2, mask=off_2 < N)
103 tl.store(dropout_mask + off_3, mask3, mask=off_3 < N)
104 tl.store(dropout_mask + off_4, mask4, mask=off_4 < N)
105 tl.store(dropout_mask + off_5, mask5, mask=off_5 < N)
106 tl.store(dropout_mask + off_6, mask6, mask=off_6 < N)
107 tl.store(dropout_mask + off_7, mask7, mask=off_7 < N)
110@triton.heuristics(runtime.get_heuristic_config("dropout"))
111@triton.jit(do_not_specialize=["scale"])
112def dropout_backward_kernel(
113 DY,
114 DX,
115 dropout_mask,
116 N,
117 scale,
118 BLOCK: tl.constexpr,
119):
120 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
121 mask = offset < N
122 m = tl.load(dropout_mask + offset, mask=mask, other=0)
123 dy = tl.load(DY + offset, mask=mask, other=0)
124 dx = dy * m * scale
125 tl.store(DX + offset, dx, mask=mask)
128UNROLL = 8
131def dropout(input, p, train=True):
132 logger.debug("GEMS_KUNLUNXIN NATIVE DROPOUT FORWARD")
133 if not train or p == 0:
134 out = input.clone()
135 mask = torch.ones_like(input, dtype=torch.bool)
136 return out, mask
137 if p == 1:
138 out = torch.zeros_like(input)
139 mask = torch.zeros_like(input, dtype=torch.bool)
140 return out, mask
141 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
142 device = input.device
143 input = input.contiguous()
144 out = torch.empty_like(input)
145 mask = torch.empty_like(input, dtype=torch.bool)
146 N = input.numel()
147 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
148 increment = triton.cdiv(N, UNROLL)
149 with torch_device_fn.device(device):
150 philox_seed, philox_offset = philox_backend_seed_offset(increment)
151 dropout_forward_kernel[grid_fn](
152 input, out, mask, N, p, philox_seed, philox_offset
153 )
154 return out, mask
157def dropout_backward(grad_output, mask, scale):
158 logger.debug("GEMS_KUNLUNXIN NATIVE DROPOUT BACKWARD")
159 grad_output = grad_output.contiguous()
160 grad_input = torch.empty_like(grad_output)
161 N = grad_output.numel()
162 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),)
163 with torch_device_fn.device(grad_output.device):
164 dropout_backward_kernel[grid_fn](grad_output, grad_input, mask, N, scale)
165 return grad_input