Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mul.py: 0%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

14@triton.jit 

15def mul_func(x, y): 

16 return x * y 

17 

18 

19@triton.jit 

20def mul_kernel( 

21 x_ptr, # *Pointer* to first input vector. 

22 y_ptr, # *Pointer* to second input vector. 

23 output_ptr, # *Pointer* to output vector. 

24 n_elements, # Size of the vector. 

25 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 

26 # NOTE: `constexpr` so it can be used as a shape value. 

27): 

28 # There are multiple 'programs' processing different data. We identify which program 

29 # we are here: 

30 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

31 # This program will process inputs that are offset from the initial data. 

32 # For instance, if you had a vector of length 256 and block_size of 64, the programs 

33 # would each access the elements [0:64, 64:128, 128:192, 192:256]. 

34 # Note that offsets is a list of pointers: 

35 block_start = pid * BLOCK_SIZE 

36 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

37 # Create a mask to guard memory operations against out-of-bounds accesses. 

38 mask = offsets < n_elements 

39 # Load x and y from DRAM, masking out any extra elements in case the input is not a 

40 # multiple of the block size. 

41 x = tl.load(x_ptr + offsets, mask=mask) 

42 y = tl.load(y_ptr + offsets, mask=mask) 

43 output = x * y 

44 # Write x + y back to DRAM. 

45 tl.store(output_ptr + offsets, output, mask=mask) 

46 

47 

48def mul_all_real_func(x: torch.Tensor, y: torch.Tensor): 

49 # # We need to preallocate the output. 

50 # print("\n.......test for mutibackend specific add........\n") 

51 output = torch.empty_like(x) 

52 n_elements = output.numel() 

53 # The SPMD launch grid denotes the number of kernel instances that run in parallel. 

54 # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. 

55 # In this case, we use a 1D grid where the size is the number of blocks: 

56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

57 # NOTE: 

58 # - Each torch.tensor object is implicitly converted into a pointer to its first element. 

59 # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. 

60 # - Don't forget to pass meta-parameters as keywords arguments. 

61 with torch_device_fn.device(x.device): 

62 mul_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 

63 # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still 

64 # running asynchronously at this point. 

65 return output 

66 

67 

68def _can_use_fast_mul_all_real(x: torch.Tensor, y: torch.Tensor) -> bool: 

69 return ( 

70 x.shape == y.shape 

71 and x.dtype == y.dtype 

72 and x.is_contiguous() 

73 and y.is_contiguous() 

74 ) 

75 

76 

77@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

78@triton.jit 

79def mul_func_scalar(x, y): 

80 return x * y 

81 

82 

83@pointwise_dynamic( 

84 is_tensor=[True, True, True, True], # ar, ai, br, bi 

85 num_outputs=2, 

86 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

87) 

88@triton.jit 

89def mul_complex_kernel(ar, ai, br, bi): 

90 real = ar * br - ai * bi 

91 imag = ar * bi + ai * br 

92 return real, imag 

93 

94 

95def mul(A, B): 

96 logger.debug("GEMS MUL") 

97 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

98 A, complex 

99 ) 

100 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

101 B, complex 

102 ) 

103 if A_is_complex or B_is_complex: 

104 # 1) A、B both are complex 

105 if A_is_complex and B_is_complex: 

106 Ar = torch.view_as_real(A) 

107 Br = torch.view_as_real(B) 

108 ar, ai = Ar[..., 0], Ar[..., 1] 

109 br, bi = Br[..., 0], Br[..., 1] 

110 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

111 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

112 br, bi = br.to(common_dtype), bi.to(common_dtype) 

113 

114 # real_out = torch.empty_like(ar, dtype=common_dtype) 

115 # imag_out = torch.empty_like(ar, dtype=common_dtype) 

116 shape = ar.shape 

117 out_buffer = torch.empty((*shape, 2), dtype=common_dtype, device=ar.device) 

118 real_out = out_buffer[..., 0] 

119 imag_out = out_buffer[..., 1] 

120 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out) 

121 

122 # out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1)) 

123 out = torch.view_as_complex(out_buffer) 

124 return out.to(torch.result_type(A, B)) 

125 # 2) A complex, B real 

126 elif A_is_complex and not B_is_complex: 

127 Ar = torch.view_as_real(A) 

128 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B 

129 if isinstance(Br, torch.Tensor): 

130 out_real = mul_func(Ar, Br) 

131 else: 

132 out_real = mul_func_scalar(Ar, Br) 

133 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

134 # 3) A real, B complex 

135 else: # not A_is_complex and B_is_complex 

136 Br = torch.view_as_real(B) 

137 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A 

138 if isinstance(Ar, torch.Tensor): 

139 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br 

140 else: 

141 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar 

142 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

143 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

144 if _can_use_fast_mul_all_real(A, B): 

145 return mul_all_real_func(A, B) 

146 else: 

147 return mul_func(A, B) 

148 elif isinstance(A, torch.Tensor): 

149 return mul_func_scalar(A, B) 

150 elif isinstance(B, torch.Tensor): 

151 return mul_func_scalar(B, A) 

152 else: 

153 # Both scalar 

154 return torch.tensor(A * B) 

155 

156 

157def mul_(A, B): 

158 logger.debug("GEMS MUL_") 

159 if isinstance(B, torch.Tensor): 

160 return mul_func(A, B, out0=A) 

161 else: 

162 return mul_func_scalar(A, B, out0=A)