Coverage for src/flag_gems/runtime/backend/_cambricon/ops/threshold.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9 

10from ..utils import TOTAL_CORE_NUM 

11from ..utils.pointwise_dynamic import pointwise_dynamic 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=[ 

19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), 

23 ], 

24 key=["n_elements"], 

25) 

26@triton.jit(do_not_specialize=["threshold_val", "value_val"]) 

27def threshold_kernel( 

28 X_ptr, 

29 OUT_ptr, 

30 n_elements, 

31 threshold_val, 

32 value_val, 

33 BLOCK_SIZE: tl.constexpr, 

34): 

35 pid = tl.program_id(0) 

36 num_jobs = tl.num_programs(0) 

37 block_start = pid * BLOCK_SIZE 

38 step = num_jobs * BLOCK_SIZE 

39 block_start = block_start.to(tl.int64) 

40 for off in range(block_start, n_elements, step): 

41 offsets = off + tl.arange(0, BLOCK_SIZE) 

42 mask = offsets < n_elements 

43 x = tl.load(X_ptr + offsets, mask=mask) 

44 result = tl.where(x > threshold_val, x, value_val) 

45 tl.store(OUT_ptr + offsets, result, mask=mask) 

46 

47 

48# keep backward using pointwise_dynamic 

49@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

50@triton.jit 

51def threshold_backward_kernel(grad_output, self, threshold): 

52 return tl.where(self > threshold, grad_output, 0) 

53 

54 

55def threshold(self, threshold_val, value_val): 

56 logger.debug("GEMS_CAMBRICON THRESHOLD FORWARD") 

57 A = self.contiguous() 

58 out = torch.empty_like(A) 

59 N = A.numel() 

60 if N == 0: 

61 return out 

62 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

63 with torch_device_fn.device(A.device): 

64 threshold_kernel[grid_fn](A, out, N, threshold_val, value_val) 

65 return out 

66 

67 

68def threshold_backward(grad_output, self, threshold_val): 

69 logger.debug("GEMS_CAMBRICON THRESHOLD BACKWARD") 

70 return threshold_backward_kernel(grad_output, self, threshold_val)