Coverage for src/flag_gems/ops/leaky_relu.py: 29%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def _leaky_relu_autotune_configs(): 

14 return [ 

15 # Tiny tensors (n <= 32K): small blocks 

16 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2), 

17 triton.Config({"BLOCK_SIZE": 256}, num_warps=8, num_stages=2), 

18 triton.Config({"BLOCK_SIZE": 512}, num_warps=4, num_stages=2), 

19 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2), 

20 # Small-medium tensors (n ~ 64K-4M): 1024-element blocks 

21 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), 

22 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2), 

23 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=3), 

24 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=3), 

25 triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), 

26 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4), 

27 triton.Config({"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4), 

28 # Medium-large tensors (n ~ 4M-16M): 2048-element blocks 

29 triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=3), 

30 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3), 

31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4), 

32 triton.Config({"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4), 

33 triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), 

34 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=5), 

35 triton.Config({"BLOCK_SIZE": 2048}, num_warps=16, num_stages=5), 

36 # Large tensors (n >= 16M): 4096-element blocks for max bandwidth 

37 triton.Config({"BLOCK_SIZE": 4096}, num_warps=4, num_stages=3), 

38 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=3), 

39 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=4), 

40 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4), 

41 triton.Config({"BLOCK_SIZE": 4096}, num_warps=4, num_stages=5), 

42 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=5), 

43 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=5), 

44 ] 

45 

46 

47@libentry() 

48@triton.autotune(configs=_leaky_relu_autotune_configs(), key=["n_elements"]) 

49@triton.jit(do_not_specialize=["negative_slope"]) 

50def _leaky_relu_kernel( 

51 input_ptr, 

52 output_ptr, 

53 n_elements, 

54 negative_slope, 

55 BLOCK_SIZE: tl.constexpr, 

56): 

57 pid = tl.program_id(0) 

58 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

59 mask = offsets < n_elements 

60 

61 x = tl.load(input_ptr + offsets, mask=mask) 

62 output = tl.where(x >= 0, x, x * negative_slope) 

63 tl.store(output_ptr + offsets, output, mask=mask) 

64 

65 

66def leaky_relu(A, negative_slope=0.01): 

67 logger.debug("GEMS LEAKY_RELU") 

68 if not A.is_contiguous(): 

69 A = A.contiguous() 

70 output = torch.empty_like(A) 

71 n_elements = A.numel() 

72 if n_elements == 0: 

73 return output 

74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

75 with torch_device_fn.device(A.device.index): 

76 _leaky_relu_kernel[grid](A, output, n_elements, negative_slope) 

77 return output 

78 

79 

80def leaky_relu_(A, negative_slope=0.01): 

81 logger.debug("GEMS LEAKY_RELU_") 

82 if not A.is_contiguous(): 

83 raise RuntimeError( 

84 "leaky_relu_ requires a contiguous tensor for in-place operation" 

85 ) 

86 n_elements = A.numel() 

87 if n_elements == 0: 

88 return A 

89 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

90 with torch_device_fn.device(A.device.index): 

91 _leaky_relu_kernel[grid](A, A, n_elements, negative_slope) 

92 return A 

93 

94 

95def leaky_relu_out(A, negative_slope=0.01, *, out=None): 

96 logger.debug("GEMS LEAKY_RELU_OUT") 

97 if out is None: 

98 return leaky_relu(A, negative_slope) 

99 if not A.is_contiguous(): 

100 A = A.contiguous() 

101 n_elements = A.numel() 

102 if n_elements == 0: 

103 return out 

104 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

105 with torch_device_fn.device(A.device.index): 

106 _leaky_relu_kernel[grid](A, out, n_elements, negative_slope) 

107 return out