Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mul.py: 0%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
14@triton.jit
15def mul_func(x, y):
16 return x * y
19@triton.jit
20def mul_kernel(
21 x_ptr, # *Pointer* to first input vector.
22 y_ptr, # *Pointer* to second input vector.
23 output_ptr, # *Pointer* to output vector.
24 n_elements, # Size of the vector.
25 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
26 # NOTE: `constexpr` so it can be used as a shape value.
27):
28 # There are multiple 'programs' processing different data. We identify which program
29 # we are here:
30 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
31 # This program will process inputs that are offset from the initial data.
32 # For instance, if you had a vector of length 256 and block_size of 64, the programs
33 # would each access the elements [0:64, 64:128, 128:192, 192:256].
34 # Note that offsets is a list of pointers:
35 block_start = pid * BLOCK_SIZE
36 offsets = block_start + tl.arange(0, BLOCK_SIZE)
37 # Create a mask to guard memory operations against out-of-bounds accesses.
38 mask = offsets < n_elements
39 # Load x and y from DRAM, masking out any extra elements in case the input is not a
40 # multiple of the block size.
41 x = tl.load(x_ptr + offsets, mask=mask)
42 y = tl.load(y_ptr + offsets, mask=mask)
43 output = x * y
44 # Write x + y back to DRAM.
45 tl.store(output_ptr + offsets, output, mask=mask)
48def mul_all_real_func(x: torch.Tensor, y: torch.Tensor):
49 # # We need to preallocate the output.
50 # print("\n.......test for mutibackend specific add........\n")
51 output = torch.empty_like(x)
52 n_elements = output.numel()
53 # The SPMD launch grid denotes the number of kernel instances that run in parallel.
54 # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
55 # In this case, we use a 1D grid where the size is the number of blocks:
56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
57 # NOTE:
58 # - Each torch.tensor object is implicitly converted into a pointer to its first element.
59 # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
60 # - Don't forget to pass meta-parameters as keywords arguments.
61 with torch_device_fn.device(x.device):
62 mul_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
63 # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still
64 # running asynchronously at this point.
65 return output
68def _can_use_fast_mul_all_real(x: torch.Tensor, y: torch.Tensor) -> bool:
69 return (
70 x.shape == y.shape
71 and x.dtype == y.dtype
72 and x.is_contiguous()
73 and y.is_contiguous()
74 )
77@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
78@triton.jit
79def mul_func_scalar(x, y):
80 return x * y
83@pointwise_dynamic(
84 is_tensor=[True, True, True, True], # ar, ai, br, bi
85 num_outputs=2,
86 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
87)
88@triton.jit
89def mul_complex_kernel(ar, ai, br, bi):
90 real = ar * br - ai * bi
91 imag = ar * bi + ai * br
92 return real, imag
95def mul(A, B):
96 logger.debug("GEMS MUL")
97 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
98 A, complex
99 )
100 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
101 B, complex
102 )
103 if A_is_complex or B_is_complex:
104 # 1) A、B both are complex
105 if A_is_complex and B_is_complex:
106 Ar = torch.view_as_real(A)
107 Br = torch.view_as_real(B)
108 ar, ai = Ar[..., 0], Ar[..., 1]
109 br, bi = Br[..., 0], Br[..., 1]
110 common_dtype = torch.promote_types(ar.dtype, br.dtype)
111 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
112 br, bi = br.to(common_dtype), bi.to(common_dtype)
114 # real_out = torch.empty_like(ar, dtype=common_dtype)
115 # imag_out = torch.empty_like(ar, dtype=common_dtype)
116 shape = ar.shape
117 out_buffer = torch.empty((*shape, 2), dtype=common_dtype, device=ar.device)
118 real_out = out_buffer[..., 0]
119 imag_out = out_buffer[..., 1]
120 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out)
122 # out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
123 out = torch.view_as_complex(out_buffer)
124 return out.to(torch.result_type(A, B))
125 # 2) A complex, B real
126 elif A_is_complex and not B_is_complex:
127 Ar = torch.view_as_real(A)
128 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B
129 if isinstance(Br, torch.Tensor):
130 out_real = mul_func(Ar, Br)
131 else:
132 out_real = mul_func_scalar(Ar, Br)
133 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
134 # 3) A real, B complex
135 else: # not A_is_complex and B_is_complex
136 Br = torch.view_as_real(B)
137 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A
138 if isinstance(Ar, torch.Tensor):
139 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br
140 else:
141 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar
142 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
143 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
144 if _can_use_fast_mul_all_real(A, B):
145 return mul_all_real_func(A, B)
146 else:
147 return mul_func(A, B)
148 elif isinstance(A, torch.Tensor):
149 return mul_func_scalar(A, B)
150 elif isinstance(B, torch.Tensor):
151 return mul_func_scalar(B, A)
152 else:
153 # Both scalar
154 return torch.tensor(A * B)
157def mul_(A, B):
158 logger.debug("GEMS MUL_")
159 if isinstance(B, torch.Tensor):
160 return mul_func(A, B, out0=A)
161 else:
162 return mul_func_scalar(A, B, out0=A)