Coverage for src/flag_gems/runtime/backend/_arm/ops/full.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import triton_lang_extension as tle
10@triton.jit(do_not_specialize=["fill_value_or_ptr"])
11def full_kernel(
12 output_ptr,
13 n_elements,
14 fill_value_or_ptr,
15 FILL_VALUE_IS_PTR: tl.constexpr,
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tle.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
22 if FILL_VALUE_IS_PTR:
23 fill_value = tl.load(fill_value_or_ptr)
24 else:
25 fill_value = fill_value_or_ptr
26 tl.store(output_ptr + offsets, fill_value, mask=mask)
29ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
30ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
33def check_dtype(fill_value, dtype, device):
34 if isinstance(fill_value, bool):
35 if dtype != torch.bool:
36 fill_value = int(fill_value)
37 elif (
38 dtype in ALL_INT_DTYPES
39 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max)
40 ) or (
41 dtype in ALL_FLOAT_DTYPES
42 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max)
43 ):
44 raise RuntimeError(
45 f"value cannot be converted to type {dtype} without overflow"
46 )
47 if dtype in ALL_FLOAT_DTYPES:
48 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
49 return fill_value
52def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
53 logging.debug("GEMS FULL")
54 if device is None:
55 device = torch.device("cpu")
56 if dtype is None:
57 if isinstance(fill_value, bool):
58 dtype = torch.bool
59 elif isinstance(fill_value, int):
60 dtype = torch.int64
61 else:
62 dtype = torch.get_default_dtype()
63 else:
64 fill_value = check_dtype(fill_value, dtype, device)
66 if isinstance(fill_value, torch.Tensor):
67 scalar = fill_value.to(dtype=dtype, device=device)
68 else:
69 scalar = torch.tensor(fill_value, dtype=dtype, device=device)
70 return scalar.expand(size).clone()