Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/normal.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils.random_utils import philox_backend_seed_offset
8from flag_gems.utils.shape_utils import broadcast_shapes, volume
10from ..utils.pointwise_dynamic import pointwise_dynamic
11from .randn import randn_kernel
13logger = logging.getLogger(__name__)
14UNROLL = 4
17@pointwise_dynamic(
18 is_tensor=[True, True, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
19)
20@triton.jit
21def transform_func_tensor_tensor(val, std, mean):
22 return val * std + mean
25@pointwise_dynamic(
26 is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
27)
28@triton.jit
29def transform_func_tensor_float(val, std, mean):
30 return val * std + mean
33@pointwise_dynamic(
34 is_tensor=[True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
35)
36@triton.jit
37def transform_func_float_tensor(val, std, mean):
38 return val * std + mean
41@pointwise_dynamic(
42 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
43)
44@triton.jit
45def transform_func_float_float(val, std, mean):
46 return val * std + mean
49def normal_distribution(shape, device, *, generator=None, out=None):
50 if out is None:
51 out = torch.empty(shape, device=device, dtype=torch.float32)
52 N = volume(shape)
53 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
55 increment = triton.cdiv(N, UNROLL)
56 philox_seed, philox_offset = philox_backend_seed_offset(
57 increment, generator=generator
58 )
59 with torch_device_fn.device(device):
60 randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
61 return out
64def normal_tensor_tensor(mean, std, *, generator=None):
65 logger.debug("GEMS_TSINGMICRO NORMAL_TENSOR_TENSOR")
66 shape = broadcast_shapes([mean.shape, std.shape])
67 device = mean.device
68 out = normal_distribution(shape, device)
69 return transform_func_tensor_tensor(out, std, mean)
72def normal_tensor_float(mean, std, *, generator=None):
73 logger.debug("GEMS_TSINGMICRO NORMAL_TENSOR_FLOAT")
74 shape = mean.shape
75 device = mean.device
76 out = normal_distribution(shape, device)
77 return transform_func_tensor_float(out, std, mean)
80def normal_float_tensor(mean, std, *, generator=None):
81 logger.debug("GEMS_TSINGMICRO NORMAL_FLOAT_TENSOR")
82 shape = std.shape
83 device = std.device
84 out = normal_distribution(shape, device)
85 return transform_func_float_tensor(out, std, mean)
88def normal_(self, mean=0, std=1, *, generator=None):
89 logger.debug("GEMS_TSINGMICRO NORMAL_")
90 shape = self.shape
91 device = self.device
92 self = normal_distribution(shape, device, generator=None, out=self)
93 transform_func_float_float(self, std, mean, out0=self)
94 return self