选择性启用算子

选择性启用算子#

在启用 FlagGems 库时,你可以使用若干可选的参数来精细控制在你的应用中如何使用算子加速。 这些参数的存在使得用户能够很灵活地实现各种集成任务,并且在工作流很复杂的情况下, 也可以很方便地进行故障调试和性能分析。

目前,FlagGems 提供三种方式供你有选择地启用或者禁用某些算子。

  • API 接口 flag_gems.enable() 可以接受一个 unused 参数。

    参数名称

    数据类型

    描述

    unused

    List[str]

    禁用特定的算子

    使用这个参数,你可以有选择地禁用某些算子,尤其是某些算子在你的平台上表现不及预期时。 例如,下面的代码会启用 sumadd 之外的所有算子。换言之,参数所列出的算子不会被 FlagGems 加速。当应用调用到这些算子时,会自动回退到 PyTorch 原生的算子实现。

    flag_gems.enable(unused=["sum", "add"])
    
  • 接口 flag_gems.only_enable() 可以接受一个 include 参数,如下所示。

    参数名称

    数据类型

    描述

    include

    List[str]

    显式地启用指定的算子

    当指定了 include 参数时,只有参数值中所列出的算子会在 FlagGems 中被注册以执行加速版本。所有其他算子都会使用 PyTorch 原生的实现。

    flag_gems.only_enable(include=["rms_norm", "softmax"])
    

    [!WARNING] 警告

    API 接口 only_enable() 是实验性质的,可能在未来版本中被移除。 请谨慎使用。

  • 除此之外,还有另外一种方式来选择性地启用算子,那就是使用 use_gems() 上下文管理器。use_gems() 上下文管理器支持下面所列的两个参数, 用来选择性地启用、禁用算子。

    参数名称

    数据类型

    描述

    include

    List[str]

    显式地启用指定的算子

    exclude

    List[str]

    显式地禁用指定的算子

    如果设置了 include 参数,其行为与 only_enable(include=...) 接口的行为完全相同。 类似的,如果设置了 exclude 参数,其行为与 enable(unused=...) 接口的行为一致。 例如,下面的代码仅启用 FlagGems 中对 sumand 算子的加速:

    # 仅在给定范围内启用指定的算子
    with flag_gems.use_gems(include=["sum", "add"]):
        # 只有 sum 和 add 会使用加速版本
        ...
    

    下面的代码会启用 FlagGems所有的加速算子,除了要排除的 muldiv 之外:

    with flag_gems.use_gems(exclude=["mul", "div"]):
        # mul 和 div 算子之外的所有算子都会被加速
        ...
    

    [!TIP] 提示

    参数 include 的优先级高于参数 exclude。 如果两个参数都被指定,则 exclude 参数的设置会被忽略。