Coverage for src/flag_gems/utils/triton_lang_helper.py: 59%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1from flag_gems.runtime import backend 

2from flag_gems.runtime.backend.device import DeviceDetector 

3 

4""" 

5 To be compatible with different versions of math libraries 

6 tl_extra_shim will be selected to a specific library. 

7 And the "triton.language.extra" module is only available in 

8 Triton 2.2 and later versions. 

9""" 

10 

11device = DeviceDetector() 

12backend.set_torch_backend_device_fn(device.vendor_name) 

13try: 

14 backend.set_tl_extra_backend_module(device.vendor_name) 

15 tl_extra_shim = backend.get_tl_extra_backend_module() 

16except ImportError: 

17 import triton 

18 

19 try: 

20 tl_extra_shim = triton.language.extra.libdevice 

21 except AttributeError: 

22 try: 

23 tl_extra_shim = triton.language.math 

24 except ImportError: 

25 tl_extra_shim = triton.language.libdevice 

26 

27 

28def use_backend(module): 

29 """using backend module impl""" 

30 

31 def decorator(func): 

32 func_name = func.__name__ 

33 if hasattr(module, func_name): 

34 try: 

35 return getattr(module, func_name) 

36 except Exception: 

37 pass 

38 return func 

39 

40 return decorator 

41 

42 

43def use_tl_extra(func): 

44 """backend function shim""" 

45 return use_backend(tl_extra_shim)(func)