Coverage for src/flag_gems/runtime/backend/_sunrise/ops/arcsinh.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def arcsinh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0)
22 # Compute asinh using: asinh(x) = log(x + sqrt(x*x + 1))
23 x_f32 = x.to(tl.float32)
24 tmp = x_f32 * x_f32 + 1.0
25 sqrt_term = tl.sqrt(tmp)
26 y_f32 = tl.log(x_f32 + sqrt_term)
28 # Store result; will cast to out dtype as needed
29 tl.store(out_ptr + offsets, y_f32, mask=mask)
32def _ensure_cuda_tensor(t):
33 if not isinstance(t, torch.Tensor):
34 raise TypeError("Expected a torch.Tensor")
35 if t.device.type != flag_gems.device:
36 raise ValueError(f"Input tensors must be on {flag_gems.device} device")
37 if t.is_complex():
38 raise NotImplementedError(
39 "Complex dtypes are not supported by this Triton kernel"
40 )
43def _arcsinh_impl(input_tensor: torch.Tensor, out_tensor: torch.Tensor = None):
44 _ensure_cuda_tensor(input_tensor)
46 # Determine result dtype following basic promotion: float -> same, otherwise float32
47 if input_tensor.is_floating_point():
48 result_dtype = input_tensor.dtype
49 else:
50 result_dtype = torch.float32
52 x = input_tensor
53 n_elements = x.numel()
55 if out_tensor is None:
56 out = torch.empty_like(x, dtype=result_dtype, device=x.device)
57 else:
58 _ensure_cuda_tensor(out_tensor)
59 if out_tensor.numel() != n_elements:
60 raise ValueError(
61 "Output tensor must have the same number of elements as input"
62 )
63 # Enforce dtype consistent with promotion
64 if out_tensor.dtype != (result_dtype):
65 raise TypeError(
66 f"Output tensor has dtype {out_tensor.dtype}, expected {result_dtype}"
67 )
68 out = out_tensor
70 # Work with contiguous buffers for the kernel
71 x_contig = x.contiguous()
72 out_contig = out if out.is_contiguous() else out.contiguous()
74 # Launch kernel
75 BLOCK_SIZE = 1024
76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
77 arcsinh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE)
79 # If out was non-contiguous, copy back
80 if out_contig.data_ptr() != out.data_ptr():
81 out.copy_(out_contig)
83 return out
86def arcsinh(input_tensor: torch.Tensor):
87 logger.debug("GEMS ARCSINH")
88 return _arcsinh_impl(input_tensor)
91def arcsinh_out(input_tensor: torch.Tensor, out: torch.Tensor):
92 logger.debug("GEMS ARCSINH_OUT")
93 return _arcsinh_impl(input_tensor, out)