Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mul.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
14@triton.jit
15def mul_func(x, y):
16 return x * y
19@triton.jit
20def mul_kernel(
21 x_ptr, # *Pointer* to first input vector.
22 y_ptr, # *Pointer* to second input vector.
23 output_ptr, # *Pointer* to output vector.
24 n_elements, # Size of the vector.
25 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
26 # NOTE: `constexpr` so it can be used as a shape value.
27):
28 # There are multiple 'programs' processing different data. We identify which program
29 # we are here:
30 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
31 # This program will process inputs that are offset from the initial data.
32 # For instance, if you had a vector of length 256 and block_size of 64, the programs
33 # would each access the elements [0:64, 64:128, 128:192, 192:256].
34 # Note that offsets is a list of pointers:
35 block_start = pid * BLOCK_SIZE
36 offsets = block_start + tl.arange(0, BLOCK_SIZE)
37 # Create a mask to guard memory operations against out-of-bounds accesses.
38 mask = offsets < n_elements
39 # Load x and y from DRAM, masking out any extra elements in case the input is not a
40 # multiple of the block size.
41 x = tl.load(x_ptr + offsets, mask=mask)
42 y = tl.load(y_ptr + offsets, mask=mask)
43 output = x * y
44 # Write x + y back to DRAM.
45 tl.store(output_ptr + offsets, output, mask=mask)
48def mul_all_real_func(x: torch.Tensor, y: torch.Tensor):
49 # # We need to preallocate the output.
50 # print("\n.......test for mutibackend specific add........\n")
51 output = torch.empty_like(x)
52 n_elements = output.numel()
53 # The SPMD launch grid denotes the number of kernel instances that run in parallel.
54 # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
55 # In this case, we use a 1D grid where the size is the number of blocks:
56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
57 # NOTE:
58 # - Each torch.tensor object is implicitly converted into a pointer to its first element.
59 # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
60 # - Don't forget to pass meta-parameters as keywords arguments.
61 with torch_device_fn.device(x.device):
62 mul_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
63 # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still
64 # running asynchronously at this point.
65 return output
68@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
69@triton.jit
70def mul_func_scalar(x, y):
71 return x * y
74@pointwise_dynamic(
75 is_tensor=[True, True, True, True], # ar, ai, br, bi
76 num_outputs=2,
77 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
78)
79@triton.jit
80def mul_complex_kernel(ar, ai, br, bi):
81 real = ar * br - ai * bi
82 imag = ar * bi + ai * br
83 return real, imag
86def mul(A, B):
87 logger.debug("GEMS MUL")
88 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
89 A, complex
90 )
91 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
92 B, complex
93 )
94 if A_is_complex or B_is_complex:
95 # 1) A、B both are complex
96 if A_is_complex and B_is_complex:
97 Ar = torch.view_as_real(A)
98 Br = torch.view_as_real(B)
99 ar, ai = Ar[..., 0], Ar[..., 1]
100 br, bi = Br[..., 0], Br[..., 1]
101 common_dtype = torch.promote_types(ar.dtype, br.dtype)
102 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
103 br, bi = br.to(common_dtype), bi.to(common_dtype)
105 # real_out = torch.empty_like(ar, dtype=common_dtype)
106 # imag_out = torch.empty_like(ar, dtype=common_dtype)
107 shape = ar.shape
108 out_buffer = torch.empty((*shape, 2), dtype=common_dtype, device=ar.device)
109 real_out = out_buffer[..., 0]
110 imag_out = out_buffer[..., 1]
111 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out)
113 # out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
114 out = torch.view_as_complex(out_buffer)
115 return out.to(torch.result_type(A, B))
116 # 2) A complex, B real
117 elif A_is_complex and not B_is_complex:
118 Ar = torch.view_as_real(A)
119 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B
120 if isinstance(Br, torch.Tensor):
121 out_real = mul_func(Ar, Br)
122 else:
123 out_real = mul_func_scalar(Ar, Br)
124 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
125 # 3) A real, B complex
126 else: # not A_is_complex and B_is_complex
127 Br = torch.view_as_real(B)
128 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A
129 if isinstance(Ar, torch.Tensor):
130 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br
131 else:
132 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar
133 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
134 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
135 if len(A.shape) == len(B.shape):
136 return mul_all_real_func(A, B)
137 else:
138 return mul_func(A, B)
139 elif isinstance(A, torch.Tensor):
140 return mul_func_scalar(A, B)
141 elif isinstance(B, torch.Tensor):
142 return mul_func_scalar(B, A)
143 else:
144 # Both scalar
145 return torch.tensor(A * B)
148def mul_(A, B):
149 logger.debug("GEMS MUL_")
150 if isinstance(B, torch.Tensor):
151 return mul_func(A, B, out0=A)
152 else:
153 return mul_func_scalar(A, B, out0=A)