Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmv.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import broadcastable_to, libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def heur_block_n(args): 

15 N = args.get("N", 0) 

16 # Use smaller BLOCK_N for more parallelism 

17 if N <= 64: 

18 return triton.next_power_of_2(N) 

19 elif N <= 256: 

20 return 64 

21 elif N <= 1024: 

22 return 128 

23 else: 

24 return 256 

25 

26 

27def heur_block_m(args): 

28 import builtins 

29 

30 M = args.get("M", 0) 

31 # Larger BLOCK_M for better memory coalescing 

32 return builtins.min(triton.next_power_of_2(M), 4096) 

33 

34 

35@libentry() 

36@triton.heuristics( 

37 { 

38 "BLOCK_N": heur_block_n, 

39 "BLOCK_M": heur_block_m, 

40 } 

41) 

42@triton.jit(do_not_specialize=["alpha", "beta"]) 

43def addmv_kernel( 

44 A, 

45 B, 

46 Inp, 

47 Out, 

48 N: tl.constexpr, 

49 M: tl.constexpr, 

50 alpha, 

51 beta, 

52 stride_an: tl.constexpr, 

53 stride_am: tl.constexpr, 

54 stride_bm: tl.constexpr, 

55 stride_in: tl.constexpr, 

56 stride_outn: tl.constexpr, 

57 BLOCK_N: tl.constexpr, 

58 BLOCK_M: tl.constexpr, 

59): 

60 pid = ext.program_id(0) 

61 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

62 offset_m = tl.arange(0, BLOCK_M)[None, :] 

63 n_mask = offset_n < N 

64 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

65 B_ptrs = B + offset_m * stride_bm 

66 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

67 for m in range(0, M, BLOCK_M): 

68 m_mask = m + offset_m < M 

69 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

70 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

71 acc += a * b 

72 A_ptrs += BLOCK_M * stride_am 

73 B_ptrs += BLOCK_M * stride_bm 

74 

75 acc = tl.sum(acc, axis=1)[:, None] 

76 Inp_ptrs = Inp + offset_n * stride_in 

77 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32) 

78 Out_ptrs = Out + offset_n * stride_outn 

79 out_block = acc * alpha + inp * beta 

80 tl.store(Out_ptrs, out_block, mask=n_mask) 

81 

82 

83def addmv(self, mat, vec, *, beta=1, alpha=1): 

84 logger.debug("GEMS_KUNLUNXIN ADDMV") 

85 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

86 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

87 N, M = mat.shape 

88 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

89 self = self.broadcast_to(out.shape) 

90 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

91 with torch_device_fn.device(mat.device): 

92 addmv_kernel[grid]( 

93 mat, 

94 vec, 

95 self, 

96 out, 

97 N, 

98 M, 

99 alpha, 

100 beta, 

101 mat.stride(0), 

102 mat.stride(1), 

103 vec.stride(0), 

104 self.stride(0), 

105 out.stride(0), 

106 ) 

107 return out 

108 

109 

110def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None): 

111 logger.debug("GEMS_KUNLUNXIN ADDMV OUT") 

112 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

113 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

114 N, M = mat.shape 

115 if out is None: 

116 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

117 else: 

118 assert out.shape == (N,), "Incompatible output shape" 

119 

120 self = self.broadcast_to(out.shape) 

121 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

122 with torch_device_fn.device(mat.device): 

123 addmv_kernel[grid]( 

124 mat, 

125 vec, 

126 self, 

127 out, 

128 N, 

129 M, 

130 alpha, 

131 beta, 

132 mat.stride(0), 

133 mat.stride(1), 

134 vec.stride(0), 

135 self.stride(0), 

136 out.stride(0), 

137 ) 

138 return out