郭震 AI公众号:郭震AI

21 Keras进阶之自定义回调

发布日期:

最近更新:

分类: Keras

预计阅读: 3 分钟

阅读次数: 0

预计阅读3 分钟
结构重点6 个
图文要点6 张
正文规模1.4k 字

整理说明

这篇内容怎么整理

郭震 · 2026-06-04

独立整理围绕 6 个结构重点拆成环境、步骤、验证点和常见误区,尽量让读者能照着复现。
图文对照保留 6 张和配置、流程、判断结果有关的图片,方便快速定位正文重点。
持续校对工具、模型和命令变化较快,后续优先修正入口、参数和风险提醒。

阅读路线

先按这条路线读

先抓住主线,再回到代码、配置和图文细节,读起来会更稳。

图文要点

先看本文图文节点

按图先建立主线,再跳回正文核对步骤、配置和判断标准。

自定义回调流程图查看大图
自定义回调流程图

自定义回调适合处理项目特有的监控、记录和控制需求。它应该小而明确,避免隐藏太多业务逻辑。

自定义回调实操核对图查看大图
自定义回调实操核对图

我会让回调只做一件事,并用日志确认触发时机。副作用太多的回调,后面很难排查。

在上一期的教程中,我们介绍了迁移学习的概念以及如何在Keras中实施迁移学习。这篇文章将深入探讨Keras中的自定义回调(Custom Callbacks)。自定义回调是Keras中一个强大的功能,它允许开发者在训练过程中动态地实现控制和监测。这对于模型的监控、训练过程的调整以及其他个性化需求非常重要。

什么是回调?

回调函数是Keras在训练过程中执行的特定功能。Keras提供了一些内置的回调,例如监测模型性能的EarlyStopping和保存模型的ModelCheckpoint。自定义回调使得我们可以针对具体需求设计自己的回调逻辑。

回调的基本结构

在Keras中,自定义回调需要继承自tf.keras.callbacks.Callback类,并重写其中的方法。以下是常用的回调方法:

  • on_epoch_begin: 在每个epoch开始时执行
  • on_epoch_end: 在每个epoch结束时执行
  • on_batch_begin: 在每个batch开始时执行
  • on_batch_end: 在每个batch结束时执行
  • on_train_begin: 在训练开始时执行
  • on_train_end: 在训练结束时执行

示例:自定义回调

下面,我们将创建一个简单的自定义回调,它会在每个epoch结束时打印出当前的损失和精度,并保存最佳的模型权重。

import tensorflow as tf

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super(CustomCallback, self).__init__()
        self.best_loss = float('inf')

    def on_epoch_end(self, epoch, logs=None):
        current_loss = logs.get('loss')
        current_accuracy = logs.get('accuracy')

        print(f"Epoch {epoch + 1}: loss = {current_loss:.4f}, accuracy = {current_accuracy:.4f}")

        # 保存最佳模型
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            print("Saving the best model...")
            self.model.save_weights('best_model_weights.h5')

# 示例模型
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# 创建并训练模型
model = create_model()
custom_callback = CustomCallback()

# 生成一些随机数据
import numpy as np
X_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))

model.fit(X_train, y_train, epochs=10, batch_size=32, callbacks=[custom_callback])

代码解释

  1. 我们定义了一个名为CustomCallback的类,它继承自tf.keras.callbacks.Callback
  2. __init__方法中,我们初始化best_loss为正无穷,这样在第一个epoch中无论损失是多少,它总会被更新。
  3. on_epoch_end方法中,我们获取当前epoch的损失和准确率,并打印它们。如果当前损失小于记录的最佳损失,那么就保存模型的权重。
Keras自定义回调判断卡查看大图
Keras自定义回调判断卡

使用 Keras 自定义回调时,先看触发时机、日志字段、模型保存、学习率调整、异常处理和验证曲线。

使用自定义回调的好处

使用自定义回调的优势包括但不限于:

Keras阅读地图卡查看大图
Keras阅读地图卡

读《Keras进阶之自定义回调》时,可以把配图当成路线卡:先看整体顺序,再看每一步为什么这样做,最后再检查边界条件。

  • 灵活性: 可以根据特定需求为训练过程添加细粒度的控制。
  • 监控: 可以在训练过程中监控关键信息并做出相应调整。
  • 自动化: 可以自动化一些常见任务,例如保存模型、调整学习率等。
Keras进阶之自定义回调应用复盘卡查看大图
Keras进阶之自定义回调应用复盘卡

复习《Keras进阶之自定义回调》时,建议把关键概念、操作步骤和可见结果放在同一页里回看。

Keras进阶之自定义回调应用检查卡查看大图
Keras进阶之自定义回调应用检查卡

练习《Keras进阶之自定义回调》时,建议把输入条件、处理动作和可见结果写在一起,方便下次复查。

总结

在这篇文章中,我们深入探讨了如何创建和使用自定义回调,这在Keras模型训练过程中能提供额外的灵活性和控制能力。通过上述示例,你可以看到自定义回调如何在训练过程中监控模型性能并做出反应。在下一篇文章中,我们将讨论调整学习率的方法——Fine-tuning,并结合自定义回调进一步优化模型表现。

希望这篇文章能够帮助你更好地理解和利用Keras中的自定义回调功能!

继续阅读

从这篇继续找到相关教程

AI 教程总索引

常见问题

读前先确认这三点

Keras进阶之自定义回调适合谁读?

这是 Keras 入门 系列第 21 / 28 篇,适合正在学习Keras 入门,并且需要把概念落到操作步骤或判断标准里的读者。

读这篇Keras 入门教程要多久?

按中文技术文章阅读速度估算,通读大约 3 分钟;如果要跟着复现,建议把命令、配置和结果检查分开做。

这篇文章里的图文节点怎么用?

正文里有 6 个图文节点,可以先用它们抓住流程、配置和判断点,再回到对应段落细读。

分享文章

转发到常用平台

微信/朋友圈可先复制链接

相关教程

AI 教程总索引

继续阅读

继续找到相关 AI 教程

返回栏目

Reader Messages

读者留言

有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

最多 800 字

为了防刷,每条留言会做长度、链接数量和提交频率限制。

0/800

留言列表

0
正在加载留言...