Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dot.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as ext 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@triton.jit 

16def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): 

17 pid = ext.program_id(0) 

18 block_start = pid * BLOCK_SIZE 

19 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 

22 mask = offsets < N 

23 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

24 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

25 

26 sum = tl.sum(x * y) 

27 tl.store(out_ptr, sum) 

28 

29 

30@libentry() 

31@triton.autotune( 

32 configs=[ 

33 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), 

34 triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), 

35 triton.Config({"BLOCK_SIZE": 16384}, num_warps=16, num_stages=2), 

36 triton.Config({"BLOCK_SIZE": 32768}, num_warps=16, num_stages=2), 

37 ], 

38 key=["N"], 

39) 

40@triton.jit 

41def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr): 

42 pid = ext.program_id(0) 

43 block_start = pid * BLOCK_SIZE 

44 

45 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

46 

47 mask = offsets < N 

48 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

49 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

50 

51 partial_sum = tl.sum(x * y) 

52 tl.store(mid_ptr + pid, partial_sum) 

53 

54 

55@libentry() 

56@triton.jit 

57def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr): 

58 offset = tl.arange(0, BLOCK_MID) 

59 mid = mid_ptr + offset 

60 mask = offset < M 

61 mid_val = tl.load(mid, mask=mask, other=0.0) 

62 out_val = tl.sum(mid_val) 

63 tl.store(out_ptr, out_val) 

64 

65 

66def dot(x, y): 

67 logger.debug("GEMS_KUNLUNXIN DOT") 

68 

69 assert x.shape == y.shape, "Input vectors must have the same shape" 

70 assert x.dim() == 1, "Input must be 1D tensors" 

71 

72 N = x.shape[0] 

73 

74 if N >= 4096: 

75 # Allocate for worst case (smallest block size = 4096) 

76 max_mid_size = triton.cdiv(N, 4096) 

77 block_mid = triton.next_power_of_2(max_mid_size) 

78 

79 grid_1 = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

80 

81 mid = torch.empty((max_mid_size,), dtype=torch.float32, device=x.device) 

82 out = torch.empty([], dtype=x.dtype, device=x.device) 

83 

84 with torch_device_fn.device(x.device): 

85 dot_kernel_1[grid_1](x, y, mid, N) 

86 dot_kernel_2[(1,)](mid, out, max_mid_size, block_mid) 

87 

88 else: 

89 block_size = triton.next_power_of_2(N) 

90 

91 grid = (1, 1, 1) 

92 

93 out = torch.empty([], dtype=torch.float32, device=x.device) 

94 

95 with torch_device_fn.device(x.device): 

96 dot_kernel[grid](x, y, out, N, block_size) 

97 out = out.to(x.dtype) 

98 

99 return out