Coverage for src/flag_gems/ops/_euclidean_dist.py: 62%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def _euclidean_dist_kernel(
18 x1_ptr,
19 x2_ptr,
20 output_ptr,
21 N,
22 M,
23 D,
24 stride_x1,
25 stride_x2,
26 stride_out,
27 BLOCK_D: tl.constexpr,
28):
29 """Kernel for computing pairwise Euclidean distances between rows of x1 and x2.
31 Args:
32 x1_ptr: Pointer to x1 tensor of shape (N, D)
33 x2_ptr: Pointer to x2 tensor of shape (M, D)
34 output_ptr: Pointer to output tensor of shape (N, M)
35 N: Number of rows in x1
36 M: Number of rows in x2
37 D: Dimension of each row (columns)
38 stride_x1: Stride of x1 along row dimension
39 stride_x2: Stride of x2 along row dimension
40 stride_out: Stride of output along row dimension
41 BLOCK_D: Block size for processing dimension D
42 """
43 pid_n = tle.program_id(0)
44 pid_m = tle.program_id(1)
46 # Compute pointers to the rows
47 x1_row_ptr = x1_ptr + pid_n * stride_x1
48 x2_row_ptr = x2_ptr + pid_m * stride_x2
49 output_ptr_out = output_ptr + pid_n * stride_out + pid_m
51 # Load x1 row and compute partial squared distance
52 acc = tl.zeros([BLOCK_D], dtype=tl.float32)
54 for d_start in range(0, D, BLOCK_D):
55 d_offsets = d_start + tl.arange(0, BLOCK_D)
56 d_mask = d_offsets < D
58 # Load elements from x1 and x2 rows
59 x1_vals = tl.load(x1_row_ptr + d_offsets, mask=d_mask, other=0.0).to(tl.float32)
60 x2_vals = tl.load(x2_row_ptr + d_offsets, mask=d_mask, other=0.0).to(tl.float32)
62 # Compute squared difference and accumulate
63 diff = x1_vals - x2_vals
64 acc += diff * diff
66 # Sum all partial squared distances
67 sq_dist = tl.sum(acc, axis=0)
69 # Compute Euclidean distance (square root)
70 dist = tl.sqrt(sq_dist)
72 # Store result
73 tl.store(output_ptr_out, dist)
76def _euclidean_dist(x1, x2):
77 """Compute pairwise Euclidean distances between rows of x1 and x2.
79 Args:
80 x1: Tensor of shape (N, D)
81 x2: Tensor of shape (M, D)
83 Returns:
84 Tensor of shape (N, M) where output[i, j] = ||x1[i] - x2[j]||_2
85 """
86 logger.debug("GEMS _EUCLIDEAN_DIST")
88 assert x1.ndim == 2, "x1 must be a 2D tensor"
89 assert x2.ndim == 2, "x2 must be a 2D tensor"
90 assert x1.shape[1] == x2.shape[1], "x1 and x2 must have the same number of columns"
92 N, D = x1.shape
93 M = x2.shape[0]
95 x1 = x1.contiguous()
96 x2 = x2.contiguous()
98 output = torch.empty((N, M), dtype=x1.dtype, device=x1.device)
100 BLOCK_D = min(triton.next_power_of_2(D), 1024)
102 with torch_device_fn.device(x1.device):
103 grid = (N, M)
104 _euclidean_dist_kernel[grid](
105 x1,
106 x2,
107 output,
108 N,
109 M,
110 D,
111 x1.stride(0),
112 x2.stride(0),
113 output.stride(0),
114 BLOCK_D=BLOCK_D,
115 )
117 return output