Pytorch 图片降噪实现

阿里云2000元红包!本站用户参与享受九折优惠!

思路

要训练一个图片降噪的模型,首先需要将添加噪声的图片作为输入输出,而原图作为输出目标数据,由于是对图像的处理,并且输入输出都是图片,所以可以用卷积提取特征,maxpool降维,再用upsampling升维回图片,原理还是挺简单的,网上示例也很多,这里就用pytorch代码实现一下

导入相关模块

由于这里使用mnist集,所以需要使用到torchvision模块导入数据,下面是使用到的模块:

from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

数据预处理

首先我的的输入数据应该是原图添加噪声后的图片,而输出则是原图,并且形状为:(batch, channel, width, height),因此这里在载入mnist集的后,执行以下步骤:

  • 将数据reshape成对应格式
  • 给数据添加噪声,并且控制像素值还是在0到255之间
  • 对数据进行归一化
  • 分训练和测试数据,这里选前50000数据训练,剩下的10000拿来测试

代码如下:

dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data_x = (data + 80 * torch.rand(60000, 1, 28, 28)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].squeeze())
# plt.show()
# plt.imshow(data_x[0].squeeze())
# plt.show()
# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]

定义网络模型

这里的网络就用几次卷积加maxpool提取特征并降维之后,再通过upsampling升维,代码如下:

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        # (-1, 1, 28, 28) -> (-1, 32, 28, 28)
        self.layer2 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 28, 28) -> (-1, 32, 14, 14)
        self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer4 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 14, 14) -> (-1, 32, 7, 7)
        self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 7, 7) -> (-1, 32, 14, 14)
        self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 14, 14) -> (-1, 32, 28, 28)
        self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        # (-1, 32, 28, 28) -> (-1, 1, 28, 28)
    def forward(self,x):
        x = self.sigmoid(self.layer1(x))
        x = self.layer2(x)
        x = self.sigmoid(self.layer3(x))
        x = self.layer4(x)
        x = self.sigmoid(self.layer5(x))
        x = self.layer6(x)
        x = self.sigmoid(self.layer7(x))
        x = self.layer8(x)
        x = self.sigmoid(self.layer9(x))
        return x
model = net()

定义损失函数和优化器

这里损失函数就用简单的mse就行,优化器用adam,代码如下:

loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

训练模型

batch_size = 1000
# 开始训练
model.train()
for epoch in range(15):
    for batch in range(0, 50000 - batch_size, batch_size):
        output = model(train_x[batch: batch+batch_size])
        loss = loss_fun(train_y[batch: batch+batch_size], output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    test_output = model(test_x)
    loss_test = loss_fun(test_y, test_output)
    print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
    torch.save(model.state_dict(), "autodecode.mdl")

测试模型

model.eval()
test_output = model(test_x[:1000])
train_output = model(train_x[:1000])
# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
    ax = plt.subplot(n, 3, i*3 + 1)
    plt.imshow((test_x[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 2)
    plt.imshow((test_output[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 3)
    plt.imshow((test_y[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()

完整代码

from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
# -----------------------------------
# 数据预处理
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data_x = (data + 80 * torch.rand(60000, 1, 28, 28)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].squeeze())
# plt.show()
# plt.imshow(data_x[0].squeeze())
# plt.show()
# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]
# -----------------------------------
# 定义网络模型
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        # (-1, 1, 28, 28) -> (-1, 32, 28, 28)
        self.layer2 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 28, 28) -> (-1, 32, 14, 14)
        self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer4 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 14, 14) -> (-1, 32, 7, 7)
        self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 7, 7) -> (-1, 32, 14, 14)
        self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 14, 14) -> (-1, 32, 28, 28)
        self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        # (-1, 32, 28, 28) -> (-1, 1, 28, 28)
    def forward(self,x):
        x = self.sigmoid(self.layer1(x))
        x = self.layer2(x)
        x = self.sigmoid(self.layer3(x))
        x = self.layer4(x)
        x = self.sigmoid(self.layer5(x))
        x = self.layer6(x)
        x = self.sigmoid(self.layer7(x))
        x = self.layer8(x)
        x = self.sigmoid(self.layer9(x))
        return x
# -----------------------------------
# 定义网络、损失函数和优化器
model = net()
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# -----------------------------------
# 训练模型
# model.load_state_dict(torch.load("autodecode.mdl"))
# # 载入已保存模型
batch_size = 1000
# 开始训练
model.train()
for epoch in range(15):
    for batch in range(0, 50000 - batch_size, batch_size):
        output = model(train_x[batch: batch+batch_size])
        loss = loss_fun(train_y[batch: batch+batch_size], output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    test_output = model(test_x)
    loss_test = loss_fun(test_y, test_output)
    print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
    torch.save(model.state_dict(), "autodecode.mdl")
# -----------------------------------
# 测试模型
model.eval()
test_output = model(test_x[:1000])
train_output = model(train_x[:1000])
# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
    ax = plt.subplot(n, 3, i*3 + 1)
    plt.imshow((test_x[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 2)
    plt.imshow((test_output[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 3)
    plt.imshow((test_y[i*3 + 1].squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()

https://www.jianshu.com/p/f27526a6773f

Python量化投资网携手4326手游为资深游戏玩家推荐:《《奥拉星》萌新入门指南

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
0 条回复 A 作者 M 管理员
    所有的伟大,都源于一个勇敢的开始!
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论