Pytorch 实现自定义卷积:以 2.5 维卷积(2.5D Convolution)为例

阿里云2000元红包!本站用户参与享受九折优惠!

        在用 Pytorch 实现各种卷积神经网络的时候,一般用到的卷积层都是系统自带的 2 维卷积 torch.nn.Conv2d,或者较少用到的 1 维卷积 torch.nn.Conv1d3 维卷积 torch.nn.Conv3d,这些卷积提供了足够的参数,使得实现带洞卷积(Atrous Convolution)深度可分离卷积(Depthwise Seperable Convolution)等特殊卷积都易如反掌。但有时候,为了某些特殊的需求,不能直接使用经典的卷积层,而是要自定义的实现某种新的卷积运算,比如可形变卷积(Deformable Convolution),因此学会从底层实现自定义卷积层是必要且必须的。

        本文试着提供一个自定义卷积层的简单教程,为了有针对性和实用性,以实现 2.5 维卷积RGB-D 图像语义分割论文:2.5D Convolution for RGB-D Semantic Segmentation)为例。本文是在参考了论文 Pixel-Adaptive Convolutional Neural Networks 的开源项目 pacnet 的基础上实现的,在此对作者表示感谢。

        本文的所有代码都见下文,也可以访问 [GitHub:稍后放出]。

一、2.5 维卷积原理

        对于卷积核为 k \times k2 维卷积,计它的感受野大小为 k\Delta p\times k\Delta p,以 (u_i, v_i) 为中心的标准二维卷积计算如下:


标准的二维卷积计算公式,来源:2.5D Convolution for RGB-D Semantic Segmentation(下同)

类似的,标准的 3 维卷积 计算如下(k \times k \times k):


标准的三维卷积计算公式

        二维卷积(三维卷积)处理图像(视频)数据已经非常成熟,应用十分广泛。对于带有深度信息的 RGB-D 图像的语义分割,如果把深度信息当成一个额外的通道,那么直接使用二维卷积来实现语义分割模型即可。然而,这样做会忽视深度信息中隐藏的几何结构特征,因此有必要设计一种新颖的卷积方式来充分使用深度信息中的几何特征,论文(2.5D Convolution for RGB-D Semantic Segmentation)作者们就设计了一种称为 2.5 维卷积的操作:


2.5 维卷积计算公式

其中 , 为深度信息, 为 个 2 维卷积核的参数, 的计算公式为:


掩模操作

        根据以上公式,如果输入的特征通道数为 m,输出通道数为 n,那么容易知道:

  • 2 维卷积核的参数量:m\times n\times k\times k
  • 3 维卷积核的参数量:m\times n\times k\times k\times k
  • 2.5 维卷积核的参数量:(m\times n\times k\times k)\times k

如果输入、输出的分辨率都是 r\times s (或者 r\times s\times t),那么(大约):

  • 2 维卷积的计算量:m\times n\times k\times k\times r\times s
  • 3 维卷积的计算量:m\times n\times k\times k\times k \times r\times s\times t
  • 2.5 维卷积的计算量:(m\times n\times k\times k\times r\times s)\times k

显然,虽然相比于 2 维卷积来说,2.5 维卷积的参数量和计算量都要大,但对比 3 维卷积来说,在参数量一致的情况下,2.5 维卷积的计算量却小得多。因此,从——性能上优于 2 维卷积,计算量上优于 3 维卷积——的角度看,2.5 维卷积是有意义的

二、2.5 维卷积实现

        严格按照公式 (4-7)来实现,2.5 维卷积的实现代码为(命名为:conv2_5d.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 20 18:58:19 2019
@author: lijingxiong
Implementation of 2.5D convolution:
    paper: 2.5D Convolution for RGB-D Semantic Segmentation.
Reference: https://github.com/NVlabs/pacnet/blob/master/pac.py
"""
import math
import torch
        
        
class RepeatKernelConvFn(torch.autograd.function.Function):
    """2.5D convolution with kernel.
    """
        
    @staticmethod
    def forward(ctx, inputs, kernel, weight, bias=None, stride=1, padding=0, 
                dilation=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, channels, height, width] 
                representing a batch of images.
            kernel: A tensor with shape [k, batch, channels, k, k, N, N],
                where k = kernel_size and N = number of slide windows.
            weight: A tensor with shape [k, out_channels, in_channels, 
                kernel_size, kernel_size].
            bias: None or a tensor with shape [out_channels].
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        (batch_size, channels), input_size = inputs.shape[:2], inputs.shape[2:]
        ctx.in_channels = channels
        ctx.input_size = input_size
        ctx.kernel_size = tuple(weight.shape[-2:])
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        
        needs_input_grad = ctx.needs_input_grad
        ctx.save_for_backward(
            inputs if (needs_input_grad[1] or needs_input_grad[2]) else None,
            kernel if (needs_input_grad[0] or needs_input_grad[2]) else None,
            weight if (needs_input_grad[0] or needs_input_grad[1]) else None)
        ctx._backend = torch._thnn.type2backend[inputs.type()]
        
        # Slide windows, [batch, channels x kernel_size x kernel_size, N],
        # where N is the number of slide windows.
        inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                 ctx.dilation, ctx.padding,
                                                 ctx.stride)
        inputs_wins = inputs_wins.view(
            1, batch_size, channels, *kernel.shape[3:])
        inputs_mul_kernel = inputs_wins * kernel
                
        # Matrix multiplication
        outputs = torch.einsum(
            'hijklmn,hojkl->iomn', (inputs_mul_kernel, weight))
        
        if bias is not None:
            outputs += bias.view(1, -1, 1, 1)
        return outputs
        
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        grad_inputs = grad_kernel = grad_weight = grad_bias = None
        batch_size, out_channels = grad_outputs.shape[:2]
        output_size = grad_outputs.shape[2:]
        in_channels = ctx.in_channels
        
        # Compute gradients
        inputs, kernel, weight = ctx.saved_tensors
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
            grad_inputs_mul_kernel = torch.einsum('iomn,hojkl->hijklmn',
                                                  (grad_outputs, weight))
        if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
            inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                     ctx.dilation, ctx.padding,
                                                     ctx.stride)
            inputs_wins = inputs_wins.view(1, batch_size, in_channels,
                                           ctx.kernel_size[0], 
                                           ctx.kernel_size[1],
                                           output_size[0], output_size[1])
        if ctx.needs_input_grad[0]:
            grad_inputs = grad_outputs.new()
            grad_inputs_wins = grad_inputs_mul_kernel * kernel
            grad_inputs_wins = grad_inputs_wins.view(
                ctx.kernel_size[0], batch_size, -1, output_size[0], output_size[1])
            ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
                                                grad_inputs_wins,
                                                grad_inputs,
                                                ctx.input_size[0],
                                                ctx.input_size[1],
                                                ctx.kernel_size[0],
                                                ctx.kernel_size[1],
                                                ctx.dilation[0], 
                                                ctx.dilation[1],
                                                ctx.padding[0], 
                                                ctx.padding[1],
                                                ctx.stride[0],
                                                ctx.stride[1])
        if ctx.needs_input_grad[1]:
            grad_kernel = inputs_wins * grad_inputs_mul_kernel
            grad_kernel = grad_kernel.sum(dim=1, keepdim=True)
        if ctx.needs_input_grad[2]:
            inputs_mul_kernel = inputs_wins * kernel
            grad_weight = torch.einsum('iomn,hijklmn->hojkl',
                                       (grad_outputs, inputs_mul_kernel))
        if ctx.needs_input_grad[3]:
            grad_bias = torch.einsum('iomn->o', (grad_outputs,))
        return (grad_inputs, grad_kernel, grad_weight, grad_bias, None, None,
                None)
        
        
class DepthKernelFn(torch.autograd.function.Function):
    """Compute mask in paper: 
        2.5D convolution for rgb-d semantic segmentation.
    """
    
    @staticmethod
    def forward(ctx, depth, f, kernel_size, stride, padding, dilation):
        """Forward computation.
        
        Args:
            depth: A tensor with shape [batch, 1, height, width] representing
                a batch of depth maps.
            f: A constant.
            
        Returns:
            A tensor with shape [batch, 1, kernel_size, kernel_size, N, N]
            where N = number of slide windows.
        """
        ctx.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        
        batch_size, channels, in_height, in_width = depth.shape
        out_height = (in_height + 2 * ctx.padding[0] - 
                      ctx.dilation[0] * (ctx.kernel_size[0] - 1)
                      -1) // ctx.stride[0] + 1
        out_width = (in_width + 2 * ctx.padding[1] - 
                     ctx.dilation[1] * (ctx.kernel_size[1] - 1)
                     -1) // ctx.stride[1] + 1
        
        depth_wins = torch.nn.functional.unfold(depth, ctx.kernel_size,
                                                ctx.dilation, ctx.padding,
                                                ctx.stride)
        depth_wins = depth_wins.view(batch_size, channels, ctx.kernel_size[0],
                                     ctx.kernel_size[1], out_height, out_width)
        s_wins = torch.nn.functional.unfold(depth / f, ctx.kernel_size,
                                            ctx.dilation, ctx.padding,
                                            ctx.stride)
        s_wins = depth_wins.view(batch_size, channels, ctx.kernel_size[0],
                                 ctx.kernel_size[1], out_height, out_width)
        
        kernels = []
        center_y, center_x = ctx.kernel_size[0] // 2, ctx.kernel_size[1] // 2
        for l in range(ctx.kernel_size[0]):
            z_l = depth_wins + (l - (ctx.kernel_size[0] - 1) / 2) * s_wins
            z_l_0 = z_l.contiguous()[:, :, center_y:center_y + 1,
                                     center_x:center_x + 1, :, :]
            s_0 = s_wins.contiguous()[:, :, center_y:center_y + 1,
                                      center_x:center_x + 1, :, :]
            mask_l = torch.where(depth_wins < z_l_0 - s_0,
                                 torch.full_like(depth_wins, 0), 
                                 depth_wins)
            mask_l = torch.where(mask_l >= z_l_0 + s_0, 
                                 torch.full_like(depth_wins, 0), 
                                 torch.full_like(depth_wins, 1))
            kernels.append(mask_l.unsqueeze(dim=0))
        return torch.cat(kernels, dim=0)
    
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        return 0, None, None, None, None, None
    
    
class Conv2_5d(torch.nn.Module):
    """Implementation of 2.5D convolution."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True):
        """Constructor."""
        super(Conv2_5d, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        self.stride = torch.nn.modules.utils._pair(stride)
        self.padding = torch.nn.modules.utils._pair(padding)
        self.dilation = torch.nn.modules.utils._pair(dilation)
        
        # Parameters: weight, bias
        self.weight = torch.nn.parameter.Parameter(
            torch.Tensor(kernel_size, out_channels, in_channels, kernel_size,
                         kernel_size))
        if bias:
            self.bias = torch.nn.parameter.Parameter(
                torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        # Initialization
        self.reset_parameters()
        
    def forward(self, inputs, depth, f=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, in_channels, height, width] 
                representing a batch of images.
            depth: A tensor with shape [batch, 1, height, width] representing
                    a batch of depth maps.
            f: A constant.
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        kernel = DepthKernelFn.apply(depth, f, self.kernel_size, self.stride,
                                     self.padding, self.dilation)
        
        outputs = RepeatKernelConvFn.apply(inputs, kernel, self.weight,
                                           self.bias, self.stride,
                                           self.padding, self.dilation)
        return outputs
    
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        return s.format(**self.__dict__)
    
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)

        实现自定义卷积层的要点是:

  • 使用 torch.nn.functional.unfold 函数将数据按照滑动窗口分块:

        对于批量 b、通道数m、分辨率 R\times S 的输入 x,形状为: b\times m\times r\times s,如果卷积核大小(kernel size)、填充大小(padding)、步幅(stride)、空洞率(dilation)分别为 k,p,s,d,那么该函数的输出大小为:b\times(m\times k\times k)\times (R^\prime\times S^\prime),是一个 3 维张量,其中:
R^\prime=\lfloor[R+2p-d(k-1)-1]/s\rfloor+1,\\ S^\prime=\lfloor[S+2p-d(k-1)-1]/s\rfloor+1

  • 使用 torch.einsum 函数对张量按照卷积运算求和

        根据爱因斯坦和式约定,上下标一致的数据可以省略求和号,如:
s=\sum^N_{i=1} a^ib_i=a^ib_i\\
把这一约定用符合表示并计算出来就是 einsum 函数:

torch.einsum('i,i->', (a, b))

比如:

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
s=torch.einsum('i,i->', (a, b))
# s = tensor(14)

        结合以上两个函数,对于形状为 [b, 1, R, S] 的深度信息 z,对于固定的 l,根据公式 (7),用 torch.nn.functional.unfold 函数得到形状为 [b, 1, k, k, R^\prime, S^\prime]m_l(已通过 view 函数改变形状),对所有的 l 拼接 m_l,得到形状为 [k, b, 1, k, k, R^\prime, S^\prime] 的张量。同理,对于 x,经过滑动窗口操作之后形状为 [1, b, m, k, k, R^\prime, S^\prime](额外加第 1 维)。这两个 7 维张量经过元素级的乘法得到新的 7 维张量,然后按照公式 (4)用 torch.einsum 函数对这个张量和形状为 [k, m, n, k, k] 的权重参数求和,其中 n 为输出通道数。

        以上过程就是继承了 torch.autograd.function.Function 类的两个类: DepthKernelFnRepeatKernelConvFnforward 函数的内容。而 backward 函数就是要对 forward 函数的计算利用链式法则求梯度,因此无需赘言。

三、2.5 维卷积实现代码的验证

        为了验证以上实现的代码在反向传播时不会报错,定义一个两层的简单网络来验证如下(命名为:conv2_5d_test.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 27 13:41:23 2019
@author: lijingxiong
"""
import torch
import conv2_5d
class ConvTest(torch.nn.Module):
    """A mini networt to test Conv2_5d in forward and backword computation."""
    
    def __init__(self, num_classes=2):
        super(ConvTest, self).__init__()
        
        self._head_conv = conv2_5d.Conv2_5d(in_channels=3, 
                                            out_channels=32, 
                                            kernel_size=5, 
                                            padding=2, 
                                            bias=False)
        self._pred_conv = torch.nn.Conv2d(in_channels=32,
                                          out_channels=num_classes,
                                          kernel_size=3,
                                          padding=1,
                                          bias=False)
        self._batch_norm = torch.nn.BatchNorm2d(num_features=num_classes,
                                                momentum=0.995)
        
    def forward(self, x, z, f=1):
        x = self._head_conv(x, z, f)
        x = self._pred_conv(x)
        x = self._batch_norm(x)
        return x
    
    
if __name__ == '__main__':
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    model = ConvTest().to(device)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    
    num_steps = 100
    for i in range(num_steps):
        images = torch.rand((2, 3, 64, 64)).to(device)
        depth = torch.rand((2, 1, 64, 64)).to(device)
        labels = torch.LongTensor(
            torch.full((2, 64, 64), 0, dtype=torch.int64)).to(device)
        
        # Forward pass
        outputs = model(images, depth)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Step: {}/{}, Loss: {:.4f}'.format(i+1, num_steps, loss.item()))

        直接执行 :

python3 conv2_5d_test.py

代码正常结束,且损失逐渐减小,(暂时)认为代码是正确的。

https://www.jianshu.com/p/89d6e78fba82

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
Matplotlib
0 条回复 A 作者 M 管理员
    所有的伟大,都源于一个勇敢的开始!
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论