Coverage for src/flag_gems/runtime/backend/_cambricon/ops/dropout.py: 0%
88 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import torch_mlu # noqa: F401
5import triton
6import triton.language as tl
7from triton.language.extra.mlu.libdevice import philox as _philox
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils.random_utils import (
12 philox_backend_seed_offset,
13 uint_to_uniform_float,
14)
16from ..utils import TOTAL_CORE_NUM
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
20UNROLL = 4
23@libentry()
24@libtuner(
25 configs=[
26 triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1),
27 triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1),
28 triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1),
29 triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1),
30 ],
31 key=["N"],
32)
33@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
34def dropout_forward_kernel(
35 X,
36 Y,
37 dropout_mask,
38 N,
39 p,
40 philox_seed,
41 philox_offset,
42 BLOCK: tl.constexpr,
43):
44 UNROLL: tl.constexpr = 4
45 philox_seed = philox_seed.to(tl.int64)
46 philox_offset = philox_offset.to(tl.int64)
48 pid = tl.program_id(0)
49 num_jobs = tl.num_programs(0)
50 i4_start = pid * BLOCK
51 block_start = pid * UNROLL * BLOCK
52 step = num_jobs * BLOCK * UNROLL
53 mp = 1.0 / (1.0 - p)
55 for block_offset in range(block_start, N, step):
56 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
57 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
60 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
61 r = uint_to_uniform_float(r)
63 mask = r > p
64 mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)
66 off = block_offset + tl.arange(0, UNROLL * BLOCK)
67 valid = off < N
68 x = tl.load(X + off, mask=valid, other=0.0)
69 y = tl.where(mask_reshaped, x * mp, 0.0)
70 tl.store(dropout_mask + off, mask_reshaped, mask=valid)
71 tl.store(Y + off, y, mask=valid)
72 i4_start += num_jobs * BLOCK
75@libentry()
76@libtuner(
77 configs=[
78 triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1),
79 triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1),
80 triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1),
81 triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1),
82 ],
83 key=["N"],
84)
85@triton.jit(do_not_specialize=["scale"])
86def dropout_backward_kernel(
87 DY,
88 DX,
89 dropout_mask,
90 N,
91 scale,
92 BLOCK: tl.constexpr,
93):
94 UNROLL: tl.constexpr = 4
95 pid = tl.program_id(0)
96 num_programs = tl.num_programs(0)
97 block_start = pid * UNROLL * BLOCK
98 step = num_programs * UNROLL * BLOCK
99 for block_offset in range(block_start, N, step):
100 off = block_offset + tl.arange(0, UNROLL * BLOCK)
101 valid = off < N
102 mask = tl.load(
103 dropout_mask + off, mask=valid, other=0, eviction_policy="evict_first"
104 )
105 dy = tl.load(DY + off, mask=valid, other=0.0, eviction_policy="evict_first")
106 dx = dy * mask * scale
107 tl.store(DX + off, dx, mask=valid, eviction_policy="evict_first")
110def dropout(input, p, train=True):
111 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT FORWARD")
112 if not train or p == 0:
113 out = input.clone()
114 mask = torch.ones_like(input, dtype=torch.bool)
115 return out, mask
116 if p == 1:
117 out = torch.zeros_like(input)
118 mask = torch.zeros_like(input, dtype=torch.bool)
119 return out, mask
120 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
121 device = input.device
122 input = input.contiguous()
123 out = torch.empty_like(input)
124 mask = torch.empty_like(input, dtype=torch.bool)
125 N = input.numel()
126 grid_fn = lambda meta: (
127 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
128 )
129 increment = triton.cdiv(N, UNROLL)
130 with torch_device_fn.device(device):
131 philox_seed, philox_offset = philox_backend_seed_offset(increment)
132 dropout_forward_kernel[grid_fn](
133 input,
134 out,
135 mask,
136 N,
137 p,
138 philox_seed,
139 philox_offset,
140 )
141 return out, mask
144def dropout_backward(grad_output, mask, scale):
145 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT BACKWARD")
146 grad_output = grad_output.contiguous()
147 grad_input = torch.empty_like(grad_output)
148 N = grad_output.numel()
149 grid_fn = lambda meta: (
150 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
151 )
152 with torch_device_fn.device(grad_output.device):
153 dropout_backward_kernel[grid_fn](
154 grad_output,
155 grad_input,
156 mask,
157 N,
158 scale,
159 )
160 return grad_input