15 改善 GAN 训练之模型架构的变化
系列进度
GAN 网络教程 · 第 15 / 21 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
GAN 的关键是生成器和判别器互相推动,学习时要同时看结构、训练和样本质量。阅读时可以按「深度卷积生成对抗网络 -> DCGAN 的架构 -> Wasserstein GAN -> WGAN 的架构示例」建立结构,再回到正文里的代码、案例或指标做验证。
读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「深度卷积生成对抗网络」,再查「DCGAN 的架构」。
在上一篇中,我们讨论了引入正则化技术以改善 GAN 的训练。正如我们所知,GAN(生成对抗网络)是一种通过生成器和判别器之间的对抗学习来生成新数据的有力工具。然而,除了正则化技术之外,调整模型的架构也是提高 GAN 训练性能的一个有效方法。本篇将探讨几种模型架构的变化,以改进 GAN 的训练效果。
1. 深度卷积生成对抗网络(DCGAN)
在 GAN 的发展的初期,标准的 GAN 使用了浅层的全连接网络,但这在生成复杂数据(如图像)时效果不佳。为了应对这一挑战,深度卷积生成对抗网络(DCGAN) 的提出极大地改善了 GAN 的生成效果。
比较 GAN 架构变化时,先看网络深度、归一化、残差连接、条件输入、判别器能力和训练稳定性。
DCGAN 的架构
DCGAN 主要通过以下几点来改善生成效果:
- 使用卷积层:采用卷积层而非全连接层,允许生成器和判别器在空间上保留更多的信息。
- 批量归一化:在每个卷积层后使用批量归一化,可以加速收敛并提高模型的稳定性。
- 使用激活函数:在生成器中使用
ReLU激活函数,而在输出层则使用tanh激活函数。判别器则使用Leaky ReLU激活。
代码示例
下面是一个简单的 DCGAN 生成器的代码示例:
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, noise_dim, image_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(noise_dim, 128, 4, 1, 0, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.ConvTranspose2d(32, image_dim, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.model(input)
在该代码中,nn.ConvTranspose2d 被用于构建转置卷积层,从随机噪声生成图像。
2. Wasserstein GAN(WGAN)
WGAN 提出了一个新的损失函数,即 Wasserstein 距离,来解决 GAN 训练时的不稳定性和模式崩溃问题。WGAN 的关键在于其改进的判别器(也称为“ critic”),采用了以下策略:
看《改善 GAN 训练之模型架构的变化》时,先把图中的问题、关键词、操作和验收标准对上,再读正文会更省力。读完后,最好能用自己的项目重新讲一遍。
- 权重裁剪:在每次权重更新后对判别器的权重进行裁剪,以强制执行 1-Lipschitz 连续性。
- 平滑标签:使用平滑标签(例如,将真实样本标签 1.0 替换为 0.9)可以进一步提高训练的稳定性。
WGAN 的架构示例
WGAN 的判别器可以简单地修改为以下结构,保持 Conv 层的设计理念:
class Critic(nn.Module):
def __init__(self, image_dim):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(image_dim, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 1, 4, 1, 0)
)
def forward(self, input):
return self.model(input)
3. 采用残差网络(ResNet)
残差网络的引入也使得 GAN 的结构更为灵活和强大。通过使用残差连接,可以使网络更深,并解决梯度消失的问题。生成器和判别器都可以采用残差块的结构,来进一步提高复杂数据的生成能力。
残差块示例
以下是一个简单的残差块实现:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual # 残差连接
out = self.relu(out)
return out
如果《改善 GAN 训练之模型架构的变化》还没完全消化,可以从这张卡片的四个动作重新走一遍。
回看《改善 GAN 训练之模型架构的变化》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。
结论
通过不同的模型架构的变化,如引入 DCGAN、WGAN 和 残差网络,可以显著提高 GAN 的训练效果和生成数据的质量。在实际应用中,选择合适的架构可以帮助我们更好地适应特定的生成任务。在下一篇中,我们将探讨 GAN 的应用案例,重点讨论其在 图像生成 领域的具体使用和实际案例分析。
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
改善 GAN 训练之模型架构的变化适合谁读?
这是 GAN 网络教程 系列第 15 / 21 篇,适合正在学习GAN 网络教程,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇GAN 网络教程要多久?
按中文技术文章阅读速度估算,通读大约 4 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读