10 超分辨率生成对抗网络(SRGAN)之SRGAN的架构
系列进度
生成对抗网络高级 · 第 10 / 21 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
GAN 进阶内容要围绕稳定性、条件控制、架构变化和评估方法建立判断框架。阅读时可以按「SRGAN的基本框架 -> 生成器 -> 判别器 -> SRGAN的损失函数」建立结构,再回到正文里的代码、案例或指标做验证。
读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「SRGAN的基本框架」,再查「生成器」。
在上篇中,我们探讨了条件生成对抗网络(cGAN)的训练和评估,了解了如何利用条件信息来生成目标数据。在本篇中,我们将专注于超分辨率生成对抗网络(SRGAN)的具体架构。SRGAN是一种用于图像超分辨率重建的强大模型,能够将低分辨率图像转化为高分辨率图像,同时保持图像的细节和纹理。
SRGAN的基本框架
SRGAN的架构主要由两个部分构成:生成器(Generator)和判别器(Discriminator)。与一般的GAN架构相似,SRGAN的生成器用于生成与真实高分辨率图像相似的图像,而判别器则用于区分生成的图像和真实图像。
理解 SRGAN 架构时,先看低分辨率输入如何放大,判别器如何评价细节,感知损失如何保留纹理。
生成器
SRGAN的生成器通常采用卷积神经网络(CNN)结构。以下是SRGAN生成器的主要特点:
- 输入:低分辨率图像(通常是经过降采样的高分辨率图像)。
- 特征提取:使用多个卷积层提取图像特征,通过激活函数(如ReLU)引入非线性因素。
- 上采样:通过像素Shuffle等方法将低分辨率图像上采样到目标高分辨率。
- 输出:生成高分辨率图像。
一个典型的SRGAN生成器可能实现如下:
import tensorflow as tf
from tensorflow.keras import layers
def build_generator():
inputs = tf.keras.Input(shape=(None, None, 3))
# 低分辨率特征提取
x = layers.Conv2D(64, kernel_size=9, padding='same')(inputs)
x = layers.PReLU()(x)
# 残差块
for _ in range(16):
residual = x
x = layers.Conv2D(64, kernel_size=3, padding='same')(x)
x = layers.PReLU()(x)
x = layers.Conv2D(64, kernel_size=3, padding='same')(x)
x = layers.add([residual, x])
# 上采样
x = layers.Conv2D(256, kernel_size=3, padding='same')(x)
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x) # PixelShuffle
# 最后一层
outputs = layers.Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
return tf.keras.Model(inputs, outputs)
判别器
SRGAN的判别器也是基于卷积神经网络的,通常结构如下:
- 输入:生成的高分辨率图像和真实的高分辨率图像(通过合并操作)。
- 多层卷积:逐层使用卷积层提取特征,逐渐缩小图像的空间维度。
- 输出:经过sigmoid激活函数后输出一个二分类结果,表示输入图像为真实图像的概率。
判别器的代码示例如下:
def build_discriminator():
inputs = tf.keras.Input(shape=(None, None, 3))
x = layers.Conv2D(64, kernel_size=3, strides=2, padding='same')(inputs)
x = layers.LeakyReLU(alpha=0.2)(x)
for _ in range(3):
x = layers.Conv2D(64 * (2 ** (_ + 1)), kernel_size=3, strides=2, padding='same')(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Flatten()(x)
x = layers.Dense(1024)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)
return tf.keras.Model(inputs, outputs)
SRGAN的损失函数
SRGAN引入了感知损失(Perceptual Loss),该损失通过深度网络提取图像特征,同时结合对抗损失来优化生成图像的质量。感知损失定义为生成图像与真实图像在高层特征空间上的差异:
《超分辨率生成对抗网络(SRGAN)之SRGAN的架构》适合边看图边读正文。先确认问题和判断标准,再看概念解释与练习步骤,信息会更容易连成一条线。
其中,是生成器生成的图像,是真实的高分辨率图像,是一个预训练的特征提取网络(如VGG网络)的第层。
实践案例
以下是一段完整的训练SRGAN的示例代码框架,其中包含生成器、判别器和训练过程的简要实现:
import numpy as np
def train_srgan(generator, discriminator, dataset, epochs=100, batch_size=16):
for epoch in range(epochs):
for low_res_images, high_res_images in dataset.batch(batch_size):
# 生成高分辨率图片
generated_images = generator(low_res_images)
# 训练判别器
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
discriminator_loss_real = discriminator.train_on_batch(high_res_images, real_labels)
discriminator_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
# 训练生成器
generator_loss = srgan_train_on_batch(low_res_images, high_res_images)
print(f"Epoch: {epoch+1}, Discriminator Loss: {discriminator_loss_real + discriminator_loss_fake}, Generator Loss: {generator_loss}")
# 利用上面构建的模型进行训练
generator = build_generator()
discriminator = build_discriminator()
# dataset 应该是加载的低分辨率和高分辨率图像对
train_srgan(generator, discriminator, dataset)
学完《超分辨率生成对抗网络(SRGAN)之SRGAN的架构》后,不妨换一个自己的场景试一次,重点观察输入、处理和输出是否能对应起来。
如果想把《超分辨率生成对抗网络(SRGAN)之SRGAN的架构》用到自己的任务里,可以先缩小场景,只验证一个最关键的判断点。
总结
在本篇中,我们详细介绍了超分辨率生成对抗网络(SRGAN)的架构,包括其生成器和判别器的具体设计,以及损失函数的构建。SRGAN不仅在传统图像处理领域展示了良好的超分辨率性能,且为深度学习领域的图像生成任务提供了重要的思路和灵感。接下来的篇幅将集中在超分辨率的实际实现上,我们将探讨如何使用SRGAN对给定的低分辨率图像进行超分辨率重建。
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
超分辨率生成对抗网络(SRGAN)之SRGAN的架构适合谁读?
这是 生成对抗网络高级 系列第 10 / 21 篇,适合正在学习生成对抗网络高级,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇生成对抗网络高级教程要多久?
按中文技术文章阅读速度估算,通读大约 4 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读