Coverage for src/flag_gems/runtime/backend/_arm/ops/lt.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
10@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")])
11@triton.jit
12def lt_func(x, y):
13 return x.to(tl.float32) < y
16@triton.jit
17def lt_kernel(
18 x_ptr,
19 y_ptr,
20 out_ptr,
21 n_elements,
22 BLOCK_SIZE: tl.constexpr = 16,
23):
24 pid = tl.program_id(0)
25 block_start = pid * BLOCK_SIZE
26 offsets = block_start + tl.arange(0, BLOCK_SIZE)
28 mask = offsets < n_elements
29 x_vals = tl.load(x_ptr + offsets, mask=mask)
30 y_vals = tl.load(y_ptr + offsets, mask=mask)
32 out = x_vals < y_vals
33 tl.store(out_ptr + offsets, out, mask=mask)
36def lt_block(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
37 # Ensure tensors are on same device and dtype
38 assert x.device == y.device, "Tensors must be on the same device"
39 # device = x.device
40 dtype = torch.float32
42 # Broadcast tensors to the same shape
43 x_b, y_b = torch.broadcast_tensors(x.to(dtype), y.to(dtype))
44 out = torch.empty_like(x_b, dtype=torch.bool)
46 # Flatten tensors for Triton kernel
47 x_b_flat = x_b.contiguous().view(-1)
48 y_b_flat = y_b.contiguous().view(-1)
49 out_flat = out.view(-1)
51 n_elements = out_flat.numel()
53 # Launch Triton kernel
54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
55 lt_kernel[grid](x_b_flat, y_b_flat, out_flat, n_elements, BLOCK_SIZE=16)
57 return out
60def lt(A, B):
61 logging.debug("GEMS LT")
62 return lt_block(A, B)
63 # return lt_func(A, B)
66@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")])
67@triton.jit
68def lt_func_scalar(x, y):
69 return x.to(tl.float32) < y
72def lt_scalar(A, B):
73 logging.debug("GEMS LT SCALAR")
74 return lt_func_scalar(A, B)