天下苦英伟达久矣!PyTorch官方免CUDA加速推理,Triton时代要来?

今日应用


今日话题


天下苦英伟达久矣!PyTorch官方免CUDA加速推理,Triton时代要来?
天下苦英伟达久矣!PyTorch官方免CUDA加速推理,Triton时代要来?
 

重点标签 PyTorchCUDATriton机器学习大模型推理

文章摘要


PyTorch 官方近期分享了如何实现无 CUDA 计算,并对各个内核进行了微基准测试比较,讨论了未来如何进一步改进 Triton 内核以缩小与 CUDA 的差距。在大语言模型(LLM)的训练、微调和推理中,使用英伟达的 GPU 和 CUDA 是常见的做法,但其他工作如 OpenAI 推出的 Triton 正在向 CUDA 发起挑战。PyTorch 官宣要做「无英伟达 CUDA 参与的大模型推理」,表示 Triton 提供了一条途径,使大模型能够在不同类型的 GPU 上运行,包括英伟达、AMD、英特尔和其他基于 GPU 的加速器。此外,Triton 还在 Python 中为 GPU 编程提供了更高的抽象层,使得使用 PyTorch 能够比使用供应商特定的 API 更快地编写高性能内核。

在 PyTorch 博客中,讨论了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中计算是 100% 使用 OpenAI 的 Triton 语言执行的。对于使用基于 Triton 内核的模型生成单个 token 的时间,PyTorch 能够实现在英伟达 H100 GPU 上 Llama 和 Granite 的 CUDA 内核主导工作流程的 0.76-0.78 倍性能,以及在英伟达 A100 GPU 上的 0.62-0.82 倍。

PyTorch 团队首先对基于 Transformer 的模型中发生的计算进行细分,显示了典型 Transformer 块的「内核(kernel)」。Llama3 架构的核心操作包括均方根归一化(RMSNorm)、矩阵乘法:Fused QKV、RoPE、注意力、矩阵乘法:输出投影、RMSNorm、矩阵乘法:Fused Gate + Up Projection、激活函数:SiLU、点乘(Element Wise Multiplication)、矩阵乘法:Down Projection 等。这些操作中的每一个都是通过在 GPU 上执行一个(或多个)内核来计算的。

为了实现 100% Triton 进行端到端 Llama3-8B 和 Granite-8B 推理,需要编写和集成手写 Triton 内核以及利用 torch.compile(生成 Triton 操作)。Torch.compile 自动为 RMSNorm、RoPE、SiLU 和点乘生成 Triton 内核。对于上面的跟踪,PyTorch 团队注意到,在 Llama3-8B 样式模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了弥补剩余的差距,PyTorch 团队用手写的 Triton 内核替换了 matmul 和注意力内核。

对于线性层中的矩阵乘法,PyTorch 团队编写了一个自定义 FP16 Triton GEMM(通用矩阵 – 矩阵乘法)内核,该内核利用了 SplitK 工作分解。为了实现最佳性能,PyTorch 团队使用穷举搜索方法来调整 SplitK GEMM 内核。在对每个线性层进行调整后,PyTorch 能够在 Llama3-8B 和 Granite-8B 上实现相对于未调整的 Triton 内核 1.20 倍的 E2E 加速。

PyTorch 团队使用不同的配置,对现有 Triton flash attention 内核进行了评估,包括 AMD Flash、OpenAI Flash、Dao AI Lab Flash、XFormers Flash、PyTorch FlexAttention。PyTorch 团队分别在 eager 模式和编译模式下评估了每个内核的文本生成质量。为了满足 torch.compile 与 AMD flash attention 内核的兼容性,PyTorch 团队必须将它定义为 torch 自定义算子。并且封装更复杂的 flash attention 内核遵循以下两个步骤:一是将函数封装为一个 PyTorch 自定义算子;二是向该算子添加一个 FakeTensor 内核,并在给定 flash 输入张量的形状(q、k 和 v)时,计算 flash 内核的输出形状。

从图中可以看到,在集成 SplitK 矩阵乘法内核后,torch op 封装 flash attention 内核,然后运行 torch.compile,即可实现使用 100% Triton 计算内核的前向传递。总的来说,在 H100 上,Triton 模型最高可以达到 CUDA 模型性能的 78%;在 A100 上可以达到 82%。这些性能差距是由[PyTorch/CUDA/Triton/机器学习/大模型推理]/[CUDA/GPU/机器学习/PyTorch/Triton]/[PyTorch/CUDA/机器学习/大模型/推理]/[CUDA/GPU/大模型/机器学习/PyTorch]/[PyTorch/CUDA/大模型/机器学习/推理]

PyTorch 官方探索无 CUDA 计算
PyTorch 官方近期分享了如何实现无 CUDA 计算,并对各个内核进行了微基准测试比较,讨论了未来如何进一步改进 Triton 内核以缩小与 CUDA 的差距。在大语言模型(LLM)的训练、微调和推理中,使用英伟达的 GPU 和 CUDA 是常见的做法,但其他工作如 OpenAI 推出的 Triton 正在向 CUDA 发起挑战。PyTorch 官宣要做「无英伟达 CUDA 参与的大模型推理」,表示 Triton 提供了一条途径,使大模型能够在不同类型的 GPU 上运行,包括英伟达、AMD、英特尔和其他基于 GPU 的加速器。此外,Triton 还在 Python 中为 GPU 编程提供了更高的抽象层,使得使用 PyTorch 能够比使用供应商特定的 API 更快地编写高性能内核。

大模型推理的 Triton 实现
在 PyTorch 博客中,讨论了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中计算是 100% 使用 OpenAI 的 Triton 语言执行的。对于使用基于 Triton 内核的模型生成单个 token 的时间,PyTorch 能够实现在英伟达 H100 GPU 上 Llama 和 Granite 的 CUDA 内核主导工作流程的 0.76-0.78 倍性能,以及在英伟达 A100 GPU 上的 0.62-0.82 倍。

Transformer 块的组成与模型推理
PyTorch 团队首先对基于 Transformer 的模型中发生的计算进行细分,显示了典型 Transformer 块的「内核(kernel)」。Llama3 架构的核心操作包括均方根归一化(RMSNorm)、矩阵乘法:Fused QKV、RoPE、注意力、矩阵乘法:输出投影、RMSNorm、矩阵乘法:Fused Gate + Up Projection、激活函数:SiLU、点乘(Element Wise Multiplication)、矩阵乘法:Down Projection 等。这些操作中的每一个都是通过在 GPU 上执行一个(或多个)内核来计算的。

为了实现 100% Triton 进行端到端 Llama3-8B 和 Granite-8B 推理,需要编写和集成手写 Triton 内核以及利用 torch.compile(生成 Triton 操作)。Torch.compile 自动为 RMSNorm、RoPE、SiLU 和点乘生成 Triton 内核。对于上面的跟踪,PyTorch 团队注意到,在 Llama3-8B 样式模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了弥补剩余的差距,PyTorch 团队用手写的 Triton 内核替换了 matmul 和注意力内核。

Triton SplitK GEMM 内核与调优
对于线性层中的矩阵乘法,PyTorch 团队编写了一个自定义 FP16 Triton GEMM(通用矩阵 – 矩阵乘法)内核,该内核利用了 SplitK 工作分解。为了实现最佳性能,PyTorch 团队使用穷举搜索方法来调整 SplitK GEMM 内核。在对每个线性层进行调整后,PyTorch 能够在 Llama3-8B 和 Granite-8B 上实现相对于未调整的 Triton 内核 1.20 倍的 E2E 加速。

Flash Attention 内核的评估与封装
PyTorch 团队使用不同的配置,对现有 Triton flash attention 内核进行了评估,包括 AMD Flash、OpenAI Flash、Dao AI Lab Flash、XFormers Flash、PyTorch FlexAttention。PyTorch 团队分别在 eager 模式和编译模式下评估了每个内核的文本生成质量。为了满足 torch.compile 与 AMD flash attention 内核的兼容性,PyTorch 团队必须将它定义为 torch 自定义算子。并且封装更复杂的 flash attention 内核遵循以下两个步骤:一是将函数封装为一个 PyTorch 自定义算子;二是向该算子添加一个 FakeTensor 内核,并在给定 flash 输入张量的形状(q、k 和 v)时,计算 flash 内核的输出形状。

端到端基准测试与微基准测试
从图中可以看到,在集成 SplitK 矩阵乘法内核后PyTorch 官方探索无 CUDA 计算

PyTorch 官方近期分享了如何实现无 CUDA 计算,并对各个内核进行了微基准测试比较,讨论了未来如何进一步改进 Triton 内核以缩小与 CUDA 的差距。在大语言模型(LLM)的训练、微调和推理中,使用英伟达的 GPU 和 CUDA 是常见的做法,但其他工作如 OpenAI 推出的 Triton 正在向 CUDA 发起挑战。PyTorch 官宣要做「无英伟达 CUDA 参与的大模型推理」,表示 Triton 提供了一条途径,使大模型能够在不同类型的 GPU 上运行,包括英伟达、AMD、英特尔和其他基于 GPU 的加速器。此外,Triton 还在 Python 中为 GPU 编程提供了更高的抽象层,使得使用 PyTorch 能够比使用供应商特定的 API 更快地编写高性能内核。

大模型推理的 Triton 实现

在 PyTorch 博客中,讨论了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中计算是 100% 使用 OpenAI 的 Triton 语言执行的。对于使用基于 Triton 内核的模型生成单个 token 的时间,PyTorch 能够实现在英伟达 H100 GPU 上 Llama 和 Granite 的 CUDA 内核主导工作流程的 0.76-0.78 倍性能,以及在英伟达 A100 GPU 上的 0.62-0.82 倍。

Transformer 块的组成与模型推理

PyTorch 团队首先对基于 Transformer 的模型中发生的计算进行细分,显示了典型 Transformer 块的「内核(kernel)」。Llama3 架构的核心操作包括均方根归一化(RMSNorm)、矩阵乘法:Fused QKV、RoPE、注意力、矩阵乘法:输出投影、RMSNorm、矩阵乘法:Fused Gate + Up Projection、激活函数:SiLU、点乘(Element Wise Multiplication)、矩阵乘法:Down Projection 等。这些操作中的每一个都是通过在 GPU 上执行一个(或多个)内核来计算的。

为了实现 100% Triton 进行端到端 Llama3-8B 和 Granite-8B 推理,需要编写和集成手写 Triton 内核以及利用 torch.compile(生成 Triton 操作)。Torch.compile 自动为 RMSNorm、RoPE、SiLU 和点乘生成 Triton 内核。对于上面的跟踪,PyTorch 团队注意到,在 Llama3-8B 样式模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了弥补剩余的差距,PyTorch 团队用手写的 Triton 内核替换了 matmul 和注意力内核。

Triton SplitK GEMM 内核与调优

对于线性层中的矩阵乘法,PyTorch 团队编写了一个自定义 FP16 Triton GEMM(通用矩阵 – 矩阵乘法)内核,该内核利用了 SplitK 工作分解。为了实现最佳性能,PyTorch 团队使用穷举搜索方法来调整 SplitK GEMM 内核。在对每个线性层进行调整后,PyTorch 能够在 Llama3-8B 和 Granite-8B 上实现相对于未调整的 Triton 内核 1.20 倍的 E2E 加速。

Flash Attention 内核的评估与封装

PyTorch 团队使用不同的配置,对现有 Triton flash attention 内核进行了评估,包括 AMD Flash、OpenAI Flash、Dao AI Lab Flash、XFormers Flash、PyTorch FlexAttention。PyTorch 团队分别在 eager 模式和编译模式下评估了每个内核的文本生成质量。为了满足 torch.compile 与 AMD flash attention 内核的兼容性,PyTorch 团队必须将它定义为 torch 自定义算子。并且封装更复杂的 flash attention 内核遵循以下两个步骤:一是将函数封装为一个 PyTorch 自定义算子;二是向该算子添加一个 FakeTensor 内核,并在给定 flash 输入张量的形状(q、k 和 v)时,计算 flash 内核的输出形状。

端到端基准测试与微基准测试

从图中可以看到,在集成 SplitK 矩阵乘法内核后,torch op 封装 flash attention 内核,然后运行 torch.compile,即可实现使用 100% Triton 计算内核的前向传递。总的来说,在 H100 上,Triton 模型最高可以达到 CUDA 模型性能的 78%;在 A100 上可以达到 82%。这些性能差距是由 matmul 和 flash attention 的内核延迟造成的。

文章来源


原文地址: 点我阅读全文
原文作者: 机器之心

© 版权声明

相关文章

暂无评论

暂无评论...