Coverage for src/flag_gems/runtime/backend/_arm/ops/full.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import triton_lang_extension as tle 

8 

9 

10@triton.jit(do_not_specialize=["fill_value_or_ptr"]) 

11def full_kernel( 

12 output_ptr, 

13 n_elements, 

14 fill_value_or_ptr, 

15 FILL_VALUE_IS_PTR: tl.constexpr, 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tle.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 if FILL_VALUE_IS_PTR: 

23 fill_value = tl.load(fill_value_or_ptr) 

24 else: 

25 fill_value = fill_value_or_ptr 

26 tl.store(output_ptr + offsets, fill_value, mask=mask) 

27 

28 

29ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

30ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

31 

32 

33def check_dtype(fill_value, dtype, device): 

34 if isinstance(fill_value, bool): 

35 if dtype != torch.bool: 

36 fill_value = int(fill_value) 

37 elif ( 

38 dtype in ALL_INT_DTYPES 

39 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max) 

40 ) or ( 

41 dtype in ALL_FLOAT_DTYPES 

42 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max) 

43 ): 

44 raise RuntimeError( 

45 f"value cannot be converted to type {dtype} without overflow" 

46 ) 

47 if dtype in ALL_FLOAT_DTYPES: 

48 fill_value = torch.tensor(fill_value, dtype=dtype, device=device) 

49 return fill_value 

50 

51 

52def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

53 logging.debug("GEMS FULL") 

54 if device is None: 

55 device = torch.device("cpu") 

56 if dtype is None: 

57 if isinstance(fill_value, bool): 

58 dtype = torch.bool 

59 elif isinstance(fill_value, int): 

60 dtype = torch.int64 

61 else: 

62 dtype = torch.get_default_dtype() 

63 else: 

64 fill_value = check_dtype(fill_value, dtype, device) 

65 

66 if isinstance(fill_value, torch.Tensor): 

67 scalar = fill_value.to(dtype=dtype, device=device) 

68 else: 

69 scalar = torch.tensor(fill_value, dtype=dtype, device=device) 

70 return scalar.expand(size).clone()