Coverage for src/flag_gems/runtime/common.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import os 

2from enum import Enum 

3 

4 

5class vendors(Enum): 

6 NVIDIA = 0 

7 CAMBRICON = 1 

8 METAX = 2 

9 ILUVATAR = 3 

10 MTHREADS = 4 

11 KUNLUNXIN = 5 

12 HYGON = 6 

13 AMD = 7 

14 AIPU = 8 

15 ASCEND = 9 

16 TSINGMICRO = 10 

17 SUNRISE = 11 

18 ENFLAME = 12 

19 

20 @classmethod 

21 def get_all_vendors(cls) -> dict: 

22 vendorDict = {} 

23 for member in cls: 

24 vendorDict[member.name.lower()] = member 

25 return vendorDict 

26 

27 

28UNSUPPORT_FP64 = frozenset( 

29 { 

30 vendors.CAMBRICON, 

31 vendors.ILUVATAR, 

32 vendors.KUNLUNXIN, 

33 vendors.MTHREADS, 

34 vendors.AIPU, 

35 vendors.ASCEND, 

36 vendors.TSINGMICRO, 

37 vendors.SUNRISE, 

38 vendors.ENFLAME, 

39 } 

40) 

41 

42UNSUPPORT_BF16 = frozenset( 

43 { 

44 vendors.AIPU, 

45 vendors.SUNRISE, 

46 } 

47) 

48 

49UNSUPPORT_INT64 = frozenset( 

50 { 

51 vendors.AIPU, 

52 vendors.TSINGMICRO, 

53 vendors.SUNRISE, 

54 vendors.ENFLAME, 

55 } 

56) 

57 

58DEFAULT_EXPAND_CONFIG_PATH = os.path.normpath( 

59 os.path.join( 

60 os.path.dirname(__file__), 

61 "..", 

62 "utils", 

63 "configs", 

64 "general_ops_expand_configs.yaml", 

65 ) 

66) 

67 

68 

69DEFAULT_STRATEGIES = { 

70 "bmm": ["align32", "align32", "align32", "align32", "align32"], 

71 "addmm": ["align32", "align32", "align32"], 

72 "baddbmm": ["align32", "align32", "align32"], 

73 "mv": ["align32", "align32"], 

74 "w8a8_block_fp8_general": [ 

75 "align32", 

76 "align32", 

77 "align32", 

78 "align32", 

79 "align32", 

80 ], 

81 "w8a8_block_fp8_general_splitk": [ 

82 "align32", 

83 "align32", 

84 "align32", 

85 "align32", 

86 "align32", 

87 ], 

88 "w8a8_block_fp8_general_tma": [ 

89 "align32", 

90 "align32", 

91 "align32", 

92 "align32", 

93 "align32", 

94 "default", 

95 ], 

96 "mm_general_tma": [ 

97 "align32", 

98 "align32", 

99 "align32", 

100 "align32", 

101 "align32", 

102 "default", 

103 ], 

104 "gemv": ["align32", "align32", "align32", "default"], 

105 "sparse_attention": ["align32", "align32", "align32"], 

106 "mm": ["align32", "align32", "align32", "align32", "align32"], 

107 "bmm_sqmma": ["align32", "align32", "align32"], 

108 "addmm_sqmma": ["align32", "align32", "align32"], 

109} 

110 

111OP_KEY_ORDERS = { 

112 "bmm": ["M", "N", "K", "stride_am", "stride_bk"], 

113 "addmm": ["M", "N", "K"], 

114 "baddbmm": ["M", "N", "K"], 

115 "mv": ["M", "N"], 

116 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"], 

117 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

118 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

119 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

120 "gemv": ["M", "K", "stride_am", "stride_bk"], 

121 "sparse_attention": ["topk", "H_ACTUAL", "D"], 

122 "mm": ["M", "N", "K", "stride_am", "stride_bk"], 

123 "bmm_sqmma": ["M", "N", "K"], 

124 "addmm_sqmma": ["M", "N", "K"], 

125} 

126 

127 

128# Mapping from vendor name to torch attribute for quick detection 

129_VENDOR_TORCH_ATTR = { 

130 "cambricon": "mlu", 

131 "mthreads": "musa", 

132 "iluvatar": "corex", 

133 "ascend": "npu", 

134 "sunrise": "ptpu", 

135 "enflame": "gcu", 

136} 

137 

138__all__ = [ 

139 "vendors", 

140 "UNSUPPORT_FP64", 

141 "UNSUPPORT_BF16", 

142 "UNSUPPORT_INT64", 

143 "DEFAULT_STRATEGIES", 

144 "OP_KEY_ORDERS", 

145 "_VENDOR_TORCH_ATTR", 

146]