Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/add.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def add_func(x, y, alpha): 

14 return x + y * alpha 

15 

16 

17@pointwise_dynamic( 

18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

19) 

20@triton.jit 

21def add_func_tensor_scalar(x, y, alpha): 

22 return x + y * alpha 

23 

24 

25@pointwise_dynamic( 

26 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

27) 

28@triton.jit 

29def add_func_scalar_tensor(x, y, alpha): 

30 return x + y * alpha 

31 

32 

33def add(A, B, *, alpha=1): 

34 logger.debug("GEMS_KUNLUNXIN ADD") 

35 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

36 A, complex 

37 ) 

38 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

39 B, complex 

40 ) 

41 if A_is_complex or B_is_complex: 

42 if A_is_complex and B_is_complex: 

43 Ar = torch.view_as_real(A) 

44 Br = torch.view_as_real(B) 

45 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

46 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

47 out_real = add_func(Ar, Br, alpha) 

48 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

49 elif A_is_complex and not B_is_complex: 

50 Ar = torch.view_as_real(A) 

51 if isinstance(B, torch.Tensor): 

52 B_casted = B.to(dtype=Ar.dtype) 

53 Br = torch.stack([B_casted, torch.zeros_like(B_casted)], dim=-1) 

54 else: 

55 B_tensor = torch.full_like(Ar[..., 0], fill_value=B, dtype=Ar.dtype) 

56 Br = torch.stack([B_tensor, torch.zeros_like(B_tensor)], dim=-1) 

57 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

58 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

59 out_real = add_func(Ar, Br, alpha) 

60 return torch.view_as_complex(out_real.contiguous()).to( 

61 torch.result_type(A, B) 

62 ) 

63 else: 

64 Br = torch.view_as_real(B) 

65 if isinstance(A, torch.Tensor): 

66 A_casted = A.to(dtype=Br.dtype) 

67 Ar = torch.stack([A_casted, torch.zeros_like(A_casted)], dim=-1) 

68 else: 

69 A_tensor = torch.full_like(Br[..., 0], fill_value=A, dtype=Br.dtype) 

70 Ar = torch.stack([A_tensor, torch.zeros_like(A_tensor)], dim=-1) 

71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

73 out_real = add_func(Ar, Br, alpha) 

74 return torch.view_as_complex(out_real.contiguous()).to( 

75 torch.result_type(A, B) 

76 ) 

77 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

78 if B.device != A.device: 

79 B = B.to(A.device) 

80 return add_func(A, B, alpha) 

81 elif isinstance(A, torch.Tensor): 

82 return add_func_tensor_scalar(A, B, alpha) 

83 elif isinstance(B, torch.Tensor): 

84 return add_func_scalar_tensor(A, B, alpha) 

85 else: 

86 return torch.tensor(A + B * alpha) 

87 

88 

89def add_(A, B, *, alpha=1.0): 

90 logger.debug("GEMS_KUNLUNXIN ADD_") 

91 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

92 return add_func(A, B, alpha, out0=A) 

93 elif isinstance(A, torch.Tensor): 

94 return add_func_tensor_scalar(A, B, alpha, out0=A) 

95 # elif isinstance(B, torch.Tensor): 

96 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

97 else: 

98 raise ValueError("Unreachable.")