- Keras深度学习:入门、实战与进阶
- 谢佳标
- 1002字
- 2024-10-30 00:46:57
1.4.4 模型评估及预测
前面我们已经完成了训练,现在使用测试数据集来评估模型准确率。通过evaluate()函数实现。
> # 模型评估 > scores <- model %>% + evaluate(x_Test_normalize,y_TestOneHot) 10000/10000 [===========================] - 1s 118us/sample - loss: 0.0712 - acc: 0.9787 > scores $loss [1] 0.07122663 $acc [1] 0.9787
以上程序代码的执行结果的准确率为0.9787,效果非常不错。接下来我们使用predict_classes()函数对测试数据集进行类别预测。
> # 模型预测 > prediction <- model %>% + predict_classes(x_Test_normalize) > prediction[1:9] [1] 7 2 1 0 4 1 4 9 6
通过可视化手段展示测试集前9张数字图像及其实际和预测标签,如图1-7所示。
> par(mfrow=c(3,3)) > par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i') > for(i in 1:9){ + plot(as.raster(x_test_image[i,,],max = 255)) + title(main = paste0('label=',y_test_label[i],'predict=',prediction[i])) + } > par(mfrow=c(1,1)
图1-7 测试集前9个数字图像及标签展示
从图1-7可知,前8个数字图像均预测正确,但是第9个数字图像实际标签为5,却被预测为6。
如果想要进一步了解预测结果中哪些数字准确率最高,哪些数字最容易混淆,可以使用混淆矩阵(confusion matrix)来显示。
> # 构建混淆矩阵 > table('label' = y_test_label, + 'predict' = prediction) predict label 0 1 2 3 4 5 6 7 8 9 0 971 0 2 2 0 0 1 1 2 1 1 0 1122 3 1 0 2 2 1 4 0 2 4 0 1015 3 1 0 2 3 4 0 3 0 0 5 989 0 5 0 3 1 7 4 0 0 5 0 961 0 2 2 0 12 5 2 0 0 9 0 867 7 1 3 3 6 6 2 3 1 4 3 937 0 2 0 7 0 2 12 7 0 0 0 1000 0 7 8 5 0 4 11 3 3 2 4 937 5 9 3 3 0 5 5 3 0 2 0 988
混淆矩阵可以展示各个标签的误分类情况。比如在1万个测试样本中,有7个样本的实际标签为5却误分为6,有6个标签的实际标签为6却误分为0的情况。
最后,让我们把实际标签与预测标签组成一个新的数据框df,方便进行结果对比查看。比如我们想查看所有实际标签为5却误分为6的样本,可以通过以下程序代码实现。
> # 构建结果集 > df <- data.frame('label' = y_test_label, + 'predict' = prediction) > # 查看实际标签为5,预测标签为6的样本 > df_sub <- df[df$label==5 & df$predict==6,] > df_sub label predict 9 5 6 1379 5 6 3894 5 6 8864 5 6 9730 5 6 9750 5 6 9983 5 6
除了第9号(指的是该样本在测试集中的第几个)样本,还有1379、3894、8864、9730、9750、9983号共7个样本的实际标签为5却被预测为6。
最后,我们通过可视化手段查看这些数字图像为什么不易被正确识别,运行以下程序代码得到如图1-8所示结果。
> index <- as.numeric(rownames(df_sub)) > par(mfrow=c(2,4)) > for(i in index){ + plot(as.raster(x_test_image[i,,],max = 255)) + title(main = paste0(i,':label=',y_test_label[i],'predict=',prediction[i])) + } > par(mfrow=c(1,1))
图1-8 测试集中实际标签为5却预测为6的数字图像
从图1-8可知,误分类的样本数字书写比较随意,数字整体向右倾斜,5字下部分有部分数字是合并在一起的,不易识别。
如果读者好奇那些实际标签为6却被预测为0的数字图像是怎样的,可以通过以下程序代码实现,结果如图1-9所示。
> # 实际标签为6,预测为0的数字图像展示 > index1 <- as.numeric(rownames(df[df$label==6 & df$predict==0,])) > par(mfrow=c(2,3)) > for(i in index1){ + plot(as.raster(x_test_image[i,,],max = 255)) + title(main = paste0(i,':label=',y_test_label[i],'predict=',prediction[i])) + } > par(mfrow=c(1,1))
图1-9 训练集中实际标签为6、预测为0的数字图像
图1-9 (续)
从图1-9可知,这些数字图像中6的出头部分太不明显,故模型被误预测为0。