Coverage for src/flag_gems/ops/_euclidean_dist.py: 62%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def _euclidean_dist_kernel( 

18 x1_ptr, 

19 x2_ptr, 

20 output_ptr, 

21 N, 

22 M, 

23 D, 

24 stride_x1, 

25 stride_x2, 

26 stride_out, 

27 BLOCK_D: tl.constexpr, 

28): 

29 """Kernel for computing pairwise Euclidean distances between rows of x1 and x2. 

30 

31 Args: 

32 x1_ptr: Pointer to x1 tensor of shape (N, D) 

33 x2_ptr: Pointer to x2 tensor of shape (M, D) 

34 output_ptr: Pointer to output tensor of shape (N, M) 

35 N: Number of rows in x1 

36 M: Number of rows in x2 

37 D: Dimension of each row (columns) 

38 stride_x1: Stride of x1 along row dimension 

39 stride_x2: Stride of x2 along row dimension 

40 stride_out: Stride of output along row dimension 

41 BLOCK_D: Block size for processing dimension D 

42 """ 

43 pid_n = tle.program_id(0) 

44 pid_m = tle.program_id(1) 

45 

46 # Compute pointers to the rows 

47 x1_row_ptr = x1_ptr + pid_n * stride_x1 

48 x2_row_ptr = x2_ptr + pid_m * stride_x2 

49 output_ptr_out = output_ptr + pid_n * stride_out + pid_m 

50 

51 # Load x1 row and compute partial squared distance 

52 acc = tl.zeros([BLOCK_D], dtype=tl.float32) 

53 

54 for d_start in range(0, D, BLOCK_D): 

55 d_offsets = d_start + tl.arange(0, BLOCK_D) 

56 d_mask = d_offsets < D 

57 

58 # Load elements from x1 and x2 rows 

59 x1_vals = tl.load(x1_row_ptr + d_offsets, mask=d_mask, other=0.0).to(tl.float32) 

60 x2_vals = tl.load(x2_row_ptr + d_offsets, mask=d_mask, other=0.0).to(tl.float32) 

61 

62 # Compute squared difference and accumulate 

63 diff = x1_vals - x2_vals 

64 acc += diff * diff 

65 

66 # Sum all partial squared distances 

67 sq_dist = tl.sum(acc, axis=0) 

68 

69 # Compute Euclidean distance (square root) 

70 dist = tl.sqrt(sq_dist) 

71 

72 # Store result 

73 tl.store(output_ptr_out, dist) 

74 

75 

76def _euclidean_dist(x1, x2): 

77 """Compute pairwise Euclidean distances between rows of x1 and x2. 

78 

79 Args: 

80 x1: Tensor of shape (N, D) 

81 x2: Tensor of shape (M, D) 

82 

83 Returns: 

84 Tensor of shape (N, M) where output[i, j] = ||x1[i] - x2[j]||_2 

85 """ 

86 logger.debug("GEMS _EUCLIDEAN_DIST") 

87 

88 assert x1.ndim == 2, "x1 must be a 2D tensor" 

89 assert x2.ndim == 2, "x2 must be a 2D tensor" 

90 assert x1.shape[1] == x2.shape[1], "x1 and x2 must have the same number of columns" 

91 

92 N, D = x1.shape 

93 M = x2.shape[0] 

94 

95 x1 = x1.contiguous() 

96 x2 = x2.contiguous() 

97 

98 output = torch.empty((N, M), dtype=x1.dtype, device=x1.device) 

99 

100 BLOCK_D = min(triton.next_power_of_2(D), 1024) 

101 

102 with torch_device_fn.device(x1.device): 

103 grid = (N, M) 

104 _euclidean_dist_kernel[grid]( 

105 x1, 

106 x2, 

107 output, 

108 N, 

109 M, 

110 D, 

111 x1.stride(0), 

112 x2.stride(0), 

113 output.stride(0), 

114 BLOCK_D=BLOCK_D, 

115 ) 

116 

117 return output