Coverage for src/flag_gems/runtime/backend/_arm/ops/min.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import numpy as np 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11 

12# from ..runtime import torch_device_fn 

13# from ..utils import libentry 

14from flag_gems.utils import triton_lang_extension as tle 

15 

16 

17# @libentry() 

18@triton.jit 

19def min_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 inp_ptrs = inp + offset 

28 mask = offset < M 

29 inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf")) 

30 min_val = tl.min(inp_val) 

31 mid_ptr = mid + pid 

32 tl.store(mid_ptr, min_val) 

33 

34 

35# @libentry() 

36@triton.jit 

37def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

38 offset = tl.arange(0, BLOCK_MID) 

39 mid_ptrs = mid + offset 

40 mask = offset < mid_size 

41 mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf")) 

42 min_val = tl.min(mid_val) 

43 tl.store(out, min_val) 

44 

45 

46@triton.autotune( 

47 configs=[ 

48 triton.Config({"BLOCK_SIZE": 8}, num_warps=1), 

49 triton.Config({"BLOCK_SIZE": 2}, num_warps=2), 

50 triton.Config({"BLOCK_SIZE": 16}, num_warps=4), 

51 triton.Config({"BLOCK_SIZE": 32}, num_warps=4), 

52 ], 

53 key=["M"], # re-tune when tensor size changes 

54) 

55# @libentry() 

56@triton.jit 

57def min_kernel_3(inp, out, M, BLOCK_SIZE: tl.constexpr): 

58 pid = tl.program_id(0) 

59 start = pid * BLOCK_SIZE 

60 offsets = start + tl.arange(0, BLOCK_SIZE) 

61 mask = offsets < M 

62 x = tl.load(inp + offsets, mask=mask) 

63 min_val = tl.min(x, axis=None) 

64 tl.atomic_min(out, min_val) 

65 

66 

67def heur_block_n(args): 

68 return triton.next_power_of_2(args["N"]) 

69 

70 

71# @libentry() 

72@triton.autotune( 

73 configs=runtime.get_tuned_config("min"), 

74 key=[ 

75 "M", 

76 "N", 

77 ], 

78) 

79@triton.jit 

80def min_kernel( 

81 inp, 

82 out_value, 

83 out_index, 

84 M, 

85 N, 

86 K, 

87 BLOCK_M: tl.constexpr, 

88 BLOCK_N: tl.constexpr, 

89): 

90 # set offset 

91 pid_m = tle.program_id(0) 

92 pid_k = tle.program_id(1) 

93 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

94 

95 min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) 

96 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

97 for start_n in range(0, N, BLOCK_N): 

98 n_offset = start_n + tl.arange(0, BLOCK_N) 

99 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

100 mask = m_offset[:, None] < M and n_offset[None, :] < N 

101 inp_ptrs = inp + offset 

102 inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) 

103 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) 

104 # if return indices is not supported, call a tl.argmax in addition 

105 # local_argmin = tl.argmin(inp_vals, 1) 

106 update = local_min < min_values 

107 min_values = tl.where(update, local_min, min_values) 

108 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

109 

110 offset_index = m_offset * K + pid_k 

111 out_value_ptrs = out_value + offset_index 

112 out_index_ptrs = out_index + offset_index 

113 mask1 = m_offset < M 

114 tl.store(out_value_ptrs, min_values, mask=mask1) 

115 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

116 

117 

118def min(inp): 

119 logging.debug("GEMS MIN") 

120 M = inp.numel() 

121 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

122 mid_size = triton.cdiv(M, block_size) 

123 block_mid = triton.next_power_of_2(mid_size) 

124 

125 dtype = inp.dtype 

126 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

127 out = torch.empty([], dtype=dtype, device=inp.device) 

128 # Use two-stage reduction for broader dtype support on Triton CPU. 

129 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

130 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

131 return out 

132 

133 

134def min_dim(inp, dim=None, keepdim=False): 

135 logging.debug("GEMS MIN DIM") 

136 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

137 dim = dim % inp.ndim 

138 inp_np = inp.detach().cpu().numpy() 

139 out_index_np = np.argmin(inp_np, axis=dim) 

140 gather_index = np.expand_dims(out_index_np, axis=dim) 

141 out_value_np = np.take_along_axis(inp_np, gather_index, axis=dim) 

142 out_index = torch.from_numpy(out_index_np.astype(np.int64, copy=False)).to( 

143 inp.device 

144 ) 

145 out_value = torch.from_numpy(out_value_np).to(inp.device) 

146 if keepdim: 

147 out_index = out_index.unsqueeze(dim) 

148 else: 

149 out_value = out_value.squeeze(dim) 

150 Min_out = namedtuple("min", ["values", "indices"]) 

151 out = Min_out(values=out_value, indices=out_index) 

152 return out