5 GAN的基本原理之损失函数的定义
系列进度
GAN 网络教程 · 第 5 / 21 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
GAN 的关键是生成器和判别器互相推动,学习时要同时看结构、训练和样本质量。阅读时可以按「损失函数的基本概念 -> 对抗损失函数 -> 生成器的损失 -> 最优解」建立结构,再回到正文里的代码、案例或指标做验证。
读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「损失函数的基本概念」,再查「对抗损失函数」。
在上一篇中,我们探讨了生成对抗网络(GAN)中生成器和判别器的角色。生成器的任务是生成尽可能真实的数据,而判别器则负责区分实际数据和生成数据的真假。在这一节中,我们将深入了解损失函数的定义,它是衡量生成器与判别器性能的核心。
损失函数的基本概念
在 GAN 中,损失函数用于优化生成器和判别器。我们需要定义损失函数,使两个网络相互竞争,从而提升生成器的生成能力和判别器的识别能力。
学习 GAN 损失函数时,先看判别器如何区分真假,再看生成器如何利用这个反馈改进样本。
对抗损失函数
GAN 的核心思想是“对抗”。我们通过以下公式来定义对抗损失:
在这个公式中:
- 是判别器在真实数据 上的输出。
- 是判别器在生成数据 上的输出。
这里, 越接近 1, 越接近 0,损失就越小,说明判别器能够很好地区分真实和生成的数据。
生成器的损失
生成器的目标是使判别器误以为生成的数据是真实的。因此,生成器的损失函数为:
在这个公式中, 是生成器生成的数据。生成器的目标是最大化 ,使判别器认为这些生成的数据是真实的。
最优解
在理论上,当 GAN 的训练达到平衡状态时,总损失函数 应该减少到 0:
- 判别器 的输出来区分真实样本和生成样本都相等,即 和 。
- 此时生成器能够生成非常逼真的样本,以至于判别器无法区分。
案例分析
请考虑一个简单的场景,我们使用 GAN 来生成手写数字图像(例如 MNIST 数据集)。在训练过程中,生成器试图生成手写数字图像,而判别器则试图区分真实的手写数字和生成的手写数字。
阅读《GAN的基本原理之损失函数的定义》前,可以先用配图确认主线;读完后再检查哪些步骤能直接操作,哪些还需要补资料。
代码示例
以下是一个简单的 GAN 实现示例,演示如何定义损失函数并进行优化。
import torch
import torch.nn as nn
import torch.optim as optim
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 28*28),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
# 初始化网络
generator = Generator()
discriminator = Discriminator()
# 定义损失函数
criterion = nn.BCELoss()
# 定义优化器
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 假设 z 是从标准正态分布中随机采样的噪声
# 真实样本的标签是真实标签 1,生成样本的标签是假标签 0
z = torch.randn(64, 100)
real_samples = torch.randint(0, 2, (64, 1)).float() # 假设这是从真实数据集中提取的真实样本
# 判别器的损失
D_real = discriminator(real_samples)
D_fake = discriminator(generator(z))
loss_d = criterion(D_real, torch.ones_like(D_real)) + criterion(D_fake, torch.zeros_like(D_fake))
# 生成器的损失
loss_g = criterion(D_fake, torch.ones_like(D_fake)) # 生成器希望 D_fake 接近 1
# 更新判别器和生成器的参数
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
在这个例子中,我们定义了生成器和判别器的结构,并使用二元交叉熵损失(BCE)作为损失函数。通过如下动作,生成器和判别器可以在训练过程中不断优化。
复习《GAN的基本原理之损失函数的定义》时,建议把关键概念、操作步骤和可见结果放在同一页里回看。
练习《GAN的基本原理之损失函数的定义》时,建议把输入条件、处理动作和可见结果写在一起,方便下次复查。
总结
本节我们详细讨论了 GAN 中损失函数的定义。我们了解了生成器和判别器如何通过对抗性损失进行优化,从而不断提升生成数据的质量。损失函数是 GAN 训练的核心,通过精心设计的损失函数,我们可以实现理想的对抗训练。在下一节中,我们将探讨 GAN 的对抗训练流程,深入分析如何应用这些损失函数来实现有效的训练。
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
GAN的基本原理之损失函数的定义适合谁读?
这是 GAN 网络教程 系列第 5 / 21 篇,适合正在学习GAN 网络教程,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇GAN 网络教程要多久?
按中文技术文章阅读速度估算,通读大约 4 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读