Coverage for src/flag_gems/runtime/backend/_ascend/ops/full.py: 0%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.shape_utils import volume
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@triton.jit(do_not_specialize=["fill_value_or_ptr"])
16def full_kernel(
17 output_ptr,
18 n_elements,
19 fill_value_or_ptr,
20 FILL_VALUE_IS_PTR: tl.constexpr,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tle.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < n_elements
27 if FILL_VALUE_IS_PTR:
28 fill_value = tl.load(fill_value_or_ptr)
29 else:
30 fill_value = fill_value_or_ptr
31 tl.store(output_ptr + offsets, fill_value, mask=mask)
34ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
35ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
38def check_dtype(fill_value, dtype, device):
39 if isinstance(fill_value, bool):
40 if dtype != torch.bool:
41 fill_value = int(fill_value)
42 elif dtype in ALL_INT_DTYPES and (
43 fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max
44 ):
45 raise RuntimeError(
46 f"value cannot be converted to type {dtype} without overflow"
47 )
48 if dtype in ALL_FLOAT_DTYPES:
49 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
50 return fill_value
53def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
54 logger.debug("GEMS_ASCEND FULL")
55 if device is None:
56 device = torch.device("cpu")
57 if dtype is None:
58 if isinstance(fill_value, bool):
59 dtype = torch.bool
60 elif isinstance(fill_value, int):
61 dtype = torch.int64
62 else:
63 dtype = torch.get_default_dtype()
64 else:
65 fill_value = check_dtype(fill_value, dtype, device)
67 out = torch.empty(size, device=device, dtype=dtype)
68 N = volume(size)
69 BLOCK_SIZE = min(triton.next_power_of_2(math.ceil(math.sqrt(N))), 2048)
70 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
71 with torch_device_fn.device(device):
72 full_kernel[grid_fn](
73 out,
74 N,
75 fill_value,
76 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor),
77 BLOCK_SIZE=BLOCK_SIZE,
78 )
79 return out