Coverage for src/flag_gems/ops/arcsinh.py: 69%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def arcsinh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

21 

22 # Compute asinh using: asinh(x) = log(x + sqrt(x*x + 1)) 

23 x_f32 = x.to(tl.float32) 

24 tmp = x_f32 * x_f32 + 1.0 

25 sqrt_term = tl.sqrt(tmp) 

26 y_f32 = tl.log(x_f32 + sqrt_term) 

27 

28 # Store result; will cast to out dtype as needed 

29 tl.store(out_ptr + offsets, y_f32, mask=mask) 

30 

31 

32def _ensure_cuda_tensor(t): 

33 if not isinstance(t, torch.Tensor): 

34 raise TypeError("Expected a torch.Tensor") 

35 if t.device.type != flag_gems.device: 

36 raise ValueError(f"Input tensors must be on {flag_gems.device} device") 

37 if t.is_complex(): 

38 raise NotImplementedError( 

39 "Complex dtypes are not supported by this Triton kernel" 

40 ) 

41 

42 

43def _arcsinh_impl(input_tensor: torch.Tensor, out_tensor: torch.Tensor = None): 

44 _ensure_cuda_tensor(input_tensor) 

45 

46 # Determine result dtype following basic promotion: float -> same, otherwise float32 

47 if input_tensor.is_floating_point(): 

48 result_dtype = input_tensor.dtype 

49 else: 

50 result_dtype = torch.float32 

51 

52 x = input_tensor 

53 n_elements = x.numel() 

54 

55 if out_tensor is None: 

56 out = torch.empty_like(x, dtype=result_dtype, device=x.device) 

57 else: 

58 _ensure_cuda_tensor(out_tensor) 

59 if out_tensor.numel() != n_elements: 

60 raise ValueError( 

61 "Output tensor must have the same number of elements as input" 

62 ) 

63 # Enforce dtype consistent with promotion 

64 if out_tensor.dtype != (result_dtype): 

65 raise TypeError( 

66 f"Output tensor has dtype {out_tensor.dtype}, expected {result_dtype}" 

67 ) 

68 out = out_tensor 

69 

70 # Work with contiguous buffers for the kernel 

71 x_contig = x.contiguous() 

72 out_contig = out if out.is_contiguous() else out.contiguous() 

73 

74 # Launch kernel 

75 BLOCK_SIZE = 1024 

76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

77 arcsinh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

78 

79 # If out was non-contiguous, copy back 

80 if out_contig.data_ptr() != out.data_ptr(): 

81 out.copy_(out_contig) 

82 

83 return out 

84 

85 

86def arcsinh(input_tensor: torch.Tensor): 

87 logger.debug("GEMS ARCSINH") 

88 return _arcsinh_impl(input_tensor) 

89 

90 

91def arcsinh_out(input_tensor: torch.Tensor, out: torch.Tensor): 

92 logger.debug("GEMS ARCSINH_OUT") 

93 return _arcsinh_impl(input_tensor, out)