Coverage for src/flag_gems/runtime/backend/_sunrise/ops/eq.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems 

8from flag_gems.runtime import device 

9from flag_gems.utils import pointwise_dynamic 

10from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13device = device.name 

14 

15MAX_GRID_SIZES = (65535, 65535, 65535) 

16config = CodeGenConfig( 

17 max_tile_size=1024, 

18 max_grid_size=MAX_GRID_SIZES, 

19 max_num_warps_per_cta=32, 

20 prefer_block_pointer=True, 

21 prefer_1d_tile=True, 

22) 

23 

24 

25@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")], config=config) 

26@triton.jit 

27def eq_func(x, y): 

28 return x.to(tl.float32) == y.to(tl.float32) 

29 

30 

31def eq(A, B): 

32 if A.device != B.device: 

33 if A.device.type == device: 

34 B = B.to(A.device) 

35 else: 

36 A = A.to(B.device) 

37 logger.debug("GEMS EQ") 

38 return eq_func(A, B) 

39 

40 

41@pointwise_dynamic( 

42 is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")], config=config 

43) 

44@triton.jit 

45def eq_func_scalar(x, y): 

46 return x.to(tl.float32) == y.to(tl.float32) 

47 

48 

49def eq_scalar(A, B): 

50 logger.debug("GEMS EQ SCALAR") 

51 return eq_func_scalar(A, B) 

52 

53 

54def equal(x: torch.Tensor, y: torch.Tensor) -> bool: 

55 logger.debug("GEMS EQUAL") 

56 if x.shape != y.shape: 

57 return False 

58 eq_tensor = eq(x, y) 

59 return bool(flag_gems.all(eq_tensor).item())