Coverage for src/flag_gems/ops/leaky_relu.py: 76%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13def _leaky_relu_autotune_configs():
14 return [
15 # Tiny tensors (n <= 32K): small blocks
16 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2),
17 triton.Config({"BLOCK_SIZE": 256}, num_warps=8, num_stages=2),
18 triton.Config({"BLOCK_SIZE": 512}, num_warps=4, num_stages=2),
19 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2),
20 # Small-medium tensors (n ~ 64K-4M): 1024-element blocks
21 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2),
22 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2),
23 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=3),
24 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=3),
25 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
26 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
27 triton.Config({"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
28 # Medium-large tensors (n ~ 4M-16M): 2048-element blocks
29 triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=3),
30 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3),
31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
32 triton.Config({"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
33 triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=5),
34 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=5),
35 triton.Config({"BLOCK_SIZE": 2048}, num_warps=16, num_stages=5),
36 # Large tensors (n >= 16M): 4096-element blocks for max bandwidth
37 triton.Config({"BLOCK_SIZE": 4096}, num_warps=4, num_stages=3),
38 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=3),
39 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=4),
40 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4),
41 triton.Config({"BLOCK_SIZE": 4096}, num_warps=4, num_stages=5),
42 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=5),
43 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=5),
44 ]
47@libentry()
48@triton.autotune(configs=_leaky_relu_autotune_configs(), key=["n_elements"])
49@triton.jit(do_not_specialize=["negative_slope"])
50def _leaky_relu_kernel(
51 input_ptr,
52 output_ptr,
53 n_elements,
54 negative_slope,
55 BLOCK_SIZE: tl.constexpr,
56):
57 pid = tl.program_id(0)
58 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
59 mask = offsets < n_elements
61 x = tl.load(input_ptr + offsets, mask=mask)
62 output = tl.where(x >= 0, x, x * negative_slope)
63 tl.store(output_ptr + offsets, output, mask=mask)
66def leaky_relu(A, negative_slope=0.01):
67 logger.debug("GEMS LEAKY_RELU")
68 if not A.is_contiguous():
69 A = A.contiguous()
70 output = torch.empty_like(A)
71 n_elements = A.numel()
72 if n_elements == 0:
73 return output
74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
75 with torch_device_fn.device(A.device.index):
76 _leaky_relu_kernel[grid](A, output, n_elements, negative_slope)
77 return output
80def leaky_relu_(A, negative_slope=0.01):
81 logger.debug("GEMS LEAKY_RELU_")
82 if not A.is_contiguous():
83 raise RuntimeError(
84 "leaky_relu_ requires a contiguous tensor for in-place operation"
85 )
86 n_elements = A.numel()
87 if n_elements == 0:
88 return A
89 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
90 with torch_device_fn.device(A.device.index):
91 _leaky_relu_kernel[grid](A, A, n_elements, negative_slope)
92 return A
95def leaky_relu_out(A, negative_slope=0.01, *, out=None):
96 logger.debug("GEMS LEAKY_RELU_OUT")
97 if out is None:
98 return leaky_relu(A, negative_slope)
99 if not A.is_contiguous():
100 A = A.contiguous()
101 n_elements = A.numel()
102 if n_elements == 0:
103 return out
104 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
105 with torch_device_fn.device(A.device.index):
106 _leaky_relu_kernel[grid](A, out, n_elements, negative_slope)
107 return out