untested - Serbipunk/notes GitHub Wiki
pruning
PyTorch 剪枝(Pruning)的常用库与大纲
在 PyTorch 中,剪枝(Pruning)是一种常用的模型优化技术,用于减少模型参数量、加速推理并降低模型复杂度。以下是几种常用的 PyTorch 剪枝库及其主要功能概览:
1. torch.nn.utils.prune(官方库)
简介:
PyTorch 自带的剪枝工具位于 torch.nn.utils.prune,提供了基础剪枝功能,适合需要自定义剪枝策略的开发者。
主要功能:
• 剪枝类型:
• 全局剪枝(global pruning): 基于全模型参数的重要性进行剪枝。
• 局部剪枝(layer-wise pruning): 仅对指定层进行剪枝。
• 不规则剪枝(unstructured pruning): 移除单个权重。
• 结构化剪枝(structured pruning): 移除整个神经元或通道。
• 剪枝方法:
• 随机剪枝(RandomPruning): 随机移除部分权重。
• L1剪枝(L1Unstructured): 基于权重的 L1 范数进行排序后剪枝。
• 自定义剪枝: 支持用户定义剪枝规则。
常用代码示例:
import torch
import torch.nn.utils.prune as prune
# 创建模型
model = torch.nn.Linear(10, 5)
# 对权重进行剪枝
prune.random_unstructured(model, name="weight", amount=0.3)
# 检查剪枝后参数
print(model.weight) # 剪枝后的权重
print(model.weight_mask) # 剪枝的掩码
2. torch_pruning
简介:
torch_pruning 是一个第三方库,专注于剪枝的自动化和简化,支持多种剪枝策略,特别适合需要对复杂模型(如 ResNet、Transformer 等)进行结构化剪枝的场景。
主要功能:
• 剪枝类型:
• 支持不规则剪枝和结构化剪枝(如剪枝卷积通道)。
• 自动依赖分析:
• 自动计算剪枝后的层间依赖关系,确保剪枝不会破坏模型结构。
• 预训练权重支持:
• 剪枝后支持加载预训练模型权重。
常用代码示例:
import torch
import torch_pruning as tp
# 定义模型
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3),
torch.nn.ReLU(),
torch.nn.Conv2d(16, 32, 3)
)
# 剪枝工具
pruner = tp.pruner.MagnitudePruner()
pruner.prune(model, amount=0.2, prune_type='filter') # 剪枝20%的卷积滤波器
3. NNCF(Intel OpenVINO Neural Network Compression Framework)
简介:
NNCF 是 Intel 提供的神经网络压缩工具,集成了剪枝、量化和蒸馏功能,支持 PyTorch 和 TensorFlow。
主要功能:
• 支持不规则和结构化剪枝。
• 与 OpenVINO 推理优化结合,适合部署场景。
• 提供全面的压缩功能(剪枝、量化、蒸馏等)。
常用代码示例:
from nncf import NNCFConfig
from nncf import create_compressed_model
# 读取配置
config = NNCFConfig.from_json("config.json")
# 压缩模型
model = create_compressed_model(original_model, config)
4. SparseML
简介:
SparseML 是专为深度学习模型稀疏化(包括剪枝、量化等)设计的库,适合大规模稀疏模型的训练和部署。
主要功能:
• 支持渐进式剪枝(在训练过程中逐步剪枝)。
• 提供预定义剪枝策略模板,减少手动调整的复杂性。
• 与 Hugging Face Transformers 深度集成。
常用代码示例:
from sparseml.pytorch.optim import ScheduledModifierManager
manager = ScheduledModifierManager.from_yaml("pruning_schedule.yaml")
model, optimizer, lr_scheduler = manager.modify(model, optimizer, lr_scheduler)
5. PyTorch Lightning Pruning
简介:
PyTorch Lightning 通过扩展的方式支持剪枝,适合使用 Lightning 框架的用户,结合训练和剪枝的集成体验良好。
主要功能:
• 支持 Lightning 的 Callback 机制,通过钩子函数动态控制剪枝。
• 简化复杂模型的训练与剪枝过程。
常用代码示例:
from pytorch_lightning.callbacks import PruningCallback
from pytorch_lightning import Trainer
# 添加剪枝回调
trainer = Trainer(callbacks=[PruningCallback(amount=0.3)])
trainer.fit(model)
剪枝库对比总结
剪枝库 剪枝类型 特点 适用场景
torch.nn.utils.prune 不规则剪枝、结构化剪枝 原生库,灵活性高 适合简单模型和自定义剪枝策略
torch_pruning 不规则剪枝、通道剪枝 自动依赖分析,支持复杂模型 复杂网络(ResNet、Transformer)
NNCF 不规则剪枝、结构化剪枝 与 OpenVINO 集成,压缩全流程支持 部署优化
SparseML 渐进式剪枝 提供预定义模板,支持大模型 大规模模型稀疏化
PyTorch Lightning 动态剪枝 结合训练的剪枝支持 Lightning 项目
选择建议
• 如果是 简单剪枝 或 自定义规则,使用 PyTorch 官方库(torch.nn.utils.prune)。
• 如果需要 复杂模型剪枝,如卷积网络或 Transformer,建议选择 torch_pruning。
• 如果目标是 部署优化,可以使用 NNCF 或 SparseML。
• 使用 PyTorch Lightning 框架的项目,可以直接集成其剪枝功能。
剪枝的选择取决于模型的复杂度、使用场景和部署需求! 😊