Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py: 0%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def fill_scalar_kernel( 

14 out_ptr, 

15 value_scalar, 

16 n_elements, 

17 BLOCK_SIZE: tl.constexpr, 

18): 

19 pid = tl.program_id(axis=0) 

20 block_start = pid * BLOCK_SIZE 

21 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

22 mask = offsets < n_elements 

23 

24 # Load a dummy value to infer the dtype of out_ptr 

25 dummy = tl.load(out_ptr + offsets, mask=mask, other=0) 

26 fill_val = tl.full([BLOCK_SIZE], value_scalar, dtype=dummy.dtype) 

27 tl.store(out_ptr + offsets, fill_val, mask=mask) 

28 

29 

30@triton.jit 

31def fill_tensor_kernel( 

32 out_ptr, 

33 value_ptr, 

34 n_elements, 

35 BLOCK_SIZE: tl.constexpr, 

36): 

37 pid = tl.program_id(axis=0) 

38 block_start = pid * BLOCK_SIZE 

39 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

40 mask = offsets < n_elements 

41 

42 val = tl.load(value_ptr) 

43 tl.store(out_ptr + offsets, val, mask=mask) 

44 

45 

46def _as_contiguous(tensor): 

47 """Return tensor.contiguous() view for use with flat-offset kernels. 

48 

49 For non-contiguous tensors this allocates a new buffer; callers that 

50 need in-place semantics must copy back afterwards. 

51 """ 

52 if tensor.is_contiguous(): 

53 return tensor, False 

54 return tensor.contiguous(), True 

55 

56 

57def fill_scalar(input, value): 

58 logger.debug("GEMS_HOPPER FILL_SCALAR") 

59 out = torch.empty_like(input) 

60 n_elements = out.numel() 

61 grid = (triton.cdiv(n_elements, 1024),) 

62 with torch_device_fn.device(input.device): 

63 fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) 

64 return out 

65 

66 

67def fill_scalar_out(input, value, *, out=None): 

68 logger.debug("GEMS_HOPPER FILL_SCALAR_OUT") 

69 if out is None: 

70 return fill_scalar(input, value) 

71 out_contig, need_copy = _as_contiguous(out) 

72 n_elements = out_contig.numel() 

73 grid = (triton.cdiv(n_elements, 1024),) 

74 with torch_device_fn.device(input.device): 

75 fill_scalar_kernel[grid](out_contig, value, n_elements, BLOCK_SIZE=1024) 

76 if need_copy: 

77 out.copy_(out_contig) 

78 return out 

79 

80 

81def fill_tensor(input, value): 

82 if not value.is_cuda: 

83 return fill_scalar(input, value.item()) 

84 logger.debug("GEMS_HOPPER FILL_TENSOR") 

85 if value.ndim != 0: 

86 raise RuntimeError( 

87 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

88 ) 

89 out = torch.empty_like(input) 

90 n_elements = out.numel() 

91 grid = (triton.cdiv(n_elements, 1024),) 

92 with torch_device_fn.device(input.device): 

93 fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) 

94 return out 

95 

96 

97def fill_tensor_out(input, value, *, out=None): 

98 logger.debug("GEMS_HOPPER FILL_TENSOR_OUT") 

99 if out is None: 

100 return fill_tensor(input, value) 

101 if not value.is_cuda: 

102 return fill_scalar_out(input, value.item(), out=out) 

103 if value.ndim != 0: 

104 raise RuntimeError( 

105 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

106 ) 

107 out_contig, need_copy = _as_contiguous(out) 

108 n_elements = out_contig.numel() 

109 grid = (triton.cdiv(n_elements, 1024),) 

110 with torch_device_fn.device(input.device): 

111 fill_tensor_kernel[grid](out_contig, value, n_elements, BLOCK_SIZE=1024) 

112 if need_copy: 

113 out.copy_(out_contig) 

114 return out 

115 

116 

117def fill_tensor_(self, value): 

118 if not value.is_cuda: 

119 return fill_scalar_(self, value.item()) 

120 logger.debug("GEMS_HOPPER FILL_TENSOR_") 

121 if value.ndim != 0: 

122 raise RuntimeError( 

123 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

124 ) 

125 if self.is_contiguous(): 

126 n_elements = self.numel() 

127 grid = (triton.cdiv(n_elements, 1024),) 

128 with torch_device_fn.device(self.device): 

129 fill_tensor_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) 

130 else: 

131 tmp = self.contiguous() 

132 n_elements = tmp.numel() 

133 grid = (triton.cdiv(n_elements, 1024),) 

134 with torch_device_fn.device(self.device): 

135 fill_tensor_kernel[grid](tmp, value, n_elements, BLOCK_SIZE=1024) 

136 self.copy_(tmp) 

137 return self 

138 

139 

140def fill_scalar_(self, value): 

141 logger.debug("GEMS_HOPPER FILL_SCALAR_") 

142 if self.is_contiguous(): 

143 n_elements = self.numel() 

144 grid = (triton.cdiv(n_elements, 1024),) 

145 with torch_device_fn.device(self.device): 

146 fill_scalar_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) 

147 else: 

148 tmp = self.contiguous() 

149 n_elements = tmp.numel() 

150 grid = (triton.cdiv(n_elements, 1024),) 

151 with torch_device_fn.device(self.device): 

152 fill_scalar_kernel[grid](tmp, value, n_elements, BLOCK_SIZE=1024) 

153 self.copy_(tmp) 

154 return self