3.4.1 回调函数介绍

Keras通过回调函数API提供Checkpoint功能。回调函数是一组在训练的特定阶段被调用的函数集,可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的fit()函数中,即可在给定的训练阶段调用该函数集中的函数。比如在3.3.3节出现的callback_tensorboard()函数就是一个回调函数,其将日志信息写入TensorBoard,使得可以动态地观察训练和测试指标的图像。

我们之前的训练过程是先训练一遍,然后得到一个验证集的识别率变化趋势,从而知道最佳的训练周期,然后根据得到的最佳训练周期再训练一遍,得到最终结果。这样很浪费时间。一个好方法就是在测试识别率不再上升的时候就终止训练,Keras中的回调函数可以帮助我们做到这一点。回调函数属于obj类型,它可以让模型去拟合,也常在各个点被调用。它存储模型的状态,能够打断训练,保存模型,加载不同的权重,或者替代模型状态。

回调函数可以实现如下功能。

  • 模型断点续训:保存当前模型的所有权重。
  • 提早结束:当模型的损失不再下降的时候就终止训练,当然,回调函数会保存最优的模型。
  • 动态调整训练时的参数,比如优化的学习速度。

以下是Keras内置的回调函数。

  • callback_progbar_logger():将metrics指定的监视输出到标准输出上。
  • callback_model_checkpoint():在每个训练期之后保存模型。
  • callback_early_stopping():当监测值不再改善时,终止训练。
  • callback_remote_monitor():用于向服务器发送事件流。
  • callback_learning_rate_scheduler():学习率调度器。
  • callback_tensorboard():TensorBoard可视化。
  • callback_reduce_lr_on_plateau():当指标停止改善时,降低学习率。
  • callback_csv_logger():把训练周期的训练结果保存到csv文件。
  • callback_lambda():在训练过程中创建简单、自定义的回调函数。
  • KerasCallback:创建基础R6类的Keras回调函数。

接下来,让我们先学习常用回调函数的基本用法。

1. callback_progbar_logger()函数

将metrics指定的监视输出到标准输出上的回调函数,其基本形式为:

callback_progbar_logger(count_mode = "samples",
    stateful_metrics = NULL)

各参数描述如下。

  • count_mode:steps或者samples,表示进度条是否应该对样本或步骤(批量)计数。
  • stateful_metrics:不应在训练周期求平均值的度量名称列表。此列表中的度量标准将按原样记录在on_epoch_end中。列表外的其他指标将在on_epoch_end中取平均值。

2. callback_model_checkpoint()函数

在每个训练期之后保存模型的回调函数,其基本形式为:

callback_model_checkpoint(filepath, monitor = "val_loss", verbose = 0,
    save_best_only = FALSE, save_weights_only = FALSE, mode = c("auto",
    "min", "max"), period = 1)

各参数描述如下。

  • filepath:字符串,保存模型的路径。filepath可以包含命名格式选项,由epoch的值和logs的键(由on_epoch_end参数传递)来填充。
  • monitor:被监测的数据。
  • verbose:信息展示模式,0或者1。
  • save_best_only:如果为TRUE,代表我们只保存最优的训练结果。
  • save_weights_only:如果为TRUE,那么只有模型的权重会被保存(save_model_weights_hdf5(filepath)),否则,整个模型会被保存(save_model_hdf5(filepath))。
  • mode:有auto、min、max三种模式,在save_best_only=TRUE时决定性能最佳模型的评判准则。例如,当监测值为val_acc时,模式应为max;当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
  • period:每个检测点之间的间隔(训练轮数)。

3. callback_early_stopping()函数

当监测值不再改善时,终止训练的回调函数,其基本形式为:

callback_early_stopping(monitor = "val_loss", min_delta = 0,
    patience = 0, verbose = 0, mode = c("auto", "min", "max"),
    baseline = NULL, restore_best_weights = FALSE)

各参数描述如下。

  • monitor:被监测的数据。
  • min_delta:在被监测的数据中被认为是提升的最小变化,例如,小于min_delta的绝对变化会被认为没有提升。
  • patience:没有进步的训练轮数,在这之后的训练会被停止。
  • verbose:详细信息模式。
  • mode:有auto、min、max三种模式。在min模式下,如果监测值停止下降,终止训练;在max模式下,当监测值不再上升时停止训练;在auto模式中,方向会自动从被监测的数据的名字中判断出来。
  • baseline:要监控的数量的基准值。如果模型没有显示基准的改善,训练将停止。
  • restore_best_weights:是否从具有监测数量的最佳值的时期恢复模型权重。如果为FALSE,则使用在训练的最后一步获得的模型权重。

4. callback_learning_rate_scheduler()函数

学习率调度器的回调函数,其基本形式为:

callback_learning_rate_scheduler(schedule)

其中,参数schedule是一个函数,接收轮数作为输入(整数),然后返回一个学习速率作为输出(浮点数)。

5. callback_reduce_lr_on_plateau()函数

当指标停止改善时,降低学习率的回调函数,其基本形式为:

callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1,
    patience = 10, verbose = 0, mode = c("auto", "min", "max"),
    min_delta = 1e-04, cooldown = 0, min_lr = 0)

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。该回调函数检测指标的情况,如果在指定训练轮数(默认为10个)中看不到模型性能提升,则减少学习率。

各参数描述如下。

  • monitor:被监测的数据。
  • factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少。
  • patience:没有进步的训练轮数,在这之后的训练速率会降低。
  • mode:有auto、min、max三种模式。在min模式下,如果监测值停止下降,触发学习率减少;在max模式下,当监测值不再上升时触发学习率减少。
  • min_delta:对于测量新的最优化的阈值,只关注巨大的改变。
  • cooldown:在学习速率被降低之后,重新恢复正常操作之前等待的训练轮数。
  • min_lr:学习速率的下边界。

6. callback_lambda()函数

在训练过程中创建简单、自定义的回调函数,其基本形式为:

callback_lambda(on_epoch_begin = NULL, on_epoch_end = NULL,
    on_batch_begin = NULL, on_batch_end = NULL,
    on_train_batch_begin = NULL, on_train_batch_end = NULL,
    on_train_begin = NULL, on_train_end = NULL,
    on_predict_batch_begin = NULL, on_predict_batch_end = NULL,
    on_predict_begin = NULL, on_predict_end = NULL,
    on_test_batch_begin = NULL, on_test_batch_end = NULL,
    on_test_begin = NULL, on_test_end = NULL)

该回调函数将会在适当的时候调用,注意,这里假定了一些位置参数:

  • on_epoch_begin和on_epoch_end假定输入的参数是epoch和logs。
  • on_batch_begin和on_batch_end假定输入的参数是batch和logs。
  • on_train_begin和on_train_end假定输入的参数是logs。

主要参数描述如下。

  • on_epoch_begin:在每个epoch开始时调用。
  • on_epoch_end:在每个epoch结束时调用。
  • on_batch_begin:在每个batch开始时调用。
  • on_batch_end:在每个batch结束时调用。
  • on_train_begin:在训练开始时调用。
  • on_train_end:在训练结束时调用。