Coverage for src/flag_gems/runtime/backend/_ascend/ops/full_like.py: 0%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
7from flag_gems.runtime import torch_device_fn
9from .full import check_dtype, full_kernel
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14def full_like(
15 x,
16 fill_value,
17 *,
18 dtype=None,
19 layout=None,
20 device=None,
21 pin_memory=None,
22 memory_format=None,
23):
24 logger.debug("GEMS_ASCEND FULL_LIKE")
25 if device is None:
26 device = x.device
27 if dtype is None:
28 dtype = x.dtype
29 fill_value = check_dtype(fill_value, dtype, device)
30 out = torch.empty_like(x, device=device, dtype=dtype)
31 N = x.numel()
32 BLOCK_SIZE = min(triton.next_power_of_2(math.ceil(math.sqrt(N))), 2048)
33 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
34 with torch_device_fn.device(x.device):
35 full_kernel[grid_fn](
36 out,
37 N,
38 fill_value,
39 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor),
40 BLOCK_SIZE=BLOCK_SIZE,
41 )
42 return out