深圳Yolo 视觉
深圳
立即加入
词汇表

模型剪枝

使用模型剪枝优化机器学习模型。实现更快的推理、更少的内存使用以及更高的能源效率,以用于资源受限的部署。

模型剪枝是一种 模型优化技术 是一种模型优化技术,旨在通过去除不必要的 神经网络的规模和计算复杂度。 参数。随着人工智能模型不断增大以实现更高的性能,它们往往会变得 过度参数化,包含许多对最终输出贡献甚微的连接或神经元。通过识别 并消除这些冗余组件,开发人员就能创建更精简的模型,从而减少对内存和能源的需求,同时提供更快的实时推理。 同时提供更快的实时推理。 这一过程对于部署复杂的架构尤为重要,例如 YOLO11这样的复杂架构尤为重要。 手机或嵌入式传感器。

核心概念和机制

修剪过程通常包括三个主要阶段:训练、修剪和微调。 微调。最初,要对一个大型模型进行训练 收敛,以捕捉复杂的特征。在剪枝阶段,算法会评估特定参数的重要性。 参数的重要性,通常是 weights and biases-基于 大小或灵敏度等标准来评估特定参数(通常是权重和偏差)的重要性。被认为不重要的参数会被设置为零或完全删除。

然而,简单地切除网络的某些部分会降低其准确性。 准确性。为了解决这个问题,模型需要经过一 一轮称为微调的再训练。这一步骤允许剩余参数进行调整并 补偿缺失的连接,通常能将模型的性能恢复到接近原始水平。这种方法的 这种方法的有效性得到了 彩票假说支持这种方法的有效性。 该假说认为,密集网络包含更小的子网络,在单独训练时也能达到相当的准确性。

模型剪枝的类型

剪枝策略一般按被删除组件的结构进行分类:

  • 非结构化修剪:这种方法针对的是单个权重,而不考虑其位置、 值为零。其结果是形成一个 "稀疏 "矩阵,有价值的连接分散在矩阵中。 分散。非结构化剪枝虽然能有效减小模型大小,但通常需要专门的硬件或软件库,才能实现实际的速度。 软件库才能实现实际的速度提升,因为标准的 CPUGPU针对密集矩阵操作进行了优化。 操作进行了优化。
  • 结构化剪枝:这种方法不是去除单个权重,而是去除整个几何 结构,如通道、滤波器或卷积神经网络(CNN)中的层。 卷积神经网络(CNN)中的通道、滤波器或层。通过保持矩阵的密集结构,结构化剪枝可以让标准硬件更高效地处理模型,从而直接降低成本。 模型,直接降低推理延迟,而无需专门的 推理延迟,而无需专门的 稀疏加速工具。

剪枝与量化

虽然两者都是流行的优化技术,但重要的是要区分剪枝和模型量化。 模型量化。剪枝侧重于减少 参数(连接或神经元)的数量,从而有效地改变模型的架构。相比之下 相反,量化会降低这些参数的精度,例如将 32 位浮点数转换为 8 位浮点数。 浮点数转换为 8 位整数。这些方法通常是相辅相成的;开发人员可能会先对模型进行修剪以去除冗余,然后再对模型进行量化。 去掉冗余,然后对其进行量化,以进一步减少部署时的内存占用。 部署

实际应用

剪枝技术在使先进的计算机视觉技术在实际 计算机视觉在实际应用 的关键作用:

  1. 移动物体检测:智能手机上运行的应用程序,如增强现实应用程序或 照片整理器等智能手机上运行的应用,使用剪枝模型在本地执行物体检测。 物体检测。这样既能节省电池 并通过避免云处理来确保用户数据隐私。 处理。
  2. 汽车安全系统: 自动驾驶汽车依靠快速处理 视觉数据来detect 行人和障碍物。剪枝模型使车载 推理引擎做出瞬间决策 无需服务器级GPU 的大量功耗。

实施实例

框架,如 PyTorch等框架提供了内置实用程序,以 剪枝。下面的示例演示了如何对卷积层应用非结构化剪枝。 层应用非结构化剪枝,这是将模型导出为优化格式(如 ONNX.

import torch
import torch.nn.utils.prune as prune

# Initialize a standard convolutional layer
layer = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)

# Apply L1 unstructured pruning to remove 30% of the connections
# This sets the smallest 30% of weights (by absolute value) to zero
prune.l1_unstructured(layer, name="weight", amount=0.3)

# Verify sparsity: calculate the percentage of zero parameters
sparsity = float(torch.sum(layer.weight == 0)) / layer.weight.nelement()
print(f"Layer sparsity: {sparsity:.2%}")

加入Ultralytics 社区

加入人工智能的未来。与全球创新者联系、协作和共同成长

立即加入