Coverage for src/flag_gems/runtime/backend/_cambricon/ops/relu.py: 0%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
10from ..utils import TOTAL_CORE_NUM
11from ..utils.pointwise_dynamic import pointwise_dynamic
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=[
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
23 ],
24 key=["n_elements"],
25)
26@triton.jit
27def relu_kernel(
28 X_ptr,
29 OUT_ptr,
30 n_elements,
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tl.program_id(0)
34 num_jobs = tl.num_programs(0)
35 block_start = pid * BLOCK_SIZE
36 step = num_jobs * BLOCK_SIZE
37 block_start = block_start.to(tl.int64)
38 for off in range(block_start, n_elements, step):
39 offsets = off + tl.arange(0, BLOCK_SIZE)
40 mask = offsets < n_elements
41 x = tl.load(X_ptr + offsets, mask=mask)
42 tl.store(OUT_ptr + offsets, tl.where(x > 0, x, 0), mask=mask)
45# keep backward using pointwise_dynamic
46@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
47@triton.jit
48def relu_backward(x, dy):
49 return tl.where(x > 0, dy, 0)
52def relu(self):
53 logger.debug("GEMS_CAMBRICON RELU FORWARD")
54 A = self.contiguous()
55 out = torch.empty_like(A)
56 N = A.numel()
57 if N == 0:
58 return out
59 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
60 with torch_device_fn.device(A.device):
61 relu_kernel[grid_fn](A, out, N)
62 return out
65def relu_(A):
66 logger.debug("GEMS_CAMBRICON RELU_ FORWARD")
67 A_contig = A.contiguous()
68 N = A_contig.numel()
69 if N == 0:
70 return A
71 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
72 with torch_device_fn.device(A.device):
73 relu_kernel[grid_fn](A_contig, A_contig, N)
74 if not A.is_contiguous():
75 A.copy_(A_contig)
76 return A