Coverage for src/flag_gems/ops/prelu.py: 65%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def prelu_kernel( 

15 x_ptr, # *Pointer* to input tensor. 

16 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector). 

17 out_ptr, # *Pointer* to output tensor. 

18 n_elements, # Total number of elements in input. 

19 S, # Spatial size = product of dims after channel dim (or 1 if none). 

20 C, # Number of channels (or 1). 

21 w_is_scalar: tl.constexpr, # Whether weight is a single scalar. 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid = tl.program_id(axis=0) 

25 block_start = pid * BLOCK_SIZE 

26 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

27 mask = offsets < n_elements 

28 

29 x = tl.load(x_ptr + offsets, mask=mask) 

30 

31 if w_is_scalar: 

32 alpha = tl.load(w_ptr) # scalar 

33 y = tl.where(x >= 0, x, alpha * x) 

34 else: 

35 c = (offsets // S) % C 

36 alpha = tl.load(w_ptr + c, mask=mask) 

37 y = tl.where(x >= 0, x, alpha * x) 

38 

39 tl.store(out_ptr + offsets, y, mask=mask) 

40 

41 

42def prelu(*args, **kwargs): 

43 logger.debug("GEMS PRELU") 

44 # Extract inputs 

45 if len(args) >= 2: 

46 x, weight = args[0], args[1] 

47 else: 

48 x = kwargs.get("input", kwargs.get("self")) 

49 weight = kwargs.get("weight") 

50 if x is None or weight is None: 

51 raise ValueError("prelu expects (input, weight) as arguments.") 

52 

53 if x.device.type != flag_gems.device or weight.device.type != flag_gems.device: 

54 raise AssertionError(f"Tensors must be {flag_gems.device} tensors.") 

55 

56 # Ensure dtype match 

57 if weight.dtype != x.dtype: 

58 weight = weight.to(dtype=x.dtype) 

59 

60 # Ensure contiguous 

61 x = x.contiguous() 

62 weight = weight.contiguous() 

63 

64 out = torch.empty_like(x) 

65 

66 n_elements = x.numel() 

67 if n_elements == 0: 

68 return out 

69 

70 # Determine channel count C and spatial size S 

71 ndim = x.dim() 

72 if weight.numel() == 1: 

73 C = 1 

74 S = 1 

75 w_is_scalar = True 

76 else: 

77 if ndim == 0: 

78 raise AssertionError("Non-scalar weight provided for a 0-dim input.") 

79 if ndim == 1: 

80 C = x.shape[0] 

81 S = 1 

82 else: 

83 C = x.shape[1] 

84 S = 1 

85 if ndim > 2: 

86 for d in x.shape[2:]: 

87 S *= d 

88 if weight.numel() != C: 

89 raise AssertionError( 

90 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})." 

91 ) 

92 w_is_scalar = False 

93 

94 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math 

95 C = max(int(C), 1) 

96 S = max(int(S), 1) 

97 

98 BLOCK_SIZE = 1024 

99 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

100 

101 prelu_kernel[grid]( 

102 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE 

103 ) 

104 return out