21 Keras进阶之自定义回调
系列进度
Keras 入门 · 第 21 / 28 篇
整理说明
这篇内容怎么整理
郭震 · 2026-06-04
阅读路线
先按这条路线读
先抓住主线,再回到代码、配置和图文细节,读起来会更稳。
自定义回调适合处理项目特有的监控、记录和控制需求。它应该小而明确,避免隐藏太多业务逻辑。
我会让回调只做一件事,并用日志确认触发时机。副作用太多的回调,后面很难排查。
在上一期的教程中,我们介绍了迁移学习的概念以及如何在Keras中实施迁移学习。这篇文章将深入探讨Keras中的自定义回调(Custom Callbacks)。自定义回调是Keras中一个强大的功能,它允许开发者在训练过程中动态地实现控制和监测。这对于模型的监控、训练过程的调整以及其他个性化需求非常重要。
什么是回调?
回调函数是Keras在训练过程中执行的特定功能。Keras提供了一些内置的回调,例如监测模型性能的EarlyStopping和保存模型的ModelCheckpoint。自定义回调使得我们可以针对具体需求设计自己的回调逻辑。
回调的基本结构
在Keras中,自定义回调需要继承自tf.keras.callbacks.Callback类,并重写其中的方法。以下是常用的回调方法:
on_epoch_begin: 在每个epoch开始时执行on_epoch_end: 在每个epoch结束时执行on_batch_begin: 在每个batch开始时执行on_batch_end: 在每个batch结束时执行on_train_begin: 在训练开始时执行on_train_end: 在训练结束时执行
示例:自定义回调
下面,我们将创建一个简单的自定义回调,它会在每个epoch结束时打印出当前的损失和精度,并保存最佳的模型权重。
import tensorflow as tf
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self):
super(CustomCallback, self).__init__()
self.best_loss = float('inf')
def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get('loss')
current_accuracy = logs.get('accuracy')
print(f"Epoch {epoch + 1}: loss = {current_loss:.4f}, accuracy = {current_accuracy:.4f}")
# 保存最佳模型
if current_loss < self.best_loss:
self.best_loss = current_loss
print("Saving the best model...")
self.model.save_weights('best_model_weights.h5')
# 示例模型
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# 创建并训练模型
model = create_model()
custom_callback = CustomCallback()
# 生成一些随机数据
import numpy as np
X_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))
model.fit(X_train, y_train, epochs=10, batch_size=32, callbacks=[custom_callback])
代码解释
- 我们定义了一个名为
CustomCallback的类,它继承自tf.keras.callbacks.Callback。 - 在
__init__方法中,我们初始化best_loss为正无穷,这样在第一个epoch中无论损失是多少,它总会被更新。 - 在
on_epoch_end方法中,我们获取当前epoch的损失和准确率,并打印它们。如果当前损失小于记录的最佳损失,那么就保存模型的权重。
使用 Keras 自定义回调时,先看触发时机、日志字段、模型保存、学习率调整、异常处理和验证曲线。
使用自定义回调的好处
使用自定义回调的优势包括但不限于:
读《Keras进阶之自定义回调》时,可以把配图当成路线卡:先看整体顺序,再看每一步为什么这样做,最后再检查边界条件。
- 灵活性: 可以根据特定需求为训练过程添加细粒度的控制。
- 监控: 可以在训练过程中监控关键信息并做出相应调整。
- 自动化: 可以自动化一些常见任务,例如保存模型、调整学习率等。
复习《Keras进阶之自定义回调》时,建议把关键概念、操作步骤和可见结果放在同一页里回看。
练习《Keras进阶之自定义回调》时,建议把输入条件、处理动作和可见结果写在一起,方便下次复查。
总结
在这篇文章中,我们深入探讨了如何创建和使用自定义回调,这在Keras模型训练过程中能提供额外的灵活性和控制能力。通过上述示例,你可以看到自定义回调如何在训练过程中监控模型性能并做出反应。在下一篇文章中,我们将讨论调整学习率的方法——Fine-tuning,并结合自定义回调进一步优化模型表现。
希望这篇文章能够帮助你更好地理解和利用Keras中的自定义回调功能!
继续阅读
从这篇继续找到相关教程
常见问题
读前先确认这三点
Keras进阶之自定义回调适合谁读?
这是 Keras 入门 系列第 21 / 28 篇,适合正在学习Keras 入门,并且需要把概念落到操作步骤或判断标准里的读者。
读这篇Keras 入门教程要多久?
按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。
这篇文章里的图文节点怎么用?
正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
继续阅读