Coverage for src/flag_gems/utils/triton_lang_helper.py: 59%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1from flag_gems.runtime import backend
2from flag_gems.runtime.backend.device import DeviceDetector
4"""
5 To be compatible with different versions of math libraries
6 tl_extra_shim will be selected to a specific library.
7 And the "triton.language.extra" module is only available in
8 Triton 2.2 and later versions.
9"""
11device = DeviceDetector()
12backend.set_torch_backend_device_fn(device.vendor_name)
13try:
14 backend.set_tl_extra_backend_module(device.vendor_name)
15 tl_extra_shim = backend.get_tl_extra_backend_module()
16except ImportError:
17 import triton
19 try:
20 tl_extra_shim = triton.language.extra.libdevice
21 except AttributeError:
22 try:
23 tl_extra_shim = triton.language.math
24 except ImportError:
25 tl_extra_shim = triton.language.libdevice
28def use_backend(module):
29 """using backend module impl"""
31 def decorator(func):
32 func_name = func.__name__
33 if hasattr(module, func_name):
34 try:
35 return getattr(module, func_name)
36 except Exception:
37 pass
38 return func
40 return decorator
43def use_tl_extra(func):
44 """backend function shim"""
45 return use_backend(tl_extra_shim)(func)