Coverage for src/flag_gems/runtime/backend/_ascend/ops/full_like.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9from .full import check_dtype, full_kernel 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14def full_like( 

15 x, 

16 fill_value, 

17 *, 

18 dtype=None, 

19 layout=None, 

20 device=None, 

21 pin_memory=None, 

22 memory_format=None, 

23): 

24 logger.debug("GEMS_ASCEND FULL_LIKE") 

25 if device is None: 

26 device = x.device 

27 if dtype is None: 

28 dtype = x.dtype 

29 fill_value = check_dtype(fill_value, dtype, device) 

30 out = torch.empty_like(x, device=device, dtype=dtype) 

31 N = x.numel() 

32 BLOCK_SIZE = min(triton.next_power_of_2(math.ceil(math.sqrt(N))), 2048) 

33 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

34 with torch_device_fn.device(x.device): 

35 full_kernel[grid_fn]( 

36 out, 

37 N, 

38 fill_value, 

39 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor), 

40 BLOCK_SIZE=BLOCK_SIZE, 

41 ) 

42 return out