Coverage for src/flag_gems/runtime/backend/_arm/ops/exponential_.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12 

13 

14@triton.heuristics(runtime.get_heuristic_config("exponential_")) 

15@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

16def fused_exponential_kernel( 

17 out_ptr, 

18 N, 

19 is_double, 

20 lambd, 

21 eps, 

22 philox_seed, 

23 philox_offset, 

24 BLOCK: tl.constexpr, 

25): 

26 philox_seed = philox_seed.to(tl.int64) 

27 philox_offset = philox_offset.to(tl.int64) 

28 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

29 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

30 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

31 c0 += i4 

32 _O = c0 * 0 

33 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

34 if is_double: 

35 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

36 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

37 y0 = transform_exponential(d0, lambd, eps) 

38 y1 = transform_exponential(d1, lambd, eps) 

39 UNROLL = 2 

40 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

41 off_0 = start + tl.arange(0, BLOCK) 

42 off_1 = off_0 + BLOCK 

43 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

44 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

45 else: 

46 f0 = uint_to_uniform_float(r0) 

47 f1 = uint_to_uniform_float(r1) 

48 f2 = uint_to_uniform_float(r2) 

49 f3 = uint_to_uniform_float(r3) 

50 y0 = transform_exponential(f0, lambd, eps) 

51 y1 = transform_exponential(f1, lambd, eps) 

52 y2 = transform_exponential(f2, lambd, eps) 

53 y3 = transform_exponential(f3, lambd, eps) 

54 UNROLL = 4 

55 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

56 off_0 = start + tl.arange(0, BLOCK) 

57 off_1 = off_0 + BLOCK 

58 off_2 = off_1 + BLOCK 

59 off_3 = off_2 + BLOCK 

60 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

61 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

62 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

63 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

64 

65 

66@triton.jit 

67def paste_u64(hi: tl.uint32, lo: tl.uint32): 

68 hi = hi.to(tl.uint64) << 32 

69 x = hi | lo.to(tl.uint64) 

70 return x 

71 

72 

73@triton.jit 

74def transform_exponential(u, lambd, eps): 

75 eps1 = -0.5 * eps 

76 is_min = u >= 1.0 + eps1 

77 log = tl.where(is_min, eps1, tl.math.log(u)) 

78 v = -1.0 / lambd * log 

79 return v 

80 

81 

82def exponential_(x, lambd: float = 1.0, *, gen=None): 

83 logging.debug("GEMS EXPONENTIAL_") 

84 dtype = x.dtype 

85 device = x.device 

86 inplace = x.is_contiguous() 

87 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

88 is_double = dtype in (torch.float64,) 

89 UNROLL = 2 if is_double else 4 

90 N = x.numel() 

91 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

92 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

93 # hence we cannot obtain the per thread offset as in Pytorch. 

94 increment = triton.cdiv(N, UNROLL) 

95 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

96 eps = torch.finfo(dtype).eps 

97 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

98 # with torch_device_fn.device(device): 

99 fused_exponential_kernel[grid_fn]( 

100 x_, N, is_double, lambd, eps, philox_seed, philox_offset 

101 ) 

102 if not inplace: 

103 x.copy_(x_) 

104 return x