Coverage for src/flag_gems/runtime/backend/_arm/ops/div.py: 0%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops import div as base_div 

9 

10 

11@triton.jit 

12def _div_tensor_scalar_kernel( 

13 x_ptr, 

14 out_ptr, 

15 scalar, 

16 n_elements, 

17 BLOCK_SIZE: tl.constexpr, 

18): 

19 pid = tl.program_id(0) 

20 num_prog = tl.num_programs(0) 

21 start = pid * BLOCK_SIZE 

22 step = num_prog * BLOCK_SIZE 

23 for off in range(start, n_elements, step): 

24 offsets = off + tl.arange(0, BLOCK_SIZE) 

25 mask = offsets < n_elements 

26 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

27 y = x / scalar 

28 tl.store(out_ptr + offsets, y, mask=mask) 

29 

30 

31def _select_block_size(n_elements, dtype): 

32 if n_elements >= (1 << 20): 

33 return 512 if dtype in (torch.float16, torch.bfloat16) else 256 

34 if n_elements >= (1 << 18): 

35 return 256 if dtype in (torch.float16, torch.bfloat16) else 128 

36 return 256 if dtype in (torch.float16, torch.bfloat16) else 128 

37 

38 

39def _maybe_contiguous(x, out): 

40 if x.is_contiguous(): 

41 return x, out, False 

42 if out is None: 

43 return x.contiguous(), out, True 

44 if out.is_contiguous(): 

45 return x.contiguous(), out, True 

46 return x, out, False 

47 

48 

49def _div_tensor_scalar_triton(x, scalar, out=None): 

50 n_elements = x.numel() 

51 if n_elements == 0: 

52 return x if out is None else out 

53 if n_elements == 1 and x.dtype is torch.bfloat16: 

54 val = x.item() / scalar 

55 if out is None: 

56 out = torch.empty_like(x) 

57 out.fill_(val) 

58 return out 

59 

60 block_size = _select_block_size(n_elements, x.dtype) 

61 block_size = min(block_size, triton.next_power_of_2(max(n_elements, 1))) 

62 num_blocks = triton.cdiv(n_elements, block_size) 

63 grid = (num_blocks,) 

64 x_contig, out_contig, _ = _maybe_contiguous(x, out) 

65 if out_contig is None: 

66 out_contig = torch.empty_like(x_contig) 

67 num_warps = 1 

68 _div_tensor_scalar_kernel[grid]( 

69 x_contig, 

70 out_contig, 

71 scalar, 

72 n_elements, 

73 BLOCK_SIZE=block_size, 

74 num_warps=num_warps, 

75 ) 

76 return out_contig 

77 

78 

79def _maybe_get_scalar_tensor(val): 

80 if isinstance(val, torch.Tensor) and val.numel() == 1: 

81 return val.item() 

82 return None 

83 

84 

85def true_divide(A, B): 

86 logging.debug("GEMS_ARM TRUE_DIVIDE") 

87 if os.environ.get("GEMS_DEBUG_DIV") == "1": 

88 a_shape = tuple(A.shape) if isinstance(A, torch.Tensor) else None 

89 b_shape = tuple(B.shape) if isinstance(B, torch.Tensor) else None 

90 print(f"[GEMS_DEBUG_DIV] true_divide: A={a_shape} B={b_shape}") 

91 if isinstance(A, torch.Tensor) and not isinstance(B, torch.Tensor): 

92 return _div_tensor_scalar_triton(A, B) 

93 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

94 scalar = _maybe_get_scalar_tensor(B) 

95 if scalar is not None: 

96 return _div_tensor_scalar_triton(A, scalar) 

97 return base_div.true_divide(A, B) 

98 

99 

100def true_divide_(A, B): 

101 logging.debug("GEMS_ARM TRUE_DIVIDE_") 

102 if isinstance(A, torch.Tensor) and not isinstance(B, torch.Tensor): 

103 if A.is_contiguous(): 

104 return _div_tensor_scalar_triton(A, B, out=A) 

105 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

106 scalar = _maybe_get_scalar_tensor(B) 

107 if scalar is not None and A.is_contiguous(): 

108 return _div_tensor_scalar_triton(A, scalar, out=A) 

109 return base_div.true_divide_(A, B) 

110 

111 

112def trunc_divide(A, B): 

113 logging.debug("GEMS_ARM TRUNC_DIVIDE") 

114 return base_div.trunc_divide(A, B) 

115 

116 

117def trunc_divide_(A, B): 

118 logging.debug("GEMS_ARM TRUNC_DIVIDE_") 

119 return base_div.trunc_divide_(A, B) 

120 

121 

122def floor_divide(A, B): 

123 logging.debug("GEMS_ARM FLOOR_DIVIDE") 

124 return base_div.floor_divide(A, B) 

125 

126 

127def floor_divide_(A, B): 

128 logging.debug("GEMS_ARM FLOOR_DIVIDE_") 

129 return base_div.floor_divide_(A, B) 

130 

131 

132def div_mode(A, B, rounding_mode=None): 

133 if rounding_mode is None: 

134 return true_divide(A, B) 

135 if rounding_mode == "trunc": 

136 return trunc_divide(A, B) 

137 if rounding_mode == "floor": 

138 return floor_divide(A, B) 

139 msg = ( 

140 "div expected rounding_mode to be one of None, 'trunc', or 'floor' " 

141 f"but found {rounding_mode}." 

142 ) 

143 raise ValueError(msg) 

144 

145 

146def div_mode_(A, B, rounding_mode=None): 

147 if rounding_mode is None: 

148 return true_divide_(A, B) 

149 if rounding_mode == "trunc": 

150 return trunc_divide_(A, B) 

151 if rounding_mode == "floor": 

152 return floor_divide_(A, B) 

153 msg = ( 

154 "div expected rounding_mode to be one of None, 'trunc', or 'floor' " 

155 f"but found {rounding_mode}." 

156 ) 

157 raise ValueError(msg) 

158 

159 

160def remainder(A, B): 

161 logging.debug("GEMS_ARM REMAINDER") 

162 return base_div.remainder(A, B) 

163 

164 

165def remainder_(A, B): 

166 logging.debug("GEMS_ARM REMAINDER_") 

167 return base_div.remainder_(A, B)