Coverage for src/flag_gems/runtime/backend/_sunrise/ops/bitwise_left_shift.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def bitwise_left_shift_func(x, y):
14 return x << y
17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def bitwise_left_shift_func_scalar(x, y):
20 return x << y
23def bitwise_left_shift(A, B):
24 logger.debug("GEMS BITWISE_LEFT_SHIFT")
25 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
26 return bitwise_left_shift_func(A, B)
27 elif isinstance(A, torch.Tensor):
28 return bitwise_left_shift_func_scalar(A, B)
29 elif isinstance(B, torch.Tensor):
30 return bitwise_left_shift_func_scalar(B, A)
31 return torch.tensor(A << B)
34def bitwise_left_shift_out(A, B, out):
35 logger.debug("GEMS BITWISE_LEFT_SHIFT OUT")
36 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
37 return bitwise_left_shift_func(A, B, out0=out)
38 elif isinstance(A, torch.Tensor):
39 return bitwise_left_shift_func_scalar(A, B, out0=out)
40 elif isinstance(B, torch.Tensor):
41 return bitwise_left_shift_func_scalar(B, A, out0=out)
42 return out.fill_(A << B)
45def bitwise_left_shift_(A, B):
46 logger.debug("GEMS BITWISE_LEFT_SHIFT_")
47 if isinstance(B, torch.Tensor):
48 return bitwise_left_shift_func(A, B, out0=A)
49 return bitwise_left_shift_func_scalar(A, B, out0=A)