Coverage for src/flag_gems/utils/triton_lang_helper.py: 75%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import importlib
3import triton
4import triton.language as tl
6from flag_gems.runtime import backend
7from flag_gems.runtime.backend.device import DeviceDetector
9"""
10 To be compatible with different versions of math libraries
11 tl_extra_shim will be selected to a specific library.
12 And the "triton.language.extra" module is only available in
13 Triton 2.2 and later versions.
14"""
16device = DeviceDetector()
17backend.set_torch_backend_device_fn(device.vendor_name)
18try:
19 backend.set_tl_extra_backend_module(device.vendor_name)
20 tl_extra_shim = backend.get_tl_extra_backend_module()
21except ImportError:
22 try:
23 tl_extra_shim = triton.language.extra.libdevice
24 except AttributeError:
25 try:
26 tl_extra_shim = triton.language.math
27 except ImportError:
28 tl_extra_shim = triton.language.libdevice
31def _import_module(module_name):
32 try:
33 return importlib.import_module(module_name)
34 except (AttributeError, ImportError):
35 return None
38def _tl_extra_candidates():
39 vendor_info = backend.get_vendor_info(device.vendor_name)
40 extra_name = vendor_info.triton_extra_name or vendor_info.device_name
41 module_names = (
42 f"triton.language.extra.{extra_name}.libdevice",
43 "triton.language.extra.libdevice",
44 "triton.language.math",
45 "triton.language.libdevice",
46 )
47 for module_name in module_names:
48 module = _import_module(module_name)
49 if module is not None:
50 yield module
53@triton.jit
54def _fallback_pow(x, exponent):
55 return x**exponent
58@triton.jit
59def _fallback_tanh(x):
60 return 2.0 / (1.0 + tl.exp(-2.0 * x)) - 1.0
63_FALLBACK_SYMBOLS = {
64 "pow": _fallback_pow,
65 "tanh": _fallback_tanh,
66}
69def _patch_missing_symbols(module, names):
70 for name in names:
71 if hasattr(module, name):
72 continue
73 for candidate in _tl_extra_candidates():
74 if hasattr(candidate, name):
75 setattr(module, name, getattr(candidate, name))
76 break
77 else:
78 fallback = _FALLBACK_SYMBOLS.get(name)
79 if fallback is not None:
80 setattr(module, name, fallback)
81 return module
84tl_extra_shim = _patch_missing_symbols(
85 tl_extra_shim,
86 (
87 "acos",
88 "atan",
89 "atan2",
90 "div_rn",
91 "div_rz",
92 "erf",
93 "exp",
94 "exp2",
95 "fast_erf",
96 "fast_gelu",
97 "fast_tanh",
98 "finitef",
99 "fmod",
100 "gelu_none",
101 "gelu_tanh",
102 "isfinited",
103 "isinf",
104 "isnan",
105 "log",
106 "pow",
107 "rint",
108 "rsqrt",
109 "silu",
110 "tan",
111 "tanh",
112 "trunc",
113 "xpu_trunc_div",
114 ),
115)
118def use_backend(module):
119 """using backend module impl"""
121 def decorator(func):
122 func_name = func.__name__
123 if hasattr(module, func_name):
124 try:
125 return getattr(module, func_name)
126 except Exception:
127 pass
128 return func
130 return decorator
133def use_tl_extra(func):
134 """backend function shim"""
135 return use_backend(tl_extra_shim)(func)