Coverage for src/flag_gems/runtime/backend/_sunrise/ops/prelu.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def prelu_kernel(
13 x_ptr, # *Pointer* to input tensor.
14 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector).
15 out_ptr, # *Pointer* to output tensor.
16 n_elements, # Total number of elements in input.
17 S, # Spatial size = product of dims after channel dim (or 1 if none).
18 C, # Number of channels (or 1).
19 w_is_scalar: tl.constexpr, # Whether weight is a single scalar.
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tl.program_id(axis=0)
23 block_start = pid * BLOCK_SIZE
24 offsets = block_start + tl.arange(0, BLOCK_SIZE)
25 mask = offsets < n_elements
27 x = tl.load(x_ptr + offsets, mask=mask)
29 if w_is_scalar:
30 alpha = tl.load(w_ptr) # scalar
31 y = tl.where(x >= 0, x, alpha * x)
32 else:
33 c = (offsets // S) % C
34 alpha = tl.load(w_ptr + c, mask=mask)
35 y = tl.where(x >= 0, x, alpha * x)
37 tl.store(out_ptr + offsets, y, mask=mask)
40def prelu(*args, **kwargs):
41 logger.debug("GEMS PRELU")
42 # Extract inputs
43 if len(args) >= 2:
44 x, weight = args[0], args[1]
45 else:
46 x = kwargs.get("input", kwargs.get("self"))
47 weight = kwargs.get("weight")
48 if x is None or weight is None:
49 raise ValueError("prelu expects (input, weight) as arguments.")
51 if not (x.is_ptpu and weight.is_ptpu):
52 raise AssertionError("Tensors must be PTPU tensors.")
54 # Ensure dtype match
55 if weight.dtype != x.dtype:
56 weight = weight.to(dtype=x.dtype)
58 # Ensure contiguous
59 x = x.contiguous()
60 weight = weight.contiguous()
62 out = torch.empty_like(x)
64 n_elements = x.numel()
65 if n_elements == 0:
66 return out
68 # Determine channel count C and spatial size S
69 ndim = x.dim()
70 if weight.numel() == 1:
71 C = 1
72 S = 1
73 w_is_scalar = True
74 else:
75 if ndim == 0:
76 raise AssertionError("Non-scalar weight provided for a 0-dim input.")
77 if ndim == 1:
78 C = x.shape[0]
79 S = 1
80 else:
81 C = x.shape[1]
82 S = 1
83 if ndim > 2:
84 for d in x.shape[2:]:
85 S *= d
86 if weight.numel() != C:
87 raise AssertionError(
88 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})."
89 )
90 w_is_scalar = False
92 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math
93 C = max(int(C), 1)
94 S = max(int(S), 1)
96 BLOCK_SIZE = 1024
97 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
99 prelu_kernel[grid](
100 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE
101 )
102 return out