Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/vdot.py: 0%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def compute_vdot( 

16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

17): 

18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag] 

19 

20 # # Case 1: inp_is_conj = False, other_is_conj = False 

21 # out_real = inp_real * other_real + inp_imag * other_imag 

22 # out_imag = inp_real * other_imag - inp_imag * other_real 

23 

24 # # Case 2: inp_is_conj = True, other_is_conj = False 

25 # out_real = inp_real * other_real - inp_imag * other_imag 

26 # out_imag = inp_real * other_imag + inp_imag * other_real 

27 

28 # # Case 3: inp_is_conj = False, other_is_conj = True 

29 # out_real = inp_real * other_real - inp_imag * other_imag 

30 # out_imag = -inp_real * other_imag - inp_imag * other_real 

31 

32 # # Case 4: inp_is_conj = True, other_is_conj = True 

33 # out_real = inp_real * other_real + inp_imag * other_imag 

34 # out_imag = inp_real * other_imag - inp_imag * other_real 

35 if not inp_is_conj and not other_is_conj: # Case 1 

36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real) 

38 elif inp_is_conj and not other_is_conj: # Case 2 

39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real) 

41 elif not inp_is_conj and other_is_conj: # Case 3 

42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real) 

44 else: # Case 4 

45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real) 

47 

48 return out_real, out_imag 

49 

50 

51# support old version triton which do not support tl.split 

52@libentry() 

53@triton.heuristics(runtime.get_heuristic_config("vdot")) 

54@triton.jit() 

55def vdot_kernel_complex( 

56 inp_ptr, 

57 other_ptr, 

58 out_ptr, 

59 n_elements, 

60 inp_is_conj: tl.constexpr, 

61 other_is_conj: tl.constexpr, 

62 inp_stride: tl.constexpr, 

63 other_stride: tl.constexpr, 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 pid = tl.program_id(0) 

67 

68 base_offset = 2 * pid * BLOCK_SIZE + 2 * tl.arange(0, BLOCK_SIZE) + tl.arange(0, 1) 

69 

70 inp_real_offset = inp_stride * base_offset 

71 inp_imag_offset = inp_real_offset + 1 

72 

73 other_real_offset = other_stride * base_offset 

74 other_imag_offset = other_real_offset + 1 

75 

76 mask = base_offset < n_elements 

77 

78 inp_real = tl.load(inp_ptr + inp_real_offset, mask=mask) 

79 inp_imag = tl.load(inp_ptr + inp_imag_offset, mask=mask) 

80 

81 other_real = tl.load(other_ptr + other_real_offset, mask=mask) 

82 other_imag = tl.load(other_ptr + other_imag_offset, mask=mask) 

83 

84 # Compute based on conjugate flags 

85 out_real, out_imag = compute_vdot( 

86 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

87 ) 

88 

89 tl.atomic_add(out_ptr, out_real) 

90 tl.atomic_add(out_ptr + 1, out_imag) 

91 

92 

93# only support real number 

94@libentry() 

95@triton.heuristics(runtime.get_heuristic_config("vdot")) 

96@triton.jit() 

97def dot_kernel( 

98 inp_ptr, 

99 other_ptr, 

100 out_ptr, 

101 n_elements, 

102 inp_stride: tl.constexpr, 

103 other_stride: tl.constexpr, 

104 BLOCK_SIZE: tl.constexpr, 

105): 

106 pid = tl.program_id(0) 

107 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

108 mask = offset < n_elements 

109 

110 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask).to(tl.float32) 

111 other = tl.load(other_ptr + other_stride * offset, mask=mask).to(tl.float32) 

112 

113 out = tl.sum(inp * other) 

114 tl.atomic_add(out_ptr, out) 

115 

116 

117def vdot(input: Tensor, other: Tensor): 

118 logger.debug("GEMS_TSINGMICRO VDOT") 

119 

120 assert ( 

121 input.dtype == other.dtype 

122 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}." 

123 assert ( 

124 input.ndim == 1 and other.ndim == 1 

125 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D." 

126 assert ( 

127 input.size() == other.size() 

128 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}." 

129 

130 inp = input 

131 inp_stride = inp.stride()[0] 

132 other_stride = other.stride()[0] 

133 

134 if inp.is_complex(): 

135 inp_is_conj = False 

136 other_is_conj = False 

137 

138 if inp.is_conj(): 

139 inp_is_conj = True 

140 inp = inp.conj() 

141 

142 if other.is_conj(): 

143 other_is_conj = True 

144 other = other.conj() 

145 

146 inp_real = torch.view_as_real(inp) 

147 other_real = torch.view_as_real(other) 

148 

149 n_elements = inp_real.numel() 

150 n_complex = inp.numel() 

151 

152 output_real = torch.zeros(2, dtype=inp_real.dtype, device=inp.device) 

153 

154 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),) 

155 

156 vdot_kernel_complex[grid]( 

157 inp_real, 

158 other_real, 

159 output_real, 

160 n_elements=n_elements, 

161 inp_is_conj=inp_is_conj, 

162 other_is_conj=other_is_conj, 

163 inp_stride=inp_stride, 

164 other_stride=other_stride, 

165 ) 

166 

167 return torch.view_as_complex(output_real) 

168 else: 

169 output = torch.zeros([], dtype=torch.float32, device=inp.device) 

170 n_elements = inp.numel() 

171 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

172 dot_kernel[grid]( 

173 inp, 

174 other, 

175 output, 

176 n_elements=n_elements, 

177 inp_stride=inp_stride, 

178 other_stride=other_stride, 

179 ) 

180 return output.to(inp.dtype)