Coverage for src/flag_gems/ops/prelu.py: 65%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def prelu_kernel(
15 x_ptr, # *Pointer* to input tensor.
16 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector).
17 out_ptr, # *Pointer* to output tensor.
18 n_elements, # Total number of elements in input.
19 S, # Spatial size = product of dims after channel dim (or 1 if none).
20 C, # Number of channels (or 1).
21 w_is_scalar: tl.constexpr, # Whether weight is a single scalar.
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid = tl.program_id(axis=0)
25 block_start = pid * BLOCK_SIZE
26 offsets = block_start + tl.arange(0, BLOCK_SIZE)
27 mask = offsets < n_elements
29 x = tl.load(x_ptr + offsets, mask=mask)
31 if w_is_scalar:
32 alpha = tl.load(w_ptr) # scalar
33 y = tl.where(x >= 0, x, alpha * x)
34 else:
35 c = (offsets // S) % C
36 alpha = tl.load(w_ptr + c, mask=mask)
37 y = tl.where(x >= 0, x, alpha * x)
39 tl.store(out_ptr + offsets, y, mask=mask)
42def prelu(*args, **kwargs):
43 logger.debug("GEMS PRELU")
44 # Extract inputs
45 if len(args) >= 2:
46 x, weight = args[0], args[1]
47 else:
48 x = kwargs.get("input", kwargs.get("self"))
49 weight = kwargs.get("weight")
50 if x is None or weight is None:
51 raise ValueError("prelu expects (input, weight) as arguments.")
53 if x.device.type != flag_gems.device or weight.device.type != flag_gems.device:
54 raise AssertionError(f"Tensors must be {flag_gems.device} tensors.")
56 # Ensure dtype match
57 if weight.dtype != x.dtype:
58 weight = weight.to(dtype=x.dtype)
60 # Ensure contiguous
61 x = x.contiguous()
62 weight = weight.contiguous()
64 out = torch.empty_like(x)
66 n_elements = x.numel()
67 if n_elements == 0:
68 return out
70 # Determine channel count C and spatial size S
71 ndim = x.dim()
72 if weight.numel() == 1:
73 C = 1
74 S = 1
75 w_is_scalar = True
76 else:
77 if ndim == 0:
78 raise AssertionError("Non-scalar weight provided for a 0-dim input.")
79 if ndim == 1:
80 C = x.shape[0]
81 S = 1
82 else:
83 C = x.shape[1]
84 S = 1
85 if ndim > 2:
86 for d in x.shape[2:]:
87 S *= d
88 if weight.numel() != C:
89 raise AssertionError(
90 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})."
91 )
92 w_is_scalar = False
94 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math
95 C = max(int(C), 1)
96 S = max(int(S), 1)
98 BLOCK_SIZE = 1024
99 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
101 prelu_kernel[grid](
102 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE
103 )
104 return out