Coverage for src/flag_gems/runtime/backend/_sunrise/ops/fill.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import pointwise_dynamic 

9from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14MAX_GRID_SIZES = (65535, 65535, 65535) 

15config = CodeGenConfig( 

16 max_tile_size=1024, 

17 max_grid_size=MAX_GRID_SIZES, 

18 max_num_warps_per_cta=32, 

19 prefer_block_pointer=True, 

20 prefer_1d_tile=True, 

21) 

22 

23 

24@pointwise_dynamic( 

25 is_tensor=[True, False], 

26 promotion_methods=[(0, "DEFAULT")], 

27 num_outputs=1, 

28 config=config, 

29) 

30@triton.jit 

31def fill_scalar_func(inp, value_scalar): 

32 return tl.full(inp.shape, value_scalar, dtype=inp.dtype) 

33 

34 

35@pointwise_dynamic( 

36 is_tensor=[True, True], 

37 promotion_methods=[(0, "DEFAULT")], 

38 num_outputs=1, 

39 config=config, 

40) 

41@triton.jit 

42def fill_tensor_func(inp, value): 

43 return value 

44 

45 

46def fill_scalar(input, value): 

47 logger.debug("GEMS FILL (Dynamic)") 

48 out = torch.empty_like(input) 

49 with torch_device_fn.device(input.device): 

50 return fill_scalar_func(input, value, out0=out) 

51 

52 

53def fill_scalar_out(input, value, *, out=None): 

54 logger.debug("GEMS FILL_SCALAR_OUT") 

55 if out is None: 

56 return fill_scalar(input, value) 

57 with torch_device_fn.device(input.device): 

58 fill_scalar_func(input, value, out0=out) 

59 return out 

60 

61 

62def fill_tensor(input, value): 

63 if not value.is_cuda: 

64 return fill_scalar(input, value.item()) 

65 logger.debug("GEMS FILL (Dynamic)") 

66 if value.ndim != 0: 

67 raise RuntimeError( 

68 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

69 ) 

70 out = torch.empty_like(input) 

71 with torch_device_fn.device(input.device): 

72 return fill_tensor_func(input, value, out0=out) 

73 

74 

75def fill_tensor_out(input, value, *, out=None): 

76 logger.debug("GEMS FILL_TENSOR_OUT") 

77 if out is None: 

78 return fill_tensor(input, value) 

79 if not value.is_cuda: 

80 return fill_scalar_out(input, value.item(), out=out) 

81 if value.ndim != 0: 

82 raise RuntimeError( 

83 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

84 ) 

85 with torch_device_fn.device(input.device): 

86 fill_tensor_func(input, value, out0=out) 

87 return out 

88 

89 

90def fill_tensor_(self, value): 

91 if not value.is_cuda: 

92 return fill_scalar_(self, value.item()) 

93 logger.debug("GEMS FILL_TENSOR_") 

94 if value.ndim != 0: 

95 raise RuntimeError( 

96 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

97 ) 

98 with torch_device_fn.device(self.device): 

99 fill_tensor_func(self, value, out0=self) 

100 return self 

101 

102 

103def fill_scalar_(self, value): 

104 logger.debug("GEMS FILL_SCALAR_") 

105 with torch_device_fn.device(self.device): 

106 fill_scalar_func(self, value, out0=self) 

107 return self