郭震 AI公众号:郭震AI

9 条件GAN的训练和评估

发布日期:

最近更新:

分类: GANs进阶

预计阅读: 4 分钟

阅读次数: 0

预计阅读4 分钟
结构重点6 个
图文要点6 张
正文规模1.8k 字

整理说明

这篇内容怎么整理

郭震 · 2026-06-04

独立整理围绕 6 个结构重点拆成环境、步骤、验证点和常见误区,尽量让读者能照着复现。
图文对照保留 6 张和配置、流程、判断结果有关的图片,方便快速定位正文重点。
持续校对工具、模型和命令变化较快,后续优先修正入口、参数和风险提醒。

阅读路线

先按这条路线读

先抓住主线,再回到代码、配置和图文细节,读起来会更稳。

图文要点

先看本文图文节点

按图先建立主线,再跳回正文核对步骤、配置和判断标准。

条件GAN的训练和评估结构图查看大图
条件GAN的训练和评估结构图

GAN 进阶内容要围绕稳定性、条件控制、架构变化和评估方法建立判断框架。阅读时可以按「条件GAN的训练 -> 训练过程 -> 训练中的技巧 -> 条件GAN的评估」建立结构,再回到正文里的代码、案例或指标做验证。

条件GAN的训练和评估核对图查看大图
条件GAN的训练和评估核对图

读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「条件GAN的训练」,再查「训练过程」。

在之前的文章中,我们探讨了条件生成对抗网络(cGAN)的应用实例。为了更深入地了解cGAN的工作原理,本篇将着重讨论其训练和评估方法。在深度学习的实践中,训练过程的设计和评估标准的选择直接影响模型的质量和应用效果。因此,我们将详细分析如何有效训练cGAN以及如何评估其生成结果。

1. 条件GAN的训练

1.1 训练过程

cGAN训练评估判断卡查看大图
cGAN训练评估判断卡

训练和评估 cGAN 时,先看条件标签是否生效、样本是否多样、判别器是否过强、指标是否稳定。

cGAN的训练过程与传统GAN类似,但我们在生成器和判别器中引入了条件信息。下面,我们将以MNIST手写数字生成的示例来说明cGAN的训练步骤。

  1. 准备数据集: 首先,我们需要加载MNIST数据集,并将其转换为可以供模型使用的格式。我们将每个图像与其对应的标签相结合,以使得生成器能够根据标签生成特定的数字。

    from keras.datasets import mnist
    import numpy as np
    
    # 加载数据
    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = X_train.astype('float32') / 255.0
    y_train = y_train.astype('float32')
    
    # 将数据集扩展为(样本,宽度,高度,通道)
    X_train = np.expand_dims(X_train, axis=-1)
    
  2. 构建生成器和判别器: cGAN的生成器和判别器的构建需同时接收条件信息。例如,生成器将随机噪声和标签作为输入,判别器将图像和标签作为输入。

    from keras.layers import Input, Dense, Reshape, Concatenate
    from keras.models import Model
    
    def build_generator():
        noise = Input(shape=(100,))
        label = Input(shape=(10,))
        model_input = Concatenate()([noise, label])
        x = Dense(128)(model_input)
        x = Reshape((4, 4, 8))(x)
        return Model([noise, label], x)
    
    def build_discriminator():
        img = Input(shape=(28, 28, 1))
        label = Input(shape=(10,))
        model_input = Concatenate()([img, label])
        x = Dense(128)(model_input)
        return Model([img, label], x)
    
  3. 定义损失和优化器: 在cGAN中,损失函数通常使用二元交叉熵(binary crossentropy)。同时,将生成器和判别器编译为可优化的模型。

from keras.optimizers import Adam

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam())
  • 训练循环: cGAN的训练循环包括以下步骤:

    • 随机选择一个标签;
    • 生成随机噪声;
    • 将噪声和标签输入生成器,生成伪样本;
    • 真实样本与伪样本一起喂入判别器进行训练。
    for epoch in range(num_epochs):
        for _ in range(batch_count):
            # 随机选择一个标签
            random_indices = np.random.randint(0, X_train.shape[0], batch_size)
            real_images = X_train[random_indices]
            labels = y_train[random_indices]
            
            # 生成随机噪声
            noise = np.random.normal(0, 1, (batch_size, 100))
            generated_images = generator.predict([noise, labels])
            
            # 生成标签one-hot编码
            real_labels = np.zeros((batch_size, 1))
            fake_labels = np.ones((batch_size, 1))
            d_loss_real = discriminator.train_on_batch([real_images, labels], real_labels)
            d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake_labels)
            
            # 训练生成器
            noise = np.random.normal(0, 1, (batch_size, 100))
            valid_labels = np.ones((batch_size, 1))
            g_loss = combined_model.train_on_batch([noise, labels], valid_labels)
    
  • 1.2 训练中的技巧

    • Label Smoothing:通过降低真实标签的值来增强判别器的稳定性。
    • 样本平衡:确保从每个类中均匀选取样本,以减少数据偏差。
    • 动态学习率:根据训练阶段动态调整学习率,优化训练效果。

    2. 条件GAN的评估

    评估生成模型的性能具有挑战性,特别是当生成数据与真实数据的质量和多样性都需要被考虑时。以下是几种评估方法:

    条件GAN的训练和评估应用检查卡查看大图
    条件GAN的训练和评估应用检查卡

    练习《条件GAN的训练和评估》时,建议把输入条件、处理动作和可见结果写在一起,方便下次复查。

    条件GAN的训练和评估应用复盘卡查看大图
    条件GAN的训练和评估应用复盘卡

    复习《条件GAN的训练和评估》时,建议把关键概念、操作步骤和可见结果放在同一页里回看。

    GAN 进阶阅读地图卡查看大图
    GAN 进阶阅读地图卡

    看完《条件GAN的训练和评估》后,建议用一分钟复盘:关键概念是否分清、练习步骤是否可复现、结论能不能换成自己的话。

    2.1 可视化生成效果

    最直接的方法是通过可视化生成的图像来评估其质量。在MNIST例子中,可以随机生成几个样本并展示:

    import matplotlib.pyplot as plt
    
    # 随机生成一些样本
    noise = np.random.normal(0, 1, (10, 100))
    labels = np.array([i for i in range(10)]).reshape(-1, 1)
    labels = np.random.randint(0, 10, size=(10, 10))  # Random one-hot labels
    
    generated_images = generator.predict([noise, labels])
    plt.figure(figsize=(10, 10))
    for i in range(10):
        plt.subplot(5, 10, i + 1)
        plt.imshow(generated_images[i].reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()
    

    2.2 FID和IS指标

    Fréchet Inception Distance (FID)Inception Score (IS)是评估生成模型性能的常用指标。FID越低,表示生成样本与真实样本的相似度越高。IS则评估生成图像的多样性和质量。

    实现FID的Python代码示例:

    from scipy.linalg import sqrtm
    
    def calculate_fid(real_images, generated_images):
        # 假设real_images和generated_images的形状都为(num_samples, 28, 28, 1)
        mu1, sigma1 = calculate_statistics(real_images)
        mu2, sigma2 = calculate_statistics(generated_images)
        fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
        return fid_value
    
    def calculate_statistics(images):
        # 计算均值和协方差矩阵
        mu = np.mean(images, axis=0)
        sigma = np.cov(images, rowvar=False)
        return mu, sigma
    
    def calculate_frechet_distance(mu1, sigma
    

    继续阅读

    从这篇继续找到相关教程

    AI 教程总索引

    常见问题

    读前先确认这三点

    条件GAN的训练和评估适合谁读?

    这是 生成对抗网络高级 系列第 9 / 21 篇,适合正在学习生成对抗网络高级,并且需要把概念落到操作步骤或判断标准里的读者。

    读这篇生成对抗网络高级教程要多久?

    按中文技术文章阅读速度估算,通读大约 4 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。

    这篇文章里的图文节点怎么用?

    正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。

    分享文章

    转发到常用平台

    微信/朋友圈可先复制链接

    相关教程

    AI 教程总索引

    继续阅读

    继续找到相关 AI 教程

    返回栏目

    Reader Messages

    读者留言

    有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

    最多 800 字

    为了防刷,每条留言会做长度、链接数量和提交频率限制。

    0/800

    留言列表

    0
    正在加载留言...