AI源码解读.数字图像处理案例:Python版

1.3.2 超分辨率模块

采用SRGAN算法,完成数据载入与处理、模型创建与训练及模型生成。因为训练过程对GPU性能要求高,所以在百度AI studio上运行。

1.数据载入与处理

本模块采用MS COCO数据集作为训练集,BSDS100和BSDS300数据集作为测试集,并创建json文件记录位置。将训练时数据集中的图片缩小作为高分辨率图像,用opencv放大原尺寸作为低分辨率图像。

MS COCO数据集包括82783张训练图像和40504张验证图像,下载地址为http://cocodataset.org/。

BSDS100和BSDS300数据集包含100张图像和300张图像,下载地址为https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/。

相关代码如下:

2.模型创建与训练

SRGAN模型结构分为生成网络和判别网络。生成网络(SRResNet)包含多个残差块,每个残差块中包含两个3×3的卷积层,卷积层后接批规范化层(batch normalization,BN)和PReLU作为激活函数,两个亚像素卷积层(sub-pixel convolution layers)被用来增大特征尺寸。判别网络包含8个卷积层,随着网络层数加深,特征个数不断增加,尺寸不断减小,选取激活函数为LeakyReLU,通过两个全连接层和最终的sigmoid激活函数得到预测为自然图像的概率。相关代码如下:

SRGAN采用交替训练模式的方式,先训练生成器部分(SRResNet)模型,在该模型的基础上再训练SRGAN。相关代码如下:

3.模型生成

给定图像的输入和输出地址,并通过输入的放大倍数加载所需预训练模型得到输出。相关代码如下: