Coverage for src/flag_gems/runtime/common.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from enum import Enum 

2 

3 

4class vendors(Enum): 

5 NVIDIA = 0 

6 CAMBRICON = 1 

7 METAX = 2 

8 ILUVATAR = 3 

9 MTHREADS = 4 

10 KUNLUNXIN = 5 

11 HYGON = 6 

12 AMD = 7 

13 AIPU = 8 

14 ASCEND = 9 

15 TSINGMICRO = 10 

16 SUNRISE = 11 

17 ENFLAME = 12 

18 SPACEMIT = 13 

19 THEAD = 14 

20 ARM = 15 

21 

22 @classmethod 

23 def get_all_vendors(cls) -> dict: 

24 vendorDict = {} 

25 for member in cls: 

26 vendorDict[member.name.lower()] = member 

27 return vendorDict 

28 

29 

30UNSUPPORT_FP64 = frozenset( 

31 { 

32 vendors.AIPU, 

33 vendors.ASCEND, 

34 vendors.CAMBRICON, 

35 vendors.ENFLAME, 

36 vendors.ILUVATAR, 

37 vendors.KUNLUNXIN, 

38 vendors.MTHREADS, 

39 vendors.SUNRISE, 

40 vendors.SPACEMIT, 

41 vendors.TSINGMICRO, 

42 } 

43) 

44 

45UNSUPPORT_BF16 = frozenset( 

46 { 

47 vendors.AIPU, 

48 vendors.SUNRISE, 

49 vendors.SPACEMIT, 

50 } 

51) 

52 

53UNSUPPORT_INT64 = frozenset( 

54 { 

55 vendors.AIPU, 

56 vendors.ENFLAME, 

57 vendors.SPACEMIT, 

58 vendors.SUNRISE, 

59 vendors.TSINGMICRO, 

60 } 

61) 

62 

63DEFAULT_STRATEGIES = { 

64 "addmm": ["align32", "align32", "align32"], 

65 "addmm_sqmma": ["align32", "align32", "align32"], 

66 "baddbmm": ["align32", "align32", "align32"], 

67 "bmm": ["align32", "align32", "align32", "align32", "align32"], 

68 "bmm_sqmma": ["align32", "align32", "align32"], 

69 "gemv": ["align32", "align32", "align32", "default"], 

70 "mm": ["align32", "align32", "align32", "align32", "align32"], 

71 "mm_general_tma": [ 

72 "align32", 

73 "align32", 

74 "align32", 

75 "align32", 

76 "align32", 

77 "default", 

78 ], 

79 "mv": ["align32", "align32"], 

80 "sparse_attention": ["align32", "align32", "align32"], 

81 "w8a8_block_fp8_general": [ 

82 "align32", 

83 "align32", 

84 "align32", 

85 "align32", 

86 "align32", 

87 ], 

88 "w8a8_block_fp8_general_splitk": [ 

89 "align32", 

90 "align32", 

91 "align32", 

92 "align32", 

93 "align32", 

94 ], 

95 "w8a8_block_fp8_general_tma": [ 

96 "align32", 

97 "align32", 

98 "align32", 

99 "align32", 

100 "align32", 

101 "default", 

102 ], 

103 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"], 

104} 

105 

106OP_KEY_ORDERS = { 

107 "addmm": ["M", "N", "K"], 

108 "addmm_sqmma": ["M", "N", "K"], 

109 "bmm": ["M", "N", "K", "stride_am", "stride_bk"], 

110 "bmm_sqmma": ["M", "N", "K"], 

111 "baddbmm": ["M", "N", "K"], 

112 "gemv": ["M", "K", "stride_am", "stride_bk"], 

113 "mm": ["M", "N", "K", "stride_am", "stride_bk"], 

114 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

115 "mv": ["M", "N"], 

116 "sparse_attention": ["topk", "H_ACTUAL", "D"], 

117 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"], 

118 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

119 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

120 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

121} 

122 

123 

124# Mapping from vendor name to torch attribute for quick detection 

125_VENDOR_TORCH_ATTR = { 

126 "ascend": "npu", 

127 "cambricon": "mlu", 

128 "enflame": "gcu", 

129 "hygon": "__hcu_version__", 

130 "iluvatar": "corex", 

131 "mthreads": "musa", 

132 "sunrise": "ptpu", 

133} 

134 

135__all__ = [ 

136 "vendors", 

137 "UNSUPPORT_FP64", 

138 "UNSUPPORT_BF16", 

139 "UNSUPPORT_INT64", 

140 "DEFAULT_STRATEGIES", 

141 "OP_KEY_ORDERS", 

142 "_VENDOR_TORCH_ATTR", 

143]