Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/exponential_.py: 0%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.xpu.libdevice import log2 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16CLUSTER_NUM = 12 

17 

18 

19def heur_block(args): 

20 N = args.get("N", 0) 

21 if N <= 4096: 

22 return 256 

23 elif N <= 65536: 

24 return 512 

25 else: 

26 return 1024 

27 

28 

29def heur_num_warps(args): 

30 N = args.get("N", 0) 

31 if N <= 4096: 

32 return 4 

33 elif N <= 65536: 

34 return 8 

35 else: 

36 return 16 

37 

38 

39@triton.heuristics( 

40 { 

41 "BLOCK": heur_block, 

42 "num_warps": heur_num_warps, 

43 } 

44) 

45@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

46def fused_exponential_kernel( 

47 out_ptr, 

48 N, 

49 is_double: tl.constexpr, 

50 lambd, 

51 eps, 

52 philox_seed, 

53 philox_offset, 

54 BLOCK: tl.constexpr, 

55): 

56 philox_seed = philox_seed.to(tl.int64) 

57 philox_offset = philox_offset.to(tl.int64) 

58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

60 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

61 c0 += i4 

62 _O = c0 * 0 

63 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

64 if is_double: 

65 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

66 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

67 y0 = transform_exponential(d0, lambd, eps) 

68 y1 = transform_exponential(d1, lambd, eps) 

69 UNROLL = 2 

70 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

71 off_0 = start + tl.arange(0, BLOCK) 

72 off_1 = off_0 + BLOCK 

73 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

74 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

75 else: 

76 f0 = uint_to_uniform_float(r0) 

77 f1 = uint_to_uniform_float(r1) 

78 f2 = uint_to_uniform_float(r2) 

79 f3 = uint_to_uniform_float(r3) 

80 y0 = transform_exponential(f0, lambd, eps) 

81 y1 = transform_exponential(f1, lambd, eps) 

82 y2 = transform_exponential(f2, lambd, eps) 

83 y3 = transform_exponential(f3, lambd, eps) 

84 UNROLL = 4 

85 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

86 off_0 = start + tl.arange(0, BLOCK) 

87 off_1 = off_0 + BLOCK 

88 off_2 = off_1 + BLOCK 

89 off_3 = off_2 + BLOCK 

90 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

91 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

92 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

93 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

94 

95 

96@triton.jit 

97def paste_u64(hi: tl.uint32, lo: tl.uint32): 

98 hi = hi.to(tl.uint64) << 32 

99 x = hi | lo.to(tl.uint64) 

100 return x 

101 

102 

103@triton.jit 

104def transform_exponential(u, lambd, eps): 

105 eps1 = -0.5 * eps 

106 is_min = u >= 1.0 + eps1 

107 trans_scale = 1.0 / 1.4426950408889634 

108 log = tl.where(is_min, eps1, log2(u) * trans_scale) 

109 v = -1.0 / lambd * log 

110 return v 

111 

112 

113def exponential_(x, lambd: float = 1.0, *, generator=None): 

114 logger.debug("GEMS_KUNLUNXIN EXPONENTIAL_") 

115 dtype = x.dtype 

116 device = x.device 

117 inplace = x.is_contiguous() 

118 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

119 is_double = dtype in (torch.float64,) 

120 UNROLL = 2 if is_double else 4 

121 N = x.numel() 

122 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

123 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

124 # hence we cannot obtain the per thread offset as in Pytorch. 

125 increment = triton.cdiv(N, UNROLL) 

126 philox_seed, philox_offset = philox_backend_seed_offset( 

127 increment, generator=generator 

128 ) 

129 eps = torch.finfo(dtype).eps 

130 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

131 with torch_device_fn.device(device): 

132 fused_exponential_kernel[grid_fn]( 

133 x_, N, is_double, lambd, eps, philox_seed, philox_offset 

134 ) 

135 if not inplace: 

136 x.copy_(x_) 

137 return x