Coverage for src/flag_gems/runtime/backend/_ascend/ops/full.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
11from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
13device_ = device
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
17ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
19# Threshold for switching between pointwise_dynamic (small tensors)
20# and hand-written multi-core kernel (large tensors).
21SMALL_TENSOR_THRESHOLD = 100000
24def check_dtype(fill_value, dtype, device):
25 if isinstance(fill_value, bool):
26 if dtype != torch.bool:
27 fill_value = int(fill_value)
29 elif (
30 dtype in ALL_INT_DTYPES
31 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max)
32 ) or (
33 dtype in ALL_FLOAT_DTYPES
34 and not (math.isinf(fill_value) or math.isnan(fill_value))
35 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max)
36 ):
37 raise RuntimeError(
38 f"value cannot be converted to type {dtype} without overflow"
39 )
41 return fill_value
44# Small tensor path: pointwise_dynamic has lower launch overhead
45@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
46@triton.jit
47def full_func(out, fill_value):
48 return fill_value
51@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
52@triton.jit
53def full_func_scalar(out, fill_value):
54 return tl.full(out.shape, fill_value, out.dtype)
57# Large tensor path: hand-written multi-core kernel for better throughput
58@libentry()
59@triton.jit(do_not_specialize=["fill_value"])
60def full_kernel(
61 out_ptr,
62 N,
63 fill_value,
64 BLOCK_SIZE: tl.constexpr,
65 SUBBLOCK_SIZE: tl.constexpr,
66):
67 pid = tle.program_id(0)
68 pid_offset = pid * BLOCK_SIZE
69 cols = tl.arange(0, SUBBLOCK_SIZE)
70 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE)
71 for iloop in tl.range(num_loop):
72 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols
73 tl.store(out_ptr + offset, fill_value, mask=offset < N)
76def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
77 logger.debug("GEMS_ASCEND FULL")
78 if device is None:
79 device = torch.device(device_.name)
80 if dtype is None:
81 if isinstance(fill_value, bool):
82 dtype = torch.bool
83 elif isinstance(fill_value, int):
84 dtype = torch.int64
85 else:
86 dtype = torch.get_default_dtype()
87 else:
88 fill_value = check_dtype(fill_value, dtype, device)
90 out = torch.empty(size, device=device, dtype=dtype)
91 N = out.numel()
92 if N == 0:
93 return out
95 if N < SMALL_TENSOR_THRESHOLD:
96 # Small tensor: use pointwise_dynamic for lower launch overhead
97 if isinstance(fill_value, torch.Tensor):
98 return full_func(out, fill_value, out0=out)
99 else:
100 return full_func_scalar(out, fill_value, out0=out)
102 # Large tensor: use hand-written multi-core kernel
103 if isinstance(fill_value, torch.Tensor):
104 fill_value = fill_value.item()
106 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48.
107 grid = min(40, N)
108 BLOCK_SIZE = (N + grid - 1) // grid
109 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE)
111 with torch_device_fn.device(device):
112 full_kernel[grid,](out, N, fill_value, BLOCK_SIZE, SUBBLOCK_SIZE)
113 return out