Coverage for src/flag_gems/ops/cauchy.py: 53%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12from flag_gems.utils.shape_utils import volume 

13 

14logger = logging.getLogger(__name__) 

15 

16PI = tl.constexpr(3.14159265358979323846) 

17 

18 

19@triton.jit 

20def uniform_to_cauchy(u, median, sigma): 

21 # Transform uniform [0, 1) to Cauchy using inverse CDF 

22 # X = median + sigma * tan(PI * (u - 0.5)) 

23 # Clamp u to avoid tan(±PI/2) which is undefined 

24 u = tl.maximum(1.0e-7, u) 

25 u = tl.minimum(1.0 - 1.0e-7, u) 

26 # tan(x) = sin(x) / cos(x) 

27 angle = PI * (u - 0.5) 

28 return median + sigma * (tl.sin(angle) / tl.cos(angle)) 

29 

30 

31# @triton.heuristics(runtime.get_heuristic_config("cauchy")) 

32configs = [ 

33 triton.Config({"BLOCK": 256}, num_warps=8, num_stages=2), 

34 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=2), 

35 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

36 triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), 

37 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3), 

38 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=4), 

39] 

40 

41 

42@triton.autotune(configs=configs, key=["N"]) 

43@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "median", "sigma"]) 

44def cauchy_kernel( 

45 out_ptr, 

46 N, 

47 median, 

48 sigma, 

49 philox_seed, 

50 philox_offset, 

51 BLOCK: tl.constexpr, 

52): 

53 philox_seed = philox_seed.to(tl.int64) 

54 philox_offset = philox_offset.to(tl.int64) 

55 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

56 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

57 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

58 c0 += i4 

59 _O = c0 * 0 

60 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

61 r0 = uint_to_uniform_float(r0) 

62 r1 = uint_to_uniform_float(r1) 

63 r2 = uint_to_uniform_float(r2) 

64 r3 = uint_to_uniform_float(r3) 

65 c0 = uniform_to_cauchy(r0, median, sigma) 

66 c1 = uniform_to_cauchy(r1, median, sigma) 

67 c2 = uniform_to_cauchy(r2, median, sigma) 

68 c3 = uniform_to_cauchy(r3, median, sigma) 

69 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

70 off_1 = off_0 + BLOCK 

71 off_2 = off_1 + BLOCK 

72 off_3 = off_2 + BLOCK 

73 

74 tl.store(out_ptr + off_0, c0, mask=off_0 < N, eviction_policy="evict_first") 

75 tl.store(out_ptr + off_1, c1, mask=off_1 < N, eviction_policy="evict_first") 

76 tl.store(out_ptr + off_2, c2, mask=off_2 < N, eviction_policy="evict_first") 

77 tl.store(out_ptr + off_3, c3, mask=off_3 < N, eviction_policy="evict_first") 

78 

79 

80UNROLL = 4 

81 

82 

83def cauchy_(self, median=0, sigma=1, *, generator=None): 

84 """ 

85 In-place Cauchy distribution sampler. 

86 

87 Fills self with elements drawn from the Cauchy distribution: 

88 f(x) = 1 / (π * sigma * (1 + ((x - median) / sigma)^2)) 

89 

90 Uses inverse transform sampling: X = median + sigma * tan(π * (U - 0.5)) 

91 where U ~ Uniform(0, 1). 

92 """ 

93 logger.debug("GEMS CAUCHY_") 

94 shape = self.shape 

95 device = self.device 

96 N = volume(shape) 

97 if N == 0: 

98 return self 

99 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

100 increment = triton.cdiv(N, UNROLL) 

101 philox_seed, philox_offset = philox_backend_seed_offset( 

102 increment, generator=generator 

103 ) 

104 with torch_device_fn.device(device): 

105 cauchy_kernel[grid_fn](self, N, median, sigma, philox_seed, philox_offset) 

106 return self 

107 

108 

109def cauchy(self, median=0, sigma=1, *, generator=None): 

110 """ 

111 Out-of-place Cauchy distribution sampler. 

112 

113 Returns a new tensor with elements drawn from the Cauchy distribution. 

114 """ 

115 logger.debug("GEMS CAUCHY") 

116 out = torch.empty_like(self) 

117 cauchy_(out, median, sigma, generator=generator) 

118 return out