11 神经网络基础之如何定义模型
系列进度
PyTorch 入门 · 第 11 / 20 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
用 nn.Module 定义模型时,__init__ 放层,forward 放数据流。两者分清,模型才容易调试。
官方教程:PyTorch Build the Neural Network
模型写完后,我会先用一批假输入跑一次 forward,确认输出 shape 对得上,再进入训练循环。
在学习神经网络时,除了了解其基本结构外,如何定义和构建一个神经网络模型是接下来的重要步骤。在本篇中,我们将通过 PyTorch 这个深受欢迎的深度学习框架,来学习如何定义一个基本的神经网络模型。
定义模型的基本步骤
在 PyTorch 中,定义一个神经网络模型主要涉及到以下几个步骤:
用 PyTorch 定义模型时,先写清层参数、输入形状、forward 流程、输出维度和损失函数需求。
-
导入所需的库: 首先,我们需要导入相关的 PyTorch 库。
-
创建模型类: 在 PyTorch 中,神经网络模型通常是通过继承
torch.nn.Module类来定义的。 -
定义网络层: 在模型的构造函数中定义需要的网络层,例如全连接层、卷积层等。
-
实现前向传播方法
forward: 定义如何将输入数据通过网络层进行转换。
1. 导入所需的库
在开始之前,我们需要导入 PyTorch 和相关的库:
import torch
import torch.nn as nn
import torch.optim as optim
2. 创建模型类
接下来,我们创建一个名为 SimpleNN 的模型类,继承自 nn.Module:
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
# 定义全连接层
self.fc1 = nn.Linear(input_size, hidden_size) # 隐藏层
self.fc2 = nn.Linear(hidden_size, output_size) # 输出层
def forward(self, x):
# 前向传播
x = torch.relu(self.fc1(x)) # 使用 ReLU 激活函数
x = self.fc2(x)
return x
在这段代码中,__init__ 方法用于定义网络的层,而 forward 方法定义了如何通过这些层进行前向传播。
3. 定义网络层
在 __init__ 方法中,我们定义了两个全连接层:
self.fc1:输入层到隐藏层。self.fc2:隐藏层到输出层。
隐藏层的神经元数量由 hidden_size 参数决定。
4. 实现前向传播
在 forward 方法中,我们首先将输入数据 x 传递给第一层 fc1,得到隐藏层的输出,然后使用 ReLU 激活函数进行非线性映射。最后,将隐藏层的输出传递给第二层 fc2,得到最终的输出。
模型实例化与使用
一旦模型类已经定义好,我们就可以实例化该模型并进行训练或测试了。
《神经网络基础之如何定义模型》可以按“场景、概念、动作、结果”来读。先把这四件事对齐,再回到正文里的参数、代码或流程。
示例代码
以下是如何实例化该模型并创建一个随机输入数据的示例:
# 定义输入、隐藏和输出层的神经元数量
input_size = 10
hidden_size = 5
output_size = 2
# 实例化模型
model = SimpleNN(input_size, hidden_size, output_size)
# 创建一个随机输入数据(例如,批大小为 1)
input_data = torch.randn(1, input_size)
# 进行前向传播
output_data = model(input_data)
print("Output:", output_data)
在这个示例中,我们定义了一个输入为10个神经元、隐藏层为5个神经元和输出层为2个神经元的模型。通过用 torch.randn 创建的随机输入数据,可以看到模型的输出。
如果《神经网络基础之如何定义模型》还没完全消化,可以从这张卡片的四个动作重新走一遍。
回看《神经网络基础之如何定义模型》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。
总结
在本篇中,我们学习了如何在 PyTorch 中定义一个简单的神经网络模型。我们通过定义模型类、初始化网络层和实现前向传播等步骤,为后续的模型训练和推理奠定了基础。随着接下来的学习,我们将深入探讨激活函数的使用及其对模型表现的影响。
在下一篇中,我们将重点讨论 激活函数 的使用以及它们在神经网络中的重要性,敬请期待!
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
神经网络基础之如何定义模型适合谁读?
这是 PyTorch 入门 系列第 11 / 20 篇,适合正在学习PyTorch 入门,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇PyTorch 入门教程要多久?
按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读