Coverage for src/flag_gems/runtime/backend/_spacemit/ops/softmax.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.softmax import softmax_backward as common_softmax_backward
8from flag_gems.utils import tl_extra_shim
10logger = logging.getLogger(__name__)
11exp = tl_extra_shim.exp
14@triton.jit
15def softmax_kernel_spacemit(
16 output_ptr,
17 input_ptr,
18 input_row_stride,
19 output_row_stride,
20 n_rows,
21 n_cols,
22 ROW_SIZE: tl.constexpr,
23 COL_SIZE: tl.constexpr,
24):
25 row_start = tl.program_id(0) * ROW_SIZE
26 element_ty = output_ptr.type.element_ty
28 for row_idx in range(row_start, row_start + ROW_SIZE):
29 if row_idx < n_rows:
30 denominator = tl.zeros((1,), dtype=tl.float32)
31 row_max = tl.full((COL_SIZE,), value=-float("inf"), dtype=tl.float32)
33 for col_idx in range(0, n_cols, COL_SIZE):
34 input_block_ptr = tl.make_block_ptr(
35 base=input_ptr + row_idx * input_row_stride,
36 shape=(n_cols,),
37 strides=(1,),
38 offsets=(col_idx,),
39 block_shape=(COL_SIZE,),
40 order=(0,),
41 )
42 row = tl.load(
43 input_block_ptr, boundary_check=(0,), padding_option="neg_inf"
44 ).to(tl.float32)
45 row_max = tl.maximum(row, row_max)
47 row_max_total = tl.max(row_max, axis=0)
49 for col_idx in range(0, n_cols, COL_SIZE):
50 input_block_ptr = tl.make_block_ptr(
51 base=input_ptr + row_idx * input_row_stride,
52 shape=(n_cols,),
53 strides=(1,),
54 offsets=(col_idx,),
55 block_shape=(COL_SIZE,),
56 order=(0,),
57 )
58 output_block_ptr = tl.make_block_ptr(
59 base=output_ptr + row_idx * output_row_stride,
60 shape=(n_cols,),
61 strides=(1,),
62 offsets=(col_idx,),
63 block_shape=(COL_SIZE,),
64 order=(0,),
65 )
66 row = tl.load(
67 input_block_ptr, boundary_check=(0,), padding_option="neg_inf"
68 ).to(tl.float32)
69 numerator = exp(row - row_max_total)
70 denominator += tl.sum(numerator, axis=0)
71 tl.store(
72 output_block_ptr, numerator.to(element_ty), boundary_check=(0,)
73 )
75 inv_denom = 1.0 / denominator
76 for col_idx in range(0, n_cols, COL_SIZE):
77 output_block_ptr = tl.make_block_ptr(
78 base=output_ptr + row_idx * output_row_stride,
79 shape=(n_cols,),
80 strides=(1,),
81 offsets=(col_idx,),
82 block_shape=(COL_SIZE,),
83 order=(0,),
84 )
85 exp_out = tl.load(output_block_ptr, boundary_check=(0,)).to(tl.float32)
86 tl.store(
87 output_block_ptr,
88 (exp_out * inv_denom).to(element_ty),
89 boundary_check=(0,),
90 )
93def _spacemit_softmax_lastdim(inp, out):
94 n_rows, n_cols = inp.shape
95 row_size = 1 if n_rows < 2 else (2 if n_rows < 8 else 4)
96 col_size = 64
97 grid = lambda meta: (triton.cdiv(n_rows, meta["ROW_SIZE"]),)
98 softmax_kernel_spacemit[grid](
99 out,
100 inp,
101 inp.stride(0),
102 out.stride(0),
103 n_rows,
104 n_cols,
105 ROW_SIZE=row_size,
106 COL_SIZE=col_size,
107 )
110def softmax(self, dim, half_to_float=False):
111 logger.debug("GEMS_SPACEMIT SOFTMAX")
113 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
114 dim = dim % self.ndim
116 if half_to_float:
117 dtype = torch.float32
118 else:
119 dtype = self.dtype
121 inp = self.contiguous()
123 n_cols = inp.shape[-1]
124 n_rows = inp.numel() // n_cols
125 inp_2d = inp.view(n_rows, n_cols)
126 out_2d = torch.empty_like(inp_2d, dtype=dtype)
127 _spacemit_softmax_lastdim(inp_2d, out_2d)
128 return out_2d.view_as(inp)
131def softmax_backward(grad_output, output, dim, input_dtype):
132 logger.debug("GEMS_SPACEMIT SOFTMAX_VJP")
133 return common_softmax_backward(grad_output, output, dim, input_dtype)