Coverage for src/flag_gems/utils/triton_lang_helper.py: 75%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import importlib 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import backend 

7from flag_gems.runtime.backend.device import DeviceDetector 

8 

9""" 

10 To be compatible with different versions of math libraries 

11 tl_extra_shim will be selected to a specific library. 

12 And the "triton.language.extra" module is only available in 

13 Triton 2.2 and later versions. 

14""" 

15 

16device = DeviceDetector() 

17backend.set_torch_backend_device_fn(device.vendor_name) 

18try: 

19 backend.set_tl_extra_backend_module(device.vendor_name) 

20 tl_extra_shim = backend.get_tl_extra_backend_module() 

21except ImportError: 

22 try: 

23 tl_extra_shim = triton.language.extra.libdevice 

24 except AttributeError: 

25 try: 

26 tl_extra_shim = triton.language.math 

27 except ImportError: 

28 tl_extra_shim = triton.language.libdevice 

29 

30 

31def _import_module(module_name): 

32 try: 

33 return importlib.import_module(module_name) 

34 except (AttributeError, ImportError): 

35 return None 

36 

37 

38def _tl_extra_candidates(): 

39 vendor_info = backend.get_vendor_info(device.vendor_name) 

40 extra_name = vendor_info.triton_extra_name or vendor_info.device_name 

41 module_names = ( 

42 f"triton.language.extra.{extra_name}.libdevice", 

43 "triton.language.extra.libdevice", 

44 "triton.language.math", 

45 "triton.language.libdevice", 

46 ) 

47 for module_name in module_names: 

48 module = _import_module(module_name) 

49 if module is not None: 

50 yield module 

51 

52 

53@triton.jit 

54def _fallback_pow(x, exponent): 

55 return x**exponent 

56 

57 

58@triton.jit 

59def _fallback_tanh(x): 

60 return 2.0 / (1.0 + tl.exp(-2.0 * x)) - 1.0 

61 

62 

63_FALLBACK_SYMBOLS = { 

64 "pow": _fallback_pow, 

65 "tanh": _fallback_tanh, 

66} 

67 

68 

69def _patch_missing_symbols(module, names): 

70 for name in names: 

71 if hasattr(module, name): 

72 continue 

73 for candidate in _tl_extra_candidates(): 

74 if hasattr(candidate, name): 

75 setattr(module, name, getattr(candidate, name)) 

76 break 

77 else: 

78 fallback = _FALLBACK_SYMBOLS.get(name) 

79 if fallback is not None: 

80 setattr(module, name, fallback) 

81 return module 

82 

83 

84tl_extra_shim = _patch_missing_symbols( 

85 tl_extra_shim, 

86 ( 

87 "acos", 

88 "atan", 

89 "atan2", 

90 "div_rn", 

91 "div_rz", 

92 "erf", 

93 "exp", 

94 "exp2", 

95 "fast_erf", 

96 "fast_gelu", 

97 "fast_tanh", 

98 "finitef", 

99 "fmod", 

100 "gelu_none", 

101 "gelu_tanh", 

102 "isfinited", 

103 "isinf", 

104 "isnan", 

105 "log", 

106 "pow", 

107 "rint", 

108 "rsqrt", 

109 "silu", 

110 "tan", 

111 "tanh", 

112 "trunc", 

113 "xpu_trunc_div", 

114 ), 

115) 

116 

117 

118def use_backend(module): 

119 """using backend module impl""" 

120 

121 def decorator(func): 

122 func_name = func.__name__ 

123 if hasattr(module, func_name): 

124 try: 

125 return getattr(module, func_name) 

126 except Exception: 

127 pass 

128 return func 

129 

130 return decorator 

131 

132 

133def use_tl_extra(func): 

134 """backend function shim""" 

135 return use_backend(tl_extra_shim)(func)