即插即用 HorNet|递归门控卷积

2022-10-13 13:31 955 阅读 ID:407
计算机视觉论文速递
计算机视觉论文速递

    总结

    本文提出的gnConv可以看作是即插即用模块,通过通道划分特征和特征间相乘的方式提出了一个较为新颖的结构。

    论文来源:arxir

    论文地址:https://arxiv.org/abs/2207.14284

    论文代码:https://github.com/raoyongming/HorNet

    • 提出了递归门控卷积(gnConv),它通过门控卷积和递归设计来执行高阶空间交互,具有高度的灵活性和可定制性,兼容各种卷积变量,并将自注意的两阶交互扩展到任意阶,而不引入显著的额外计算。
    • gnConv可以作为一个即插即用的模块,以改进各种视觉Transformer和基于卷积的模型。在此基础上构建了一个新的通用视觉骨干家族,名为HorNet。

    前言

    图1展示了几张不同卷积的结构,并说明了优劣:

    1. 标准的卷积运算并没有明确地考虑空间间的相互作用。
    2. 动态卷积和SE引入了动态权值,提高具有额外空间交互的卷积的建模能力。
    3. 自注意操作通过两个连续的矩阵乘法进行二阶空间交互。
    4. gnConv使用门控卷积和递归对的高效实现实现任意顺序的空间交互。
                                                                                   不同卷积结构

    方法

    gnConv 递归门控卷积

    gnConv是用标准卷积、线性投影和元素乘法构建的,但具有类似于自注意的输入-自适应空间混合函数。

    与门控卷积之间的输入-自适应交互作用

    图片的大小阻碍着视觉Transformer的应用,特别是分割和大分辨率检测。本文并没有寻求降低自注意的复杂性,而是寻求一种更有效的方法来通过卷积和全连接层等简单的操作来执行空间交互。

    设x∈RHW×C为输入特征,门控卷积y=gConv(x)的输出可以写为

    其中,φin,φout是执行通道混合的投影层,f是深度卷积。gConv中的交互作用是一阶交互作用,因为每个p0与它的邻居特征q0只有交互作用一次。

    与递归门控的高阶交互作用

    在与gConv实现有效的一阶空间交互作用后设计了gnConv,这是一种递归门控卷积,通过引入高阶交互作用进一步提高模型容量。

    我们首先使用φin来获得一组投影特征p0和{qk}n−1k=0:

    然后递归地执行门控卷积

    我们将输出缩放为1/α来稳定训练。是一组基于深度的卷积层,并用于以不同的顺序匹配维度:

    最后,我们将最后一个递归步骤qn的输出输入给投影层φout,得到gnConv的结果。

    与大型核卷积的长期交互作用

    传统的CNNs通常在整个网络中使用3×3卷积,而视觉Transformer在整个特征图或一个相对较大的局部窗口(例如7×7)内计算自注意。受此设计的启发,最近有一些努力将大型内核卷积引入cnn的。为了使我们的gnConv能够捕获长期的交互,我们采用了两种深度卷积的实现f:

    • 7 * 7卷积
    • 全局滤波器(Global Filter)

    实验

    实验占了整篇论文的大半,这里不贴太多图了。

    通过ImageNet w.r.t.上的前1个精度来比较模型的权衡(a)个参数数;(b)FLOPs;(c)延迟。延迟是用一个单一的NVIDIA RTX 3090 GPU来测量的。

    模块代码

    gnConv

     class gnconv(nn.Module):
         def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
             super().__init__()
             self.order = order
             self.dims = [dim // 2 ** i for i in range(order)]
             self.dims.reverse()
             self.proj_in = nn.Conv2d(dim, 2*dim, 1)
     ​
             if gflayer is None:
                 self.dwconv = get_dwconv(sum(self.dims), 7, True)
             else:
                 self.dwconv = gflayer(sum(self.dims), h=h, w=w)
             
             self.proj_out = nn.Conv2d(dim, dim, 1)
     ​
             self.pws = nn.ModuleList(
                 [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
             )
     ​
             self.scale = s
             print('[gnconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
     ​
         def forward(self, x, mask=None, dummy=False):
             B, C, H, W = x.shape
     ​
             fused_x = self.proj_in(x)
             pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
     ​
             dw_abc = self.dwconv(abc) * self.scale
     ​
             dw_list = torch.split(dw_abc, self.dims, dim=1)
             x = pwa * dw_list[0]
     ​
             for i in range(self.order -1):
                 x = self.pws[i](x) * dw_list[i+1]
     ​
             x = self.proj_out(x)
     ​
             return x

    全局滤波器

    class GlobalLocalFilter(nn.Module):
         def __init__(self, dim, h=14, w=8):
             super().__init__()
             self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
             self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
             trunc_normal_(self.complex_weight, std=.02)
             self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
             self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
     ​
         def forward(self, x):
             x = self.pre_norm(x)
             x1, x2 = torch.chunk(x, 2, dim=1)
             x1 = self.dw(x1)
     ​
             x2 = x2.to(torch.float32)
             B, C, a, b = x2.shape
             x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
     ​
             weight = self.complex_weight
             if not weight.shape[1:3] == x2.shape[2:4]:
                 weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
     ​
             weight = torch.view_as_complex(weight.contiguous())
     ​
             x2 = x2 * weight
             x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
     ​
             x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
             x = self.post_norm(x)
             return x
    
    免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
    反馈
    to-top--btn