Coverage for src/flag_gems/runtime/backend/_arm/ops/lt.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9 

10@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) 

11@triton.jit 

12def lt_func(x, y): 

13 return x.to(tl.float32) < y 

14 

15 

16@triton.jit 

17def lt_kernel( 

18 x_ptr, 

19 y_ptr, 

20 out_ptr, 

21 n_elements, 

22 BLOCK_SIZE: tl.constexpr = 16, 

23): 

24 pid = tl.program_id(0) 

25 block_start = pid * BLOCK_SIZE 

26 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

27 

28 mask = offsets < n_elements 

29 x_vals = tl.load(x_ptr + offsets, mask=mask) 

30 y_vals = tl.load(y_ptr + offsets, mask=mask) 

31 

32 out = x_vals < y_vals 

33 tl.store(out_ptr + offsets, out, mask=mask) 

34 

35 

36def lt_block(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 

37 # Ensure tensors are on same device and dtype 

38 assert x.device == y.device, "Tensors must be on the same device" 

39 # device = x.device 

40 dtype = torch.float32 

41 

42 # Broadcast tensors to the same shape 

43 x_b, y_b = torch.broadcast_tensors(x.to(dtype), y.to(dtype)) 

44 out = torch.empty_like(x_b, dtype=torch.bool) 

45 

46 # Flatten tensors for Triton kernel 

47 x_b_flat = x_b.contiguous().view(-1) 

48 y_b_flat = y_b.contiguous().view(-1) 

49 out_flat = out.view(-1) 

50 

51 n_elements = out_flat.numel() 

52 

53 # Launch Triton kernel 

54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

55 lt_kernel[grid](x_b_flat, y_b_flat, out_flat, n_elements, BLOCK_SIZE=16) 

56 

57 return out 

58 

59 

60def lt(A, B): 

61 logging.debug("GEMS LT") 

62 return lt_block(A, B) 

63 # return lt_func(A, B) 

64 

65 

66@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) 

67@triton.jit 

68def lt_func_scalar(x, y): 

69 return x.to(tl.float32) < y 

70 

71 

72def lt_scalar(A, B): 

73 logging.debug("GEMS LT SCALAR") 

74 return lt_func_scalar(A, B)