20 Keras进阶之迁移学习
系列进度
Keras 入门 · 第 20 / 28 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
迁移学习适合数据不多但任务相近的场景。先复用通用特征,再训练自己的任务头,通常比从零训练稳定。
我会记录哪些层被冻结、哪些层训练。迁移学习不清楚冻结边界,很容易把预训练能力洗掉。
在上一篇中,我们探讨了模型评估与预测,并学习了如何使用 Keras 进行模型的预测,接下来我们将打开迁移学习的新篇章。迁移学习是一种强大的深度学习技术,可以让我们在一个任务上利用在另一个任务上学到的知识,尤其是在处理数据有限的情况下。
什么是迁移学习?
迁移学习是利用已经训练好的模型参数以节省训练时间和资源的技术。通过将预训练模型的知识迁移到新的任务中,我们可以在较小的数据集上获得更好的性能。常见的预训练模型包括 VGG16、ResNet、Inception 等。
例如,在图像分类任务中,我们通常会使用在 ImageNet 上训练好的模型来初始化我们的网络,然后只需对少量新样本进行微调。
迁移学习的基本流程
- 选择预训练模型:根据任务的需求选择适合的预训练模型。
- 加载预训练权重:将预训练模型的权重加载到 Keras 中。
- 冻结部分层:冻结模型的一些层以保留它们的特征。
- 添加自定义层:根据新的任务添加自定义层。
- 编译和训练模型:编译模型并在目标数据集上进行训练。
使用 Keras 进行迁移学习的案例
以下是一个使用 Keras 进行迁移学习的实际案例,使用 VGG16 模型来识别猫和狗的图像。
1. 导入所需的库
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
2. 数据预处理
使用 ImageDataGenerator 来加载和增强数据。
# 定义图像大小和路径
img_size = (224, 224)
train_data_dir = 'path/to/train'
validation_data_dir = 'path/to/validation'
# 数据增强
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen = ImageDataGenerator(rescale=1./255)
# 加载数据
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=img_size,
batch_size=32,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
validation_data_dir,
target_size=img_size,
batch_size=32,
class_mode='binary'
)
3. 加载预训练模型
我们将使用 VGG16 模型,并去掉其最后的全连接层。
# 加载 VGG16 模型,且不包括顶层全连接层
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# 冻结前面的卷积层
for layer in base_model.layers:
layer.trainable = False
4. 添加自定义层
我们将根据新的任务添加自定义层。
使用 Keras 迁移学习时,先看预训练模型、输入尺寸、冻结策略、分类头、学习率和验证曲线。
# 创建新的模型
x = Flatten()(base_model.output)
x = Dense(256, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=predictions)
5. 编译模型
model.compile(optimizer=Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy'])
6. 训练模型
使用我们准备好的数据生成器来训练模型。
读《Keras进阶之迁移学习》时,可以先看配图里的任务、概念、练习和判断点,再回到正文补细节。这样更容易判断这篇内容能放到哪个真实场景里。
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
validation_data=validation_generator,
validation_steps=validation_generator.samples // validation_generator.batch_size,
epochs=10
)
评估模型
在模型训练完成后,可以通过验证集进行评估。
loss, accuracy = model.evaluate(validation_generator)
print(f"Validation Loss: {loss}, Validation Accuracy: {accuracy}")
读到这里,可以把《Keras进阶之迁移学习》整理成一张复盘表:先说清主线,再拿一个小任务检查结果。
读完《Keras进阶之迁移学习》后,可以先挑一个小样例走完整流程,再判断哪些步骤已经能独立完成。
结论
通过迁移学习,我们能够快速构建高效的深度学习模型,尤其是在小数据集的情况下。通过本文中详细的案例,我们展示了如何在 Keras 中实现迁移学习的基本流程。在下一篇中,我们将介绍如何进行 Keras 自定义回调,以便在训练过程中实现更细致的控制和监测。
针对迁移学习,你还有什么疑问或需要进一步探讨的地方吗?欢迎在评论区与我们讨论!
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
Keras进阶之迁移学习适合谁读?
这是 Keras 入门 系列第 20 / 28 篇,适合正在学习Keras 入门,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇Keras 入门教程要多久?
按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读