炼数成金 门户 商业智能 深度学习 查看内容

模型剪枝,不可忽视的推断效率提升方法

2019-8-9 10:24| 发布者: 炼数成金_小数| 查看: 45930| 评论: 0|来自: 机器学习研究会订阅号

摘要: 目前,深度学习模型需要大量算力、内存和电量。当我们需要执行实时推断、在设备端运行模型、在计算资源有限的情况下运行浏览器时,这就是瓶颈。能耗是人们对于当前深度学习模型的主要担忧。而解决这一问题的方法之一 ...

网络 模型 架构 神经网络 深度学习 权重

剪枝是常用的模型压缩方法之一,本文对剪枝的原理、效果进行了简单介绍。

目前,深度学习模型需要大量算力、内存和电量。当我们需要执行实时推断、在设备端运行模型、在计算资源有限的情况下运行浏览器时,这就是瓶颈。能耗是人们对于当前深度学习模型的主要担忧。而解决这一问题的方法之一是提高推断效率。

大模型 => 更多内存引用 => 更多能耗
剪枝正是提高推断效率的方法之一,它可以高效生成规模更小、内存利用率更高、能耗更低、推断速度更快、推断准确率损失最小的模型,此类技术还包括权重共享和量化。深度学习从神经科学中汲取过灵感,而剪枝同样受到生物学的启发。

随着深度学习的发展,当前最优的模型准确率越来越高,但这一进步伴随的是成本的增加。本文将对此进行讨论。

挑战 1:模型规模越来越大
我们很难通过无线更新(over-the-air update)分布大模型。


来自 Bill Dally 在 NIPS 2016 workshop on Efficient Methods for Deep Neural Networks 的演讲。

挑战 2:速度

使用 4 块 M40 GPU 训练 ResNet 的时间,所有模型遵循 fb.resnet.torch 训练。

训练时间之长限制了机器学习研究者的生产效率。

挑战 3:能耗
AlphaGo 使用了 1920 块 CPU 和 280 块 GPU,每场棋局光电费就需要 3000 美元。

这对于移动设备意味着:电池耗尽

对于数据中心意味着:总体拥有成本(TCO)上升

解决方案:高效推断算法
剪枝
权重共享
低秩逼近
二值化网络(Binary Net)/三值化网络(Ternary Net)
Winograd 变换

剪枝所受到的生物学启发
人工神经网络中的剪枝受启发于人脑中的突触修剪(Synaptic Pruning)。突触修剪即轴突和树突完全衰退和死亡,是许多哺乳动物幼年期和青春期间发生的突触消失过程。突触修剪从公出生时就开始了,一直持续到 20 多岁。


Christopher A Walsh. Peter Huttenlocher (1931–2013). Nature, 502(7470):172–172, 2013.

修剪深度神经网络


[Lecun et al. NIPS 89] [Han et al. NIPS 15]

神经网络通常如上图左所示:下层中的每个神经元与上一层有连接,但这意味着我们必须进行大量浮点相乘操作。完美情况下,我们只需将每个神经元与几个其他神经元连接起来,不用进行其他浮点相乘操作,这叫做「稀疏」网络。

稀疏网络更容易压缩,我们可以在推断期间跳过 zero,从而改善延迟情况。

如果你可以根据网络中神经元但贡献对其进行排序,那么你可以将排序较低的神经元移除,得到规模更小且速度更快的网络。

速度更快/规模更小的网络对于在移动设备上运行它们非常重要。

如果你根据神经元权重的 L1/L2 范数进行排序,那么剪枝后模型准确率会下降(如果排序做得好的话,可能下降得稍微少一点),网络通常需要经过训练-剪枝-训练-剪枝的迭代才能恢复。如果我们一次性修剪得太多,则网络可能严重受损,无法恢复。因此,在实践中,剪枝是一个迭代的过程,这通常叫做「迭代式剪枝」(Iterative Pruning):修剪-训练-重复(Prune / Train / Repeat)。

想更多地了解迭代式剪枝,可参考 TensorFlow 团队的代码:
https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb

权重修剪
将权重矩阵中的多个权重设置为 0,这对应上图中的删除连接。为了使稀疏度达到 k%,我们根据权重大小对权重矩阵 W 中的权重进行排序,然后将排序最末的 k% 设置为 0。
f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f[『model_weights』])[:-1]: 
    data = f[『model_weights』][l][l][『kernel:0』] 
    w = np.array(data) 
    ranks[l]=(rankdata(np.abs(w),method= 'dense')—1).astype(int).reshape(w.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l] 
    data[…] = w

单元/神经元修剪
将权重矩阵中的多个整列设置为 0,从而删除对应的输出神经元。
为使稀疏度达到 k%,我们根据 L2 范数对权重矩阵中的列进行排序,并删除排序最末的 k%。

f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f['model_weights'])[:-1]: 
    data = f['model_weights'][l][l]['kernel:0'] 
    w = np.array(data) 
    norm = LA.norm(w,axis=0) 
    norm = np.tile(norm,(w.shape[0],1)) 
    ranks[l] = (rankdata(norm,method='dense')—1).astype(int).reshape(norm.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l]
    data[…] = w

随着稀疏度的增加、网络删减越来越多,任务性能会逐渐下降。那么你觉得稀疏度 vs. 性能的下降曲线是怎样的呢?

我们来看一个例子,使用简单的图像分类神经网络架构在 MNIST 数据集上执行任务,并对该网络进行剪枝操作。

下图展示了神经网络的架构:

参考代码中使用的模型架构。

稀疏度 vs. 准确率。读者可使用代码复现上图(https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J)。


总结
很多研究者认为剪枝方法被忽视了,它需要得到更多关注和实践。本文展示了如何在小型数据集上使用非常简单的神经网络架构获取不错的结果。我认为深度学习在实践中用来解决的许多问题与之类似,因此这些问题也可以从剪枝方法中获益。

参考资料
本文相关代码:https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J
To prune, or not to prune: exploring the efficacy of pruning for model compression, Michael H. Zhu, Suyog Gupta, 2017(https://arxiv.org/pdf/1710.01878.pdf)
Learning to Prune Filters in Convolutional Neural Networks, Qiangui Huang et. al, 2018(https://arxiv.org/pdf/1801.07365.pdf)
Pruning deep neural networks to make them fast and small(https://jacobgil.github.io/deeplearning/pruning-deep-learning)
使用 Tensorflow 模型优化工具包优化机器学习模型(https://www.tensorflow.org/model_optimization)

声明:本文版权归原作者所有,文章收集于网络,为传播信息而发,如有侵权,请联系小编及时处理,谢谢!

欢迎加入本站公开兴趣群
商业智能与数据分析群
兴趣范围包括:各种让数据产生价值的办法,实际应用案例分享与讨论,分析工具,ETL工具,数据仓库,数据挖掘工具,报表系统等全方位知识
QQ群:81035754

鲜花

握手

雷人

路过

鸡蛋

相关阅读

最新评论

热门频道

  • 大数据
  • 商业智能
  • 量化投资
  • 科学探索
  • 创业

即将开课

 

GMT+8, 2019-10-20 14:36 , Processed in 0.097586 second(s), 24 queries .