11 超分辨率生成对抗网络(SRGAN)之超分辨率的实现
系列进度
生成对抗网络高级 · 第 11 / 21 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
GAN 进阶内容要围绕稳定性、条件控制、架构变化和评估方法建立判断框架。阅读时可以按「数据准备 -> 数据集加载与预处理 -> 训练模型 -> GAN 训练步骤」建立结构,再回到正文里的代码、案例或指标做验证。
读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「数据准备」,再查「数据集加载与预处理」。
在上一篇中,我们深入探讨了超分辨率生成对抗网络(SRGAN)的架构,了解了其生成器和判别器的设计理念和结构。今天,我们将关注于如何实际实现超分辨率。这一过程涉及到真实数据的预处理、模型的训练过程以及如何使用训练好的模型进行图像超分辨率重建。
数据准备
在进行超分辨率任务之前,首先需要准备数据集。一个常用的数据集是 DIV2K,它包括高分辨率图像,这是训练超分辨率模型的重要基础。
实现 SRGAN 超分辨率时,先看低清输入、高清目标、生成器输出、判别器反馈和感知损失。
数据集加载与预处理
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
def load_images_from_folder(folder, scale_factor=4):
images = []
for filename in os.listdir(folder):
img = Image.open(os.path.join(folder, filename)).convert('RGB')
img = img.resize((img.width // scale_factor, img.height // scale_factor), Image.BICUBIC)
images.append(img)
return images
# 设定数据集目录与缩放因子
train_folder = 'path/to/DIV2K/train'
images = load_images_from_folder(train_folder)
在上述代码中,我们将每个高分辨率图像减少到其尺寸的四分之一,这样就得到了低分辨率(LR)图像。随后的处理我们会使用这些 LR 图像作为输入,同时使用原图作为目标(HR)图像。
训练模型
在 SRGAN 的实现中,训练过程分为若干个步骤:准备 GAN 的组成部分(生成器和判别器),设置损失函数,然后迭代训练模型。
读《超分辨率生成对抗网络(SRGAN)之超分辨率的实现》时,可以把配图当成路线卡:先看整体顺序,再看每一步为什么这样做,最后再检查边界条件。
GAN 训练步骤
训练环节的关键是调整生成器和判别器的参数,使得生成器能够生成高质量的超分辨率图像,而判别器则要能够辨别生成的图像与真实图像的区别。
import torch.optim as optim
from model import Generator, Discriminator # 假设你有一个模块 model 包含这两个类
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
criterion_GAN = torch.nn.BCELoss()
criterion_content = torch.nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001)
# 训练过程
for epoch in range(num_epochs):
for i, (lr_images, hr_images) in enumerate(data_loader):
# 更新判别器
optimizer_D.zero_grad()
# 真实和生成的标签
real_labels = torch.ones((batch_size, 1), requires_grad=False)
fake_labels = torch.zeros((batch_size, 1), requires_grad=False)
# 判别器的损失
outputs = discriminator(hr_images)
d_loss_real = criterion_GAN(outputs, real_labels)
fake_images = generator(lr_images)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion_GAN(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 更新生成器
optimizer_G.zero_grad()
outputs = discriminator(fake_images)
g_loss_GAN = criterion_GAN(outputs, real_labels)
g_loss_content = criterion_content(fake_images, hr_images)
g_loss = g_loss_GAN + lambda_content * g_loss_content # lambda_content 是超参数
g_loss.backward()
optimizer_G.step()
在上述代码中,我们通过交替更新判别器和生成器的参数来优化 GAN 模型。对于判别器的损失,主要采取应用于真实图像与生成图像的对比。对于生成器的损失,则包含了内容损失和对抗损失。
实现超分辨率图像的生成
一旦我们的模型训练完成,就可以使用它来生成超分辨率图像。将低分辨率图像输入到生成器中,即可获得高分辨率图像。
# 生成超分辨率图像
def generate_super_resolution(generator, lr_image):
with torch.no_grad():
sr_image = generator(lr_image.unsqueeze(0)) # 添加批量维度
return sr_image.squeeze(0) # 移除批量维度
# 使用训练好的生成器生成超分辨率图像
lr_test_image = load_images_from_folder('path/to/test/image')[0] # 加载测试图像
sr_image = generate_super_resolution(generator, lr_test_image)
如果《超分辨率生成对抗网络(SRGAN)之超分辨率的实现》还没完全消化,可以从这张卡片的四个动作重新走一遍。
回看《超分辨率生成对抗网络(SRGAN)之超分辨率的实现》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。
结论
在本篇文章中,我们踏踏实实实现了 SRGAN 的超分辨率图像生成过程,从数据准备到模型训练,再到使用模型进行图像重建。接下来的篇幅中,我们将讨论如何评估模型生成的图像质量,使用一系列标准评估指标(如 PSNR 和 SSIM)来量化 SRGAN 的表现。
这一系列的设计和实现,突显了生成对抗网络在图像超分辨率领域中的强大能力,同时也为后续的评估提供了基础。希望您在实现中获得启发,并取得优秀的超分辨率效果!
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
超分辨率生成对抗网络(SRGAN)之超分辨率的实现适合谁读?
这是 生成对抗网络高级 系列第 11 / 21 篇,适合正在学习生成对抗网络高级,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇生成对抗网络高级教程要多久?
按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读