本文介绍我们组在CVPR 2023的工作:Stitchable Neural Networks,下文简称SN-Net。一种全新的模型部署方法,利用现有的model family直接做少量epoch finetune就可以得到大量插值般存在的子网络,运行时任意切换网络结构满足不同resource constraint。
论文链接:
https://arxiv.org/pdf/2302.06586.pdf
代码链接:
https://github.com/ziplab/SN-Net
一、背景
去年一次组会上,在和导师们讨论未来的research方向的时候,偶然聊到一个问题:
视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度?
从2012的AlexNet到2023年火出圈的ChatGPT, AI/ML这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace上可以直接下载的模型就有14万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。
模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络backbone设计中通常会有不同scale来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B等等。
传统方法当然能加速模型推理,如pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用NAS搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS中训练一个Supernet的成本也是巨大的,典型的如OFA和BigNAS,花费上千GPU hours才得到一个好网络,资源消耗巨大。
看着huggingface上这么大的model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来? 况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景?
对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张1080就能跑完的实验,现在8张卡都很难train得动一个model,特别是Transformer出来之后。最新的ViT已经scale到22B,BAAI的EVA也把ViT扩展到了1B的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。
二、Stitchable Neural Networks
Industry和Academia所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的pretrained model。现在我们有了一组训练好的model family,如DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的weights和结构快速得到一批新网络来满足不同的资源场景?
我们在CVPR 2023最新的工作Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。
SN-Net的主要思想是:在一组已经训练好的model family中插入若干个stitching layer (即1x1 conv), 使得forward时activation可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!!
此时,我们把原先model family中的网络叫做anchors,缝合出来的新网络叫做stitches。单个SN-Net可以cover众多FLOPs-accuracy的trade-off,如在基于Swin的实验中,一个SN-Net的可以挑战timm中200个独立的模型,整个实验不过是50 epochs,八张V100上训练不到一天。
下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。
1. 模型这么多,怎么去选择
这里主要考虑了几个地方:
· 不同模型结构在网络中各层学习到的representation会有较大差别,缝合出来的网络不一定保证较好的performance。
· 不同数据集学到的东西差别也很大,为了保证性能最好保持在相同pretrained的dataset下。
· 不同网络的实现和训练方式有差别,工程上很难权衡超参和data augmentation的选择。而同一个结构通常在一个repo里,更容易实现。
因此,我们初步关注在相同dataset上训练好的model family上, 即结构相似,但是模型scale不一样,如DeiT-Ti/S/B。
不同family能不能缝合?也能,我们paper里有展示结果,但是工程上会比较麻烦,需要combine不同repo并且权衡超参。
2. 怎么去做缝合?
model stitching在原先工作中大都是以研究representation similarity的形式呈现的,如
· Lenc, Karel, and Andrea Vedaldi. “Understanding image representations by measuring their equivariance and equivalence.” CVPR 2015.· Kornblith, Simon, et al. “Similarity of neural network representations revisited.” ICML, 2019.· Csiszárik, Adrián, et al. “Similarity and matching of neural network representations.” NeurIPS 2021.
总结过去这些工作:同一个网络,用不同seed训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。
而stitching能够work在于,假设前一个网络出来的feature map属于activation 空间A,而另一个网络在此位置的输入feature map属于activation空间B,那么stitching layer做的事情就是把feature map从A空间映射到B空间,使得此时的feature map能模拟下一网络在这个位置的输入。
当网络是已经是pretrained,那么stitching这一过程完全可以formulate成一个求解least squares的问题。也就是说stitching layer这个weights的matrix是可以直接求出来的 (参考Csiszárik, Adrián, et al 这篇)。所以此时求解出来的matrix可以天然作为stitching layer的初始化。
3. 缝合方向的设定
现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁stitch到谁呢?我们主要考虑了两个方面:
· 参考当前backbone设计的惯例,随着网络不断深入,channel dimension是在不断增大的。Fast-to-Slow这方向比较符合常见的网络设计。
· 实验验证Fast-to-Slow得到的curve要比Slow-to-Fast要smooth一点,详见论文。
所以目前SN-Net在方向上是从小模型缝合到大模型。同时我们提出一个constraint: nearest stitching,限制stitching只在复杂度(FLOPs)相邻的两个anchor之间。如补充材料中的Figure 10所示,以DeiT-Ti/S/B为例,我们的方法目前限制在(a), (b)两个case。
这个限制是因为我们发现anchor的gap比较大的时候,缝合出来的网络并不在一个optimal的区间。实验部分也证明直接stitch DeiT-Ti和DeiT-B效果不如中间加一个DeiT-S。
4. 怎么配置Stitching Layer
网络设计地千奇百怪,怎么去缝合是个问题。
我们以DeiT为例,在相同depth的缝合实验上采取了Paired Stitching这种策略。这种策略的启发来自于过去一些工作发现:相邻layer之间的representation是有较高的相似度的。所以我们选择在DeiT得相邻blocks中share同一个stitching layer,如滑窗一般进行stitching。
share的情况下,原先的初始化方法就是简单地对不同solution得到的matrix做一个average。选择share stitching layer还有其他好处,如减少过多stitching layer带来的参数量,同时扩大缝合出来的结构数量,即扩大stitching space。
另外一种情况是两个模型的depth不一样,小模型一般比较浅,block的数量要比大模型少。比如Swin-Ti的第三个stage只有6个block,而Swin-S在第三个stage有18个block。此时我们进行Unpaired Stitching,每个小模型的block都stitch到大模型的若干个block中。这样两个case就都解决了。
5. SN-Net能缝出来多少网络?
这个由多种因素决定。
· 看选择的model family,即anchors的depth。显然anchor越深,那么能stitch的位置就越多,新网络结构也会更多。
· 相同depth下看stitching时sliding window的设置。
· 不加nearest stitching的时候得到的网络更多 (DeiT上的实验是十倍的差距,71 vs. 731)。但是此时不optimal。后续潜力尚待挖掘。
对比NAS中 级别的search space,SN-Net在基于同一组model family得到的网络数量是有限的。但有一点不得不提,纵使search space再大,真正需要的时候也只是用pareto frontier上的网络结构,而SN-Net缝合出来的网络几乎天然落在pareto frontier上,同时部署的时候完全可以直接查表,几乎没有什么search cost。
另外一点是,SN-Net的潜力在于整个pretrained model zoo。有多少model familiy,就有多少潜在的SN-Net变种。这是NAS的单一supernet所不能比拟的。这意味着我们可以轻易缝合已有的model family达到NAS耗费大量计算资源搜出来的网络性能,比如简单缝合两个LeViT就可以用更低的FLOPs(977M vs. 1040M) 达到媲美BigNASModel-XL的性能(80.7% vs. 80.9%),如下图所示
6. 简单的训练策略
训练SN-Net尤为简单。先提前把所有需要训练的stitches定义好,训练中每次iteration都随机sample出来一个stitch,后面和正常的训练一样进行loss回传,梯度下降。为了进一步提升stitches的性能,我们初步实验同时采用了RegNetY-160作为teacher model去做distillation。
三、结果展示
为了验证Joint Training和原有网络从头train的差距,我们选择了若干个和stitches相同的网络结构,然后在ImageNet上训满300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net利用已有的DeiT family只用50个epoch就可以得到比肩甚至更好的性能。同时整个网络只要118.4M的参数,而这71个stitches的总量如果单独训练需要2630M,耗费 71 × 300 epochs,和SN-Net比是22倍的差距。
基于DeiT和Swin Transformer, 我们验证了缝合plain ViT和hierarchical ViT的可行性。性能曲线如在anchors中进行插值一般。
值得一提的是,图中不同点所表示的子网络,即stitch,是可以在运行时随时切换的。这意味着网络在runtime完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行power saving, 手机掉帧,系统运行速度变慢,而此时neural network也可以调整推理速度,做一个speed-accuracy的trade-off。
我们当然也尝试了stitch cnn,甚至不同的family,结果非常promising。
更多实验内容和分析请移步我们的arxiv论文:Stitchable Neural Networks。
四、SN-Net的可扩展空间
SN-Net生于large model zoo的时代。我们初版方法给出了一个最简单的baseline,相信未来有很大的扩展空间,比如
1. 当前的训练策略比较简单,每次iteration sample出来一个stitch,但是当stitches特别多的时候,可能导致某些stitch训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进。
2. anchor的performance会比之前下降一些,虽然不大。直觉上,在joint training过程中,anchor为了保证众多stitches的性能在自身weights上做了一些trade-off。目前补充材料里发现finetune更多epoch可以把这部分损失补回来。
3. 不用nearest stitching可以明显扩大space,但此时大部分网络不在pareto frontier上,未来可以结合训练策略进行改进,或者在其他地方发现advantage。
4. 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个model zoo就像积木一样,可操作空间更大,玩法更多,这一点NUS的Xingyi Yang (https://adamdad.github.io/) 之前有尝试,参考Deep Model Reassembly.(https://arxiv.org/abs/2210.17409)。
更多探索就留给future work了。代码已经开源至 https://github.com/ziplab/SN-Net,硬件要求十分友好,50个epoch (用8卡V100大约半天时间) 就可以复现结果。欢迎有兴趣的同学进行尝试!
作者:潘梓正
文章来源:知乎专栏【https://zhuanlan.zhihu.com/p/611257510】
Illustration by Twin Rizki from IconScout