郭震 AI公众号:郭震AI

15 改善 GAN 训练之模型架构的变化

发布日期:

最近更新:

分类: GAN网络从零教程

预计阅读: 4 分钟

阅读次数: 0

预计阅读4 分钟
结构重点8 个
图文要点6 张
正文规模1.5k 字

整理说明

这篇内容怎么整理

郭震 · 2026-06-04

独立整理围绕 8 个结构重点拆成环境、步骤、验证点和常见误区,尽量让读者能照着复现。
图文对照保留 6 张和配置、流程、判断结果有关的图片,方便快速定位正文重点。
持续校对工具、模型和命令变化较快,后续优先修正入口、参数和风险提醒。

阅读路线

先按这条路线读

先抓住主线,再回到代码、配置和图文细节,读起来会更稳。

图文要点

先看本文图文节点

按图先建立主线,再跳回正文核对步骤、配置和判断标准。

改善 GAN 训练之模型架构的变化结构图查看大图
改善 GAN 训练之模型架构的变化结构图

GAN 的关键是生成器和判别器互相推动,学习时要同时看结构、训练和样本质量。阅读时可以按「深度卷积生成对抗网络 -> DCGAN 的架构 -> Wasserstein GAN -> WGAN 的架构示例」建立结构,再回到正文里的代码、案例或指标做验证。

改善 GAN 训练之模型架构的变化核对图查看大图
改善 GAN 训练之模型架构的变化核对图

读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「深度卷积生成对抗网络」,再查「DCGAN 的架构」。

在上一篇中,我们讨论了引入正则化技术以改善 GAN 的训练。正如我们所知,GAN(生成对抗网络)是一种通过生成器和判别器之间的对抗学习来生成新数据的有力工具。然而,除了正则化技术之外,调整模型的架构也是提高 GAN 训练性能的一个有效方法。本篇将探讨几种模型架构的变化,以改进 GAN 的训练效果。

1. 深度卷积生成对抗网络(DCGAN)

在 GAN 的发展的初期,标准的 GAN 使用了浅层的全连接网络,但这在生成复杂数据(如图像)时效果不佳。为了应对这一挑战,深度卷积生成对抗网络(DCGAN) 的提出极大地改善了 GAN 的生成效果。

GAN架构变化判断卡查看大图
GAN架构变化判断卡

比较 GAN 架构变化时,先看网络深度、归一化、残差连接、条件输入、判别器能力和训练稳定性。

DCGAN 的架构

DCGAN 主要通过以下几点来改善生成效果:

  • 使用卷积层:采用卷积层而非全连接层,允许生成器和判别器在空间上保留更多的信息。
  • 批量归一化:在每个卷积层后使用批量归一化,可以加速收敛并提高模型的稳定性。
  • 使用激活函数:在生成器中使用 ReLU 激活函数,而在输出层则使用 tanh 激活函数。判别器则使用 Leaky ReLU 激活。

代码示例

下面是一个简单的 DCGAN 生成器的代码示例:

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim, image_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, image_dim, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.model(input)

在该代码中,nn.ConvTranspose2d 被用于构建转置卷积层,从随机噪声生成图像。

2. Wasserstein GAN(WGAN)

WGAN 提出了一个新的损失函数,即 Wasserstein 距离,来解决 GAN 训练时的不稳定性和模式崩溃问题。WGAN 的关键在于其改进的判别器(也称为“ critic”),采用了以下策略:

生成对抗网络阅读地图卡查看大图
生成对抗网络阅读地图卡

看《改善 GAN 训练之模型架构的变化》时,先把图中的问题、关键词、操作和验收标准对上,再读正文会更省力。读完后,最好能用自己的项目重新讲一遍。

  • 权重裁剪:在每次权重更新后对判别器的权重进行裁剪,以强制执行 1-Lipschitz 连续性。
  • 平滑标签:使用平滑标签(例如,将真实样本标签 1.0 替换为 0.9)可以进一步提高训练的稳定性。

WGAN 的架构示例

WGAN 的判别器可以简单地修改为以下结构,保持 Conv 层的设计理念:

class Critic(nn.Module):
    def __init__(self, image_dim):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(image_dim, 32, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 0)
        )

    def forward(self, input):
        return self.model(input)

3. 采用残差网络(ResNet)

残差网络的引入也使得 GAN 的结构更为灵活和强大。通过使用残差连接,可以使网络更深,并解决梯度消失的问题。生成器和判别器都可以采用残差块的结构,来进一步提高复杂数据的生成能力。

残差块示例

以下是一个简单的残差块实现:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual  # 残差连接
        out = self.relu(out)
        return out
改善 GAN 训练之模型架构的变化应用复盘卡查看大图
改善 GAN 训练之模型架构的变化应用复盘卡

如果《改善 GAN 训练之模型架构的变化》还没完全消化,可以从这张卡片的四个动作重新走一遍。

改善 GAN 训练之模型架构的变化应用检查卡查看大图
改善 GAN 训练之模型架构的变化应用检查卡

回看《改善 GAN 训练之模型架构的变化》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。

结论

通过不同的模型架构的变化,如引入 DCGANWGAN残差网络,可以显著提高 GAN 的训练效果和生成数据的质量。在实际应用中,选择合适的架构可以帮助我们更好地适应特定的生成任务。在下一篇中,我们将探讨 GAN 的应用案例,重点讨论其在 图像生成 领域的具体使用和实际案例分析。

继续阅读

从这篇继续找到相关教程

AI 教程总索引

常见问题

读前先确认这三点

改善 GAN 训练之模型架构的变化适合谁读?

这是 GAN 网络教程 系列第 15 / 21 篇,适合正在学习GAN 网络教程,并且需要把概念落到操作步骤或判断标准里的读者。

读这篇GAN 网络教程要多久?

按中文技术文章阅读速度估算,通读大约 4 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。

这篇文章里的图文节点怎么用?

正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。

分享文章

转发到常用平台

微信/朋友圈可先复制链接

相关教程

AI 教程总索引

继续阅读

继续找到相关 AI 教程

返回栏目

Reader Messages

读者留言

有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

最多 800 字

为了防刷,每条留言会做长度、链接数量和提交频率限制。

0/800

留言列表

0
正在加载留言...