untested - Serbipunk/notes GitHub Wiki

pruning

PyTorch 剪枝(Pruning)的常用库与大纲

在 PyTorch 中,剪枝(Pruning)是一种常用的模型优化技术,用于减少模型参数量、加速推理并降低模型复杂度。以下是几种常用的 PyTorch 剪枝库及其主要功能概览:

1. torch.nn.utils.prune(官方库)

简介:

PyTorch 自带的剪枝工具位于 torch.nn.utils.prune,提供了基础剪枝功能,适合需要自定义剪枝策略的开发者。

主要功能:

	•	剪枝类型:
	•	全局剪枝(global pruning): 基于全模型参数的重要性进行剪枝。
	•	局部剪枝(layer-wise pruning): 仅对指定层进行剪枝。
	•	不规则剪枝(unstructured pruning): 移除单个权重。
	•	结构化剪枝(structured pruning): 移除整个神经元或通道。
	•	剪枝方法:
	•	随机剪枝(RandomPruning): 随机移除部分权重。
	•	L1剪枝(L1Unstructured): 基于权重的 L1 范数进行排序后剪枝。
	•	自定义剪枝: 支持用户定义剪枝规则。

常用代码示例:

import torch
import torch.nn.utils.prune as prune

# 创建模型
model = torch.nn.Linear(10, 5)

# 对权重进行剪枝
prune.random_unstructured(model, name="weight", amount=0.3)

# 检查剪枝后参数
print(model.weight)  # 剪枝后的权重
print(model.weight_mask)  # 剪枝的掩码

2. torch_pruning

简介:

torch_pruning 是一个第三方库,专注于剪枝的自动化和简化,支持多种剪枝策略,特别适合需要对复杂模型(如 ResNet、Transformer 等)进行结构化剪枝的场景。

主要功能:

	•	剪枝类型:
	•	支持不规则剪枝和结构化剪枝(如剪枝卷积通道)。
	•	自动依赖分析:
	•	自动计算剪枝后的层间依赖关系,确保剪枝不会破坏模型结构。
	•	预训练权重支持:
	•	剪枝后支持加载预训练模型权重。

常用代码示例:

import torch
import torch_pruning as tp

# 定义模型
model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, 3),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, 3)
)

# 剪枝工具
pruner = tp.pruner.MagnitudePruner()
pruner.prune(model, amount=0.2, prune_type='filter')  # 剪枝20%的卷积滤波器

3. NNCF(Intel OpenVINO Neural Network Compression Framework)

简介:

NNCF 是 Intel 提供的神经网络压缩工具,集成了剪枝、量化和蒸馏功能,支持 PyTorch 和 TensorFlow。

主要功能:

	•	支持不规则和结构化剪枝。
	•	与 OpenVINO 推理优化结合,适合部署场景。
	•	提供全面的压缩功能(剪枝、量化、蒸馏等)。

常用代码示例:

from nncf import NNCFConfig
from nncf import create_compressed_model

# 读取配置
config = NNCFConfig.from_json("config.json")

# 压缩模型
model = create_compressed_model(original_model, config)

4. SparseML

简介:

SparseML 是专为深度学习模型稀疏化(包括剪枝、量化等)设计的库,适合大规模稀疏模型的训练和部署。

主要功能:

	•	支持渐进式剪枝(在训练过程中逐步剪枝)。
	•	提供预定义剪枝策略模板,减少手动调整的复杂性。
	•	与 Hugging Face Transformers 深度集成。

常用代码示例:

from sparseml.pytorch.optim import ScheduledModifierManager

manager = ScheduledModifierManager.from_yaml("pruning_schedule.yaml")
model, optimizer, lr_scheduler = manager.modify(model, optimizer, lr_scheduler)

5. PyTorch Lightning Pruning

简介:

PyTorch Lightning 通过扩展的方式支持剪枝,适合使用 Lightning 框架的用户,结合训练和剪枝的集成体验良好。

主要功能:

	•	支持 Lightning 的 Callback 机制,通过钩子函数动态控制剪枝。
	•	简化复杂模型的训练与剪枝过程。

常用代码示例:

from pytorch_lightning.callbacks import PruningCallback
from pytorch_lightning import Trainer

# 添加剪枝回调
trainer = Trainer(callbacks=[PruningCallback(amount=0.3)])
trainer.fit(model)

剪枝库对比总结

剪枝库	剪枝类型	特点	适用场景
torch.nn.utils.prune	不规则剪枝、结构化剪枝	原生库,灵活性高	适合简单模型和自定义剪枝策略
torch_pruning	不规则剪枝、通道剪枝	自动依赖分析,支持复杂模型	复杂网络(ResNet、Transformer)
NNCF	不规则剪枝、结构化剪枝	与 OpenVINO 集成,压缩全流程支持	部署优化
SparseML	渐进式剪枝	提供预定义模板,支持大模型	大规模模型稀疏化
PyTorch Lightning	动态剪枝	结合训练的剪枝支持	Lightning 项目

选择建议

	•	如果是 简单剪枝 或 自定义规则,使用 PyTorch 官方库(torch.nn.utils.prune)。
	•	如果需要 复杂模型剪枝,如卷积网络或 Transformer,建议选择 torch_pruning。
	•	如果目标是 部署优化,可以使用 NNCF 或 SparseML。
	•	使用 PyTorch Lightning 框架的项目,可以直接集成其剪枝功能。

剪枝的选择取决于模型的复杂度、使用场景和部署需求! 😊