Coverage for src/flag_gems/ops/asinh_.py: 41%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def asinh_kernel_( 

11 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_FP32: tl.constexpr 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask) 

19 

20 if COMPUTE_FP32: 

21 x32 = x.to(tl.float32) 

22 abs_x = tl.abs(x32) 

23 y32 = tl.log(abs_x + tl.sqrt(abs_x * abs_x + 1.0)) 

24 y32 = tl.where(x32 < 0.0, -y32, y32) 

25 y = y32.to(x.dtype) 

26 else: 

27 abs_x = tl.abs(x) 

28 y = tl.log(abs_x + tl.sqrt(abs_x * abs_x + 1.0)) 

29 y = tl.where(x < 0.0, -y, y) 

30 

31 tl.store(x_ptr + offsets, y, mask=mask) 

32 

33 

34def asinh_(*args, **kwargs): 

35 x = None 

36 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

37 x = args[0] 

38 else: 

39 for key in ("input", "self", "x"): 

40 val = kwargs.get(key, None) 

41 if isinstance(val, torch.Tensor): 

42 x = val 

43 break 

44 if x is None: 

45 raise ValueError("asinh_: expected a Tensor as the first argument") 

46 

47 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

48 return torch.ops.aten.asinh_(x) 

49 

50 BLOCK_SIZE = 1024 

51 COMPUTE_FP32 = x.dtype in (torch.float16, torch.bfloat16) 

52 

53 if x.is_contiguous(): 

54 n_elements = x.numel() 

55 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

56 with torch_device_fn.device(x.device): 

57 asinh_kernel_[grid]( 

58 x, n_elements, BLOCK_SIZE=BLOCK_SIZE, COMPUTE_FP32=COMPUTE_FP32 

59 ) 

60 return x 

61 else: 

62 y = x.contiguous() 

63 n_elements = y.numel() 

64 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

65 with torch_device_fn.device(y.device): 

66 asinh_kernel_[grid]( 

67 y, n_elements, BLOCK_SIZE=BLOCK_SIZE, COMPUTE_FP32=COMPUTE_FP32 

68 ) 

69 x.copy_(y) 

70 return x