Coverage for src/flag_gems/runtime/common.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import os 

2from enum import Enum 

3 

4 

5class vendors(Enum): 

6 NVIDIA = 0 

7 CAMBRICON = 1 

8 METAX = 2 

9 ILUVATAR = 3 

10 MTHREADS = 4 

11 KUNLUNXIN = 5 

12 HYGON = 6 

13 AMD = 7 

14 AIPU = 8 

15 ASCEND = 9 

16 TSINGMICRO = 10 

17 SUNRISE = 11 

18 ENFLAME = 12 

19 SPACEMIT = 13 

20 THEAD = 14 

21 

22 @classmethod 

23 def get_all_vendors(cls) -> dict: 

24 vendorDict = {} 

25 for member in cls: 

26 vendorDict[member.name.lower()] = member 

27 return vendorDict 

28 

29 

30UNSUPPORT_FP64 = frozenset( 

31 { 

32 vendors.AIPU, 

33 vendors.ASCEND, 

34 vendors.CAMBRICON, 

35 vendors.ENFLAME, 

36 vendors.ILUVATAR, 

37 vendors.KUNLUNXIN, 

38 vendors.MTHREADS, 

39 vendors.SUNRISE, 

40 vendors.SPACEMIT, 

41 vendors.TSINGMICRO, 

42 } 

43) 

44 

45UNSUPPORT_BF16 = frozenset( 

46 { 

47 vendors.AIPU, 

48 vendors.SUNRISE, 

49 vendors.SPACEMIT, 

50 } 

51) 

52 

53UNSUPPORT_INT64 = frozenset( 

54 { 

55 vendors.AIPU, 

56 vendors.ENFLAME, 

57 vendors.SPACEMIT, 

58 vendors.SUNRISE, 

59 vendors.TSINGMICRO, 

60 } 

61) 

62 

63DEFAULT_EXPAND_CONFIG_PATH = os.path.normpath( 

64 os.path.join( 

65 os.path.dirname(__file__), 

66 "..", 

67 "utils", 

68 "configs", 

69 "general_ops_expand_configs.yaml", 

70 ) 

71) 

72 

73 

74DEFAULT_STRATEGIES = { 

75 "addmm": ["align32", "align32", "align32"], 

76 "addmm_sqmma": ["align32", "align32", "align32"], 

77 "baddbmm": ["align32", "align32", "align32"], 

78 "bmm": ["align32", "align32", "align32", "align32", "align32"], 

79 "bmm_sqmma": ["align32", "align32", "align32"], 

80 "gemv": ["align32", "align32", "align32", "default"], 

81 "mm": ["align32", "align32", "align32", "align32", "align32"], 

82 "mm_general_tma": [ 

83 "align32", 

84 "align32", 

85 "align32", 

86 "align32", 

87 "align32", 

88 "default", 

89 ], 

90 "mv": ["align32", "align32"], 

91 "sparse_attention": ["align32", "align32", "align32"], 

92 "w8a8_block_fp8_general": [ 

93 "align32", 

94 "align32", 

95 "align32", 

96 "align32", 

97 "align32", 

98 ], 

99 "w8a8_block_fp8_general_splitk": [ 

100 "align32", 

101 "align32", 

102 "align32", 

103 "align32", 

104 "align32", 

105 ], 

106 "w8a8_block_fp8_general_tma": [ 

107 "align32", 

108 "align32", 

109 "align32", 

110 "align32", 

111 "align32", 

112 "default", 

113 ], 

114 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"], 

115} 

116 

117OP_KEY_ORDERS = { 

118 "addmm": ["M", "N", "K"], 

119 "addmm_sqmma": ["M", "N", "K"], 

120 "bmm": ["M", "N", "K", "stride_am", "stride_bk"], 

121 "bmm_sqmma": ["M", "N", "K"], 

122 "baddbmm": ["M", "N", "K"], 

123 "gemv": ["M", "K", "stride_am", "stride_bk"], 

124 "mm": ["M", "N", "K", "stride_am", "stride_bk"], 

125 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

126 "mv": ["M", "N"], 

127 "sparse_attention": ["topk", "H_ACTUAL", "D"], 

128 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"], 

129 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

130 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

131 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

132} 

133 

134 

135# Mapping from vendor name to torch attribute for quick detection 

136_VENDOR_TORCH_ATTR = { 

137 "ascend": "npu", 

138 "cambricon": "mlu", 

139 "enflame": "gcu", 

140 "hygon": "__hcu_version__", 

141 "iluvatar": "corex", 

142 "mthreads": "musa", 

143 "sunrise": "ptpu", 

144} 

145 

146__all__ = [ 

147 "vendors", 

148 "UNSUPPORT_FP64", 

149 "UNSUPPORT_BF16", 

150 "UNSUPPORT_INT64", 

151 "DEFAULT_STRATEGIES", 

152 "OP_KEY_ORDERS", 

153 "_VENDOR_TORCH_ATTR", 

154]