郭震 AI公众号:郭震AI

11 GAN的训练过程之训练循环的实现

发布日期:

最近更新:

分类: GAN网络从零教程

预计阅读: 3 分钟

阅读次数: 0

预计阅读3 分钟
结构重点5 个
图文要点6 张
正文规模1.2k 字

整理说明

这篇内容怎么整理

郭震 · 2026-06-04

独立整理围绕 5 个结构重点拆成环境、步骤、验证点和常见误区,尽量让读者能照着复现。
图文对照保留 6 张和配置、流程、判断结果有关的图片,方便快速定位正文重点。
持续校对工具、模型和命令变化较快,后续优先修正入口、参数和风险提醒。

阅读路线

先按这条路线读

先抓住主线,再回到代码、配置和图文细节,读起来会更稳。

图文要点

先看本文图文节点

按图先建立主线,再跳回正文核对步骤、配置和判断标准。

GAN的训练过程之训练循环的实现结构图查看大图
GAN的训练过程之训练循环的实现结构图

GAN 的关键是生成器和判别器互相推动,学习时要同时看结构、训练和样本质量。阅读时可以按「训练循环的基本结构 -> 训练具体实现 -> 定义判别器和生成器的损失函数 -> 训练过程的代码实现」建立结构,再回到正文里的代码、案例或指标做验证。

GAN的训练过程之训练循环的实现核对图查看大图
GAN的训练过程之训练循环的实现核对图

读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「训练循环的基本结构」,再查「训练具体实现」。

在上一篇中,我们讨论了GAN的训练过程,其中涉及数据准备与预处理。这一步是至关重要的,因为好的数据不仅可以帮助我们更好地训练生成对抗网络(GAN),而且能显著提升生成效果。在本篇中,我们将专注于实现GAN的训练循环,通过实际的代码示例来说明具体细节。

训练循环的基本结构

GAN的训练循环主要包括以下几个步骤:

GAN训练循环判断卡查看大图
GAN训练循环判断卡

实现 GAN 训练循环时,先看数据批次、判别器更新、生成器更新、损失记录和样本保存。

  1. 生成器前向传播:使用随机噪声生成伪造的数据。
  2. 判别器前向传播:将生成的数据与真实的数据一起输入判别器,计算损失。
  3. 反向传播和优化
    • 优化判别器,利用真实样本和生成样本的损失。
    • 优化生成器,利用判别器对生成样本的反馈进行调整。

以下是一个基本的训练循环结构示意:

for epoch in range(num_epochs):
    for i, (real_data, _) in enumerate(data_loader):
        # 生成器的训练
        noise = torch.randn(batch_size, z_dim)  # 噪声输入
        fake_data = generator(noise)             # 生成伪造数据
        d_loss, g_loss = train_gan(real_data, fake_data)

        # 打印损失值
        if (i+1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}')

训练具体实现

接下来,我们将具体实现train_gan函数,该函数将会实现前述的训练优化步骤。

生成对抗网络阅读地图卡查看大图
生成对抗网络阅读地图卡

《GAN的训练过程之训练循环的实现》适合边看图边读正文。先确认问题和判断标准,再看概念解释与练习步骤,信息会更容易连成一条线。

1. 定义判别器和生成器的损失函数

在训练GAN时,我们通常使用交叉熵损失。对于判别器,我们需要最大化真实样本的概率,同时最小化生成样本的概率。生成器的目标是让生成样本尽可能地被判别器识别为真实样本。其损失函数可以定义如下:

  • 判别器损失
D_loss=12(E[log(D(real))]+E[log(1D(fake))])D\_loss = -\frac{1}{2}(E[log(D(real))] + E[log(1 - D(fake))])
  • 生成器损失
G_loss=E[log(D(fake))]G\_loss = -E[log(D(fake))]

2. 训练过程的代码实现

以下是更为详细的代码,包含如何在每个训练步骤中更新生成器和判别器:

import torch
import torch.nn as nn
import torch.optim as optim

# 伪造数据使用的网络
class Generator(nn.Module):
    # 定义生成器结构
    ...

class Discriminator(nn.Module):
    # 定义判别器结构
    ...

generator = Generator()
discriminator = Discriminator()

criterion = nn.BCELoss() 
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

def train_gan(real_data, fake_data):
    # 判别器训练
    d_optimizer.zero_grad()

    # 标签设置
    real_labels = torch.ones(real_data.size(0), 1)  # 真实样本标签
    fake_labels = torch.zeros(fake_data.size(0), 1) # 伪造样本标签

    # 计算判别器对真实数据的损失
    outputs = discriminator(real_data)
    d_loss_real = criterion(outputs, real_labels)

    # 计算判别器对伪造数据的损失
    outputs = discriminator(fake_data.detach())
    d_loss_fake = criterion(outputs, fake_labels)

    # 反向传播并优化判别器
    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    d_optimizer.step()

    # 生成器训练
    g_optimizer.zero_grad()

    # 生成器损失
    outputs = discriminator(fake_data)
    g_loss = criterion(outputs, real_labels)
    
    # 反向传播并优化生成器
    g_loss.backward()
    g_optimizer.step()

    return d_loss.item(), g_loss.item()
GAN的训练过程之训练循环的实现应用复盘卡查看大图
GAN的训练过程之训练循环的实现应用复盘卡

如果《GAN的训练过程之训练循环的实现》还没完全消化,可以从这张卡片的四个动作重新走一遍。

GAN的训练过程之训练循环的实现应用检查卡查看大图
GAN的训练过程之训练循环的实现应用检查卡

回看《GAN的训练过程之训练循环的实现》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。

总结

在本篇中,我们介绍了GAN训练过程中不可或缺的训练循环实现,涵盖了生成器和判别器的损失计算、优化步骤及相关代码示例。这个循环的高效实现是GAN训练成功的关键,通过不断调整和优化,实现生成过程的迭代提升。接下来,我们将在下一篇中集中讨论如何评估GAN的性能,这也是理解模型可靠性的重要步骤。通过完整的训练与评估流程,可以帮助我们创建高质量的生成模型。

继续阅读

从这篇继续找到相关教程

AI 教程总索引

常见问题

读前先确认这三点

GAN的训练过程之训练循环的实现适合谁读?

这是 GAN 网络教程 系列第 11 / 21 篇,适合正在学习GAN 网络教程,并且需要把概念落到操作步骤或判断标准里的读者。

读这篇GAN 网络教程要多久?

按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。

这篇文章里的图文节点怎么用?

正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。

分享文章

转发到常用平台

微信/朋友圈可先复制链接

相关教程

AI 教程总索引

继续阅读

继续找到相关 AI 教程

返回栏目

Reader Messages

读者留言

有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

最多 800 字

为了防刷,每条留言会做长度、链接数量和提交频率限制。

0/800

留言列表

0
正在加载留言...