A Study on Image Classification of Diabetic Retinopathy Based on Transfer Learning
More than 400 million people worldwide suffer from diabetes and more than 100 million in China. A third of diabetics have complications that can lead to visual impairment or even blindness, namely, diabetic retinopathy. Therefore, fundus examination is a necessary check for everyone with diabetes, and the more severe the condition, the higher the frequency of the check. Such a large number of examinations places a heavy workload on medical staff and also increases the country’s medical expenditure. To conduct the examination more economically, efficiently, and accurately, it is necessary to establish a computer-aided diagnostic system. Hence, this work has become an important research topic for scientific researchers. However, due to the characteristics of medical data itself (belonging to small datasets and data imbalance) as well as the “black box” issue of neural networks, the effect of using deep learning for disease classification is not satisfactory. In response to the above problems, this paper conducts research in the following aspects: an interpretable classifier based on transfer learning and DenseNet121 is proposed. By utilizing fine-tuning transfer learning techniques on the pre-trained DenseNet121 network and modifying some of the network structures, the automatic diagnosis task of DR is realized. The experimental results show a sensitivity of 0.77, specificity of 0.91, and accuracy of 88%. At the same time, the interpretability of the neural network is visualized by using gradient-weighted class activation mapping to visualize the classification results of the classifier. Through this method, it can be shown which part of the input image has a greater impact on the classification results of the classifier. By this visualization approach, it is found that although the classifier has achieved a high accuracy rate, it does not always focus on the lesion areas, and there is a bias towards fundus image samples with fewer categories. This reflects the problems existing in models trained with imbalanced samples, allowing us to better understand the shortcomings of the models we have trained.
Computer-Aided Diagnosis
DR临床诊断过程:临床医生通过分析微血管动脉瘤(MA)、渗出、出血、新生毛细血管和棉絮状点的数量、位置、大小来判断DR病情
国内,2015年张庆鹏开发了一个自动判断眼底是否出现微血管瘤的软件
国外,2019年,Dekhil等人利用迁移学习和卷积神经网络在Kaggle APTOS 2019数据集上达到了77%的诊断准确率
以上这些研究,效果越来越好,因为采用的深度神经网络越来越深,算力也更强,再加上迁移学习等技术的应用。从以上论文中的实验结果上看,已经使得DR分类达到了很好的效果,但由于以下问题还没有得到解决,大部分神经网络还未能在医疗领域得到实际应用:神经网络的黑箱问题;数据不平衡和数据量小等。虽然神经网络的分类准确率很高,但考虑到大部分样本是不需要治疗的DR,而真正分类那些需要治疗的样本很少。这就是说,如果神经网络判断所有样本都为正常眼底图像,则其准确率也很高。而这样的分类器显然是毫无意义的,因此数据不平衡对分类器影响很大,而且采用的神经网络不具有可解释性,导致无法对输出分类结果做出解释分析,导致即使模型犯了错误也难以发现。2019年,Li T.等人比较各个常用模型在DR分类和特征识别中的表现,发现虽然大多数算法准确率很高,但与医生判断DR病情(通过观察和计量眼底损失程度)并不相同,而且这些方法在损失分割和检测上表现很不理想
关于神经网络的可解释性,深度学习要在医疗领域得到应用,除了很高准确率外,神经网络的诊断过程和原理也同样重要,否则,医疗机构通常不会选择这种具有不确定性设备,因为,医疗机构要为自己所用的设备造成的后果负责。这导致深度学习难以在医疗领域大范围应用
关于分类模型可解释性主要分为两大类:对分类器的可解释性和对抗生样本的可解释性。本文主要介绍对分类器的可解释性。对分类器的可解释性包括:对预测结果的解释和对内部表征的可视化和分析。
对预测结果的可解释性,最早是2016年,Zhou等人在
内部表征的可解释性工作:Mu等人通过给模型内部神经元打上标签,利用神经元的语义概念给出可解释性
除此之外,越复杂的数据需要更深的网络才能取得较好的效果,而越深的网络需要更多数据才能充分训练模型,进而取得好的效果。然而糖尿病视网膜病变数据集通常为小数据,但DR图像却十分复杂,因此需要大型网络来拟合数据,所以DR分类任务面临数据小的问题。
为了利用小数据集训练大型网络,但小数据难以训练整个大型网络。那能否利用那些花费了大量人力物力训练好的大型网络中的大部分网络,而用我们目标数据集训练其中的小部分网络参数,使得训练后的大型网络能够适用于数据不足的任务,迁移学习就是采用了这个想法。对于小数据集的分类问题,用迁移学习是很好的解决办法,很多工作
基于以上分析,本研究采用迁移学习,通过在小数据集上使用大型预训练模型来训练DR分类模型。随后,利用Grad-CAM对模型的分类结果进行可视化解释,从而构建一个具有可解释性的DR分类深度学习模型。
数据来源于Kaggle公开数据集Diabetic Retinopathy Detection公开竞赛数据集,数据集是从多家医院收集的,其中的图像和标签都含有噪声,图像是很多成像设备拍摄的,有的图像包含伪影、曝光不够或曝光过度、不聚焦和尺寸不同等缺陷,而且部分图像无法判断属于哪个类别。数据集中包括35126张眼底图像,其中33545张是不影响视力的眼底图像,其它1581张是对视力造成伤害的DR眼底图像,显然数据存在分布不均的情况。
数据下载后,使用torch库中transposes方法进行处理:由于数据本身特点(眼睛是圆形的),因此数据处理采用随机旋转、水平翻转、垂直翻转、平移、随机灰度化和随机颜色变化。并同时把所有图像数据统一更改为统一尺寸并转换为tensor格式以便输入模型。
除传统方法外,并采用过采样(即喂入模型中的数据按照各个类别的数据数量基本相同使得网络学习的样本数量各个类别一样,有些类别样本偏少,那么该类样本就会反复喂入网络)。按照放回抽样的方式,从各个类别中均匀抽取样本,根据样本分布d设计取样器权重w。
(2-1)
考虑到本研究是针对大量图像数据,选择GPU会节省模型训练时间。考虑到易用性和灵活性,采用pytorch框架。运行环境如下:
操作系统:Windows 10;编程语言:Python 3.8.4;深度学习框架:Pytorch 1.71。
CPU型号:IntelCorei3-9100F @ 3.60 GHz;31 GB内存。
GPU:GeForce RTX 2080 Ti。
为了使DR图片分类任务效果更好,考虑到糖尿病视网膜病变图像比较复杂,为了达到更好的效果,采用相对较深的网络,因为深层网络被证明能更好地处理复杂的分类任务。由于采用梯度下降法来优化神经网络,使得深层网络容易出现梯度消失的问题,处理该问题常用方法有采用长短时记忆(LSTM)、ReLU、Residual neural networks (ResNets)、Batch Normalization (BN)和DenseNet。ReLU和BN对梯度消失问题帮助有限,直到ResNets的提出,才使得深层网络模型开始大量使用。在ResNets基础上,DenseNet在更少的模型参数和更短的训练时间达到了更好效果
模型采用DenseNet121预训练的神经网络,采用微调迁移学习,即更改网络最后一层(FC层),使得输出通道为5,训练该FC层和最后一个dense块中的最后三个dense层,固定其它层的网络参数。本研究采用的更改后的网络如
模型搭建(Pytorch代码如下):
model = models.densenet121(pretrained=True)
其中最后一层全连接层网络,即FC网络(只显示全连接层)如下:
model.classifier = nn.Sequential(
nn.Linear(in_features=1024, out_features=512, bias=True),
nn.Linear(in_features=512, out_features=256, bias=True),
nn.Linear(in_features=256, out_features=128, bias=True),
nn.Linear(in_features=128, out_features=64, bias=True),
nn.Linear(in_features=64, out_features=32, bias=True),
nn.Linear(in_features=32, out_features=16, bias=True),
nn.Linear(in_features=16, out_features=5, bias=True),
)
层的名称 |
输出尺寸 |
参数信息 |
Convolution |
112, 112 |
7 × 7 conv, stride 2 |
Polling |
56, 56 |
3 × 3 max pool, stride 2 |
Dense Block (1) |
56, 56 |
[1 × 1 conv, 3 × 3 conv] × 6 |
Transition Layer (1) |
28, 28 |
1 × 1 conv, 2 × 2 average pool, stride 2 |
Dense Block (2) |
28, 28 |
[1 × 1 conv, 3 × 3 conv] × 12 |
Transition Layer (2) |
14, 14 |
1 × 1 conv, 2 × 2 average pool, stride 2 |
Dense Block (3) |
14, 14 |
[1 × 1 conv, 3 × 3 conv] × 24 |
Transition Layer (3) |
7, 7 |
1 × 1 conv, 2 × 2 average pool, stride 2 |
Dense Block (4) |
7, 7 |
[1 × 1 conv, 3 × 3 conv] × 16 |
Polling |
1, 1 |
7 × 7 global average pool |
Classification Layer |
1, 1 |
FC: 1024…5 |
超参数设置:
学习率:0.0001,batch size:256 (设置大些以提高训练速度和训练效果)。
数据集:将数据集80%用于训练集,20%用于测试集,由于数据量小,训练集比重设置大些,提高模型训练效果。
模型:预训练DenseNet121。
优化器:Adam;损失函数:交叉熵。
主要程序如下:
#设置可训练层
for name, param in critic0.features.denseblock4.denselayer14.named_parameters():
param.requires_grad=True
for name, param in critic0.features.denseblock4.denselayer15.named_parameters():
param.requires_grad=True
for name, param in critic0.features.denseblock4.denselayer16.named_parameters():
param.requires_grad=True
critic0 = critic0.to(device)
#选择交叉熵损失函数
criterion = nn.CrossEntropyLoss()
#优化器选择Adam
optim = optim.Adam(critic0.parameters(), lr=LEARNING_RATE, betas=(0., 0.99), eps=1e-8)
d_scheduler = build_lr_scheduler(optim, -1)
#采用tensorboard可视化训练过程
writer_real = SummaryWriter(f"logs/GAN_MNIST/real")
writer_fake = SummaryWriter(f"logs/GAN_MNIST/fake")
step = 0
为了验证DenseNet121网络的优势,与常用图片分类网络Resnet101和VGG16进行比较,网络都是在相同的图像大数据集ImageNet上训练好的网络,并更改了网络最后一层全连接层,全连接层的设置相同,并且都相应选择最后一个包含卷积的网络块设为可训练,其它浅层网络设置为不可训练,超参数选择都相同。
采用tensorboard可视化模型的损失曲线(DenseNet、VGG和ResNet)如
由
对比
进一步比较基于迁移学习DenseNet121网络与ReseNet101网络和VGG16网络之间的敏感性(Sensitivity)、特异性(Specificity)和准确率(Accuracy),计算公式如2-2,2-3,2-4所示:
(2-2)
(2-3)
(2-4)
其中:TP为真正样本数、TN为真负样本数、FP为假正样本数、FN为假负样本数。
通过
模型 |
敏感性 |
特异性 |
准确率 |
Resnet101_based network |
29% |
89% |
76% |
VGG16_based network |
6% |
96% |
74% |
Densenet121_based network |
77% |
91.3% |
88.3% |
2020年,Junsuk Choe等人在中提出了一种评估弱监督定位的方法,基于此方法,评估了最近的几个模型弱监督定位的方法,包括CAM、HaS、SPG、ADL和CutMix,发现在2016年的CAM方法之后提出的其他方法都没有比CAM效果更好
具体步骤为,在模型训练前,将Grad-CAM加入到网络中,选取要加入的位置以及输出热力图保存的路径,本研究采用Densenet网络,Grad-CAM加到DensNet121的最后一个Dense Block的输出端,即Dense Block 4为最后一个Dense Block,其输出为(7, 7, 1024)。然后,计算模型分类结果中概率最大的值,然后计算损失函数再反向传播,计算那1024个神经元的梯度并求平均,用此梯度乘上一层的网络输出再求平均(在最后一个轴)得到(7, 7)大小的图,再上采样就得到了和图片大小相同的热力图。测试时,通过任选一个样本输入网络,即可得到分类结果和解释性热图。主要程序如下:
#选取样本
sample_x, sample_y = next(iter(train_loader))
critic0 = critic0.to('cpu')
#输出热图保存至文件夹:attention_maps
critic0=medcam.inject(critic0,output_dir=f"attention_maps", lay-er='features.denseblock4.denselayer16', backend='gcam', save_maps=True)
sample_x = sample_x.cpu()
s = s_x.unsqueeze(0)
prep = critic0(s)
#将图片通道移至最后一个轴上
sample_x1 = s.permute(0,2,3,1)
sample_np = sample_x1.cpu().squeeze(0).numpy()
#标准化图片
a = sample_np - np.min(sample_np)
b = np.max(a)
sample_np_nml = (a)/(b)
#读取热图
interpre_img_dir=f'/home/featurize/work/CWGAN-GP/attention_maps
/features.denseblock4.denselayer16/attention_map_0_0_0.png'
image_fm_PIL = PIL.Image.open(interpre_img_dir)
img_inter = transforms(image_fm_PIL).permute(1,2,0)
#将热图和输入图片进行叠加:
mix_img = img_inter * 0.1 + sample_np_nml*0.8
其中选择热图img_inter占比0.1,是因为热图比重太大就会看不出眼底图像的细节信息,同理如果样本图像sample_np_nml比重太小也看不清楚眼底图像了。
如
从
网络判断是根据图片中的一些特征包括病灶部位和其它信息,通过这种方式,我们能够更好地了解训练的模型,是否是按照人类判断的方法,同时,也许会为医疗人员带来一些新的想法,比如可能人类并没有发现的一些特征,而神经网络发现了,然后网络通过热力图告诉我们,它们关注的点。
热力图显示比较直观,但也存在不足,从
如
可能存在的,因为数据不平衡,导致模型倾向于寻找无病的DR特征,而无病的DR眼底图像是不存在病灶特征的,更多的是血管特征。从
根据对热力图的分析发现,我们的模型还存在不足,如果单从准确率和敏感性特异性上看,只能看出模型效果并非十分理想,敏感性稍差0.77,同时准确率也只有88%,但除此之外并没有对模型判断过程或者关注点有更直观的信息。但通过热力图,我们可以明显发现模型出错的原因(数据少造成模型产生偏见),这给我们如何进一步提高模型,带来了更直观和重要的信息。也为我们更加直观地显示出模型判断过程是否正确。显然,通过这种方法让我们更好地了解模型,也给“黑箱”带进来一缕阳光。
糖尿病视网膜病变(DR)是糖尿病的一种并发症,如不及时检查和治疗,可能导致病人视力受损甚至失明,将给病人、病人家属和社会带来极大的负担。这种疾病是由于病人血糖控制问题所导致的,一般病人血糖控制都存在一定问题,因此这种疾病发病率高,检测次数也就因此非常大。为了缓解眼科医生的工作负担,同时也为了降低就医费用。利用深度学习来完成这项检测分类任务十分必要。
然而,深度学习在医学领域的应用还面临一些困难:医学数据通常是小数据集,而且数据存在不平衡现象。因此,本研究利用生成对抗网络、迁移学习、可解释性工具对这项任务进行研究,主要研究:(1) 研究利用迁移学习搭建生成对抗网络(GAN)生成DR眼底图像;(2) 研究利用预训练的生成对抗网络构造条件生成对抗网络,增加对生成DR眼底图像的控制;(3) 研究搭建有可解释性的DR病情分类神经网络。
主要创新点:建立了带可解释性的基于迁移学习的DensNet121分类神经网络,用小数据训练大型分类神经网络的目的,并且达到了较好分类的效果:准确率:88%,敏感性:77%,特异性91%。并且,实现了网络对每个样本的预测,做出了可视化的解释。
本研究还需进一步的改善如下:对神经网络可解释性,还需进一步研究,如何把热图做到更精细化,使得Grad-CAM的解释能力更强,还需要继续研究。