1.4.4 模型评估及预测

前面我们已经完成了训练,现在使用测试数据集来评估模型准确率。通过evaluate()函数实现。

> # 模型评估
> scores <- model %>%
+   evaluate(x_Test_normalize,y_TestOneHot)
10000/10000 [===========================] - 1s 118us/sample - loss: 0.0712 - acc: 0.9787
> scores
$loss
[1] 0.07122663
$acc
[1] 0.9787

以上程序代码的执行结果的准确率为0.9787,效果非常不错。接下来我们使用predict_classes()函数对测试数据集进行类别预测。

> # 模型预测
> prediction <- model %>%
+   predict_classes(x_Test_normalize)
> prediction[1:9]
[1] 7 2 1 0 4 1 4 9 6

通过可视化手段展示测试集前9张数字图像及其实际和预测标签,如图1-7所示。

> par(mfrow=c(3,3))
> par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i')
> for(i in 1:9){
+     plot(as.raster(x_test_image[i,,],max = 255))
+     title(main = paste0('label=',y_test_label[i],'predict=',prediction[i]))
+ }
> par(mfrow=c(1,1)
029-1

图1-7 测试集前9个数字图像及标签展示

从图1-7可知,前8个数字图像均预测正确,但是第9个数字图像实际标签为5,却被预测为6。

如果想要进一步了解预测结果中哪些数字准确率最高,哪些数字最容易混淆,可以使用混淆矩阵(confusion matrix)来显示。

> # 构建混淆矩阵
> table('label' = y_test_label,
+       'predict' = prediction)
     predict
label    0    1    2    3    4    5    6    7    8    9
    0  971    0    2    2    0    0    1    1    2    1
    1    0 1122    3    1    0    2    2    1    4    0
    2    4    0 1015    3    1    0    2    3    4    0
    3    0    0    5  989    0    5    0    3    1    7
    4    0    0    5    0  961    0    2    2    0   12
    5    2    0    0    9    0  867    7    1    3    3
    6    6    2    3    1    4    3  937    0    2    0
    7    0    2   12    7    0    0    0 1000    0    7
    8    5    0    4   11    3    3    2    4  937    5
    9    3    3    0    5    5    3    0    2    0  988

混淆矩阵可以展示各个标签的误分类情况。比如在1万个测试样本中,有7个样本的实际标签为5却误分为6,有6个标签的实际标签为6却误分为0的情况。

最后,让我们把实际标签与预测标签组成一个新的数据框df,方便进行结果对比查看。比如我们想查看所有实际标签为5却误分为6的样本,可以通过以下程序代码实现。

> # 构建结果集
> df <- data.frame('label' = y_test_label,
+                  'predict' = prediction)
> # 查看实际标签为5,预测标签为6的样本
> df_sub <- df[df$label==5 & df$predict==6,]
> df_sub
       label    predict
9        5         6
1379     5         6
3894     5         6
8864     5         6
9730     5         6
9750     5         6
9983     5         6

除了第9号(指的是该样本在测试集中的第几个)样本,还有1379、3894、8864、9730、9750、9983号共7个样本的实际标签为5却被预测为6。

最后,我们通过可视化手段查看这些数字图像为什么不易被正确识别,运行以下程序代码得到如图1-8所示结果。

> index <- as.numeric(rownames(df_sub))
> par(mfrow=c(2,4))
> for(i in index){
+     plot(as.raster(x_test_image[i,,],max = 255))
+     title(main = paste0(i,':label=',y_test_label[i],'predict=',prediction[i]))
+ }
> par(mfrow=c(1,1))
031-1

图1-8 测试集中实际标签为5却预测为6的数字图像

从图1-8可知,误分类的样本数字书写比较随意,数字整体向右倾斜,5字下部分有部分数字是合并在一起的,不易识别。

如果读者好奇那些实际标签为6却被预测为0的数字图像是怎样的,可以通过以下程序代码实现,结果如图1-9所示。

> # 实际标签为6,预测为0的数字图像展示
> index1 <- as.numeric(rownames(df[df$label==6 & df$predict==0,]))
> par(mfrow=c(2,3))
> for(i in index1){
+     plot(as.raster(x_test_image[i,,],max = 255))
+     title(main = paste0(i,':label=',y_test_label[i],'predict=',prediction[i]))
+ }
> par(mfrow=c(1,1))
031-2

图1-9 训练集中实际标签为6、预测为0的数字图像

032-1

图1-9 (续)

从图1-9可知,这些数字图像中6的出头部分太不明显,故模型被误预测为0。