Coverage for src/flag_gems/ops/mul.py: 83%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.pointwise_dynamic import ComplexMode 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def mul_func(x, y): 

15 return x * y 

16 

17 

18@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

19@triton.jit 

20def mul_func_scalar(x, y): 

21 return x * y 

22 

23 

24@pointwise_dynamic( 

25 is_tensor=[True, True, True, True], # ar, ai, br, bi 

26 num_outputs=2, 

27 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

28) 

29@triton.jit 

30def mul_complex_kernel(ar, ai, br, bi): 

31 real = ar * br - ai * bi 

32 imag = ar * bi + ai * br 

33 return real, imag 

34 

35 

36# Register complex support 

37mul_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=mul_complex_kernel) 

38mul_func_scalar.register_complex( 

39 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=mul_func 

40) 

41 

42 

43def mul(A, B): 

44 logger.debug("GEMS MUL") 

45 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

46 return mul_func(A, B) 

47 elif isinstance(A, torch.Tensor): 

48 return mul_func_scalar(A, B) 

49 elif isinstance(B, torch.Tensor): 

50 return mul_func_scalar(B, A) 

51 else: 

52 # Both scalar 

53 return torch.tensor(A * B) 

54 

55 

56def mul_(A, B): 

57 logger.debug("GEMS MUL_") 

58 if isinstance(B, torch.Tensor): 

59 return mul_func(A, B, out0=A) 

60 else: 

61 return mul_func_scalar(A, B, out0=A)