Coverage for src/flag_gems/runtime/backend/_sunrise/ops/prelu.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def prelu_kernel( 

13 x_ptr, # *Pointer* to input tensor. 

14 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector). 

15 out_ptr, # *Pointer* to output tensor. 

16 n_elements, # Total number of elements in input. 

17 S, # Spatial size = product of dims after channel dim (or 1 if none). 

18 C, # Number of channels (or 1). 

19 w_is_scalar: tl.constexpr, # Whether weight is a single scalar. 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tl.program_id(axis=0) 

23 block_start = pid * BLOCK_SIZE 

24 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

25 mask = offsets < n_elements 

26 

27 x = tl.load(x_ptr + offsets, mask=mask) 

28 

29 if w_is_scalar: 

30 alpha = tl.load(w_ptr) # scalar 

31 y = tl.where(x >= 0, x, alpha * x) 

32 else: 

33 c = (offsets // S) % C 

34 alpha = tl.load(w_ptr + c, mask=mask) 

35 y = tl.where(x >= 0, x, alpha * x) 

36 

37 tl.store(out_ptr + offsets, y, mask=mask) 

38 

39 

40def prelu(*args, **kwargs): 

41 logger.debug("GEMS PRELU") 

42 # Extract inputs 

43 if len(args) >= 2: 

44 x, weight = args[0], args[1] 

45 else: 

46 x = kwargs.get("input", kwargs.get("self")) 

47 weight = kwargs.get("weight") 

48 if x is None or weight is None: 

49 raise ValueError("prelu expects (input, weight) as arguments.") 

50 

51 if not (x.is_ptpu and weight.is_ptpu): 

52 raise AssertionError("Tensors must be PTPU tensors.") 

53 

54 # Ensure dtype match 

55 if weight.dtype != x.dtype: 

56 weight = weight.to(dtype=x.dtype) 

57 

58 # Ensure contiguous 

59 x = x.contiguous() 

60 weight = weight.contiguous() 

61 

62 out = torch.empty_like(x) 

63 

64 n_elements = x.numel() 

65 if n_elements == 0: 

66 return out 

67 

68 # Determine channel count C and spatial size S 

69 ndim = x.dim() 

70 if weight.numel() == 1: 

71 C = 1 

72 S = 1 

73 w_is_scalar = True 

74 else: 

75 if ndim == 0: 

76 raise AssertionError("Non-scalar weight provided for a 0-dim input.") 

77 if ndim == 1: 

78 C = x.shape[0] 

79 S = 1 

80 else: 

81 C = x.shape[1] 

82 S = 1 

83 if ndim > 2: 

84 for d in x.shape[2:]: 

85 S *= d 

86 if weight.numel() != C: 

87 raise AssertionError( 

88 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})." 

89 ) 

90 w_is_scalar = False 

91 

92 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math 

93 C = max(int(C), 1) 

94 S = max(int(S), 1) 

95 

96 BLOCK_SIZE = 1024 

97 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

98 

99 prelu_kernel[grid]( 

100 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE 

101 ) 

102 return out