第二军医大学学报  2018, Vol. 39 Issue (8): 897-902   PDF    
基于迁移学习的全连接神经网络舌象分类方法
杨晶东, 张朋     
上海理工大学光电信息与计算机工程学院自主机器人实验室, 上海 200093
摘要: 目的 针对深度学习在舌象分类中训练数据量大、训练设备要求高、训练时间长等问题,提出一种基于迁移学习的全连接神经网络小样本舌象分类方法。方法 应用经ImageNet海量数据集训练后的卷积Inception_v3网络提取舌象点、线等有效特征,再使用全连接神经网络对特征进行训练分类,将深度学习网络学习到的图像知识迁移到舌象识别任务中。利用舌象数据集进行训练、测试。结果 与典型舌象分类方法K最近邻(KNN)算法、支持向量机(SVM)算法和卷积神经网络(CNN)深度学习方法相比,本实验使用的两种方法(Inception_v3+2NN和Inception_v3+3NN)具有较高的舌象分类识别率,准确率分别达90.30%和93.98%,且样本训练时间明显缩短。结论 与KNN算法、SVM算法和CNN深度学习方法相比,基于迁移学习的全连接神经网络舌象分类方法可有效提高舌象分类的准确率、缩短网络的训练时间。
关键词: 人工智能     迁移学习     深度学习     舌象     卷积神经网络    
Tongue image classification method based on transfer learning and fully connected neural network
YANG Jing-dong, ZHANG Peng     
School of Optical-Electrical and Computer Engineering, University of Shanghai for Science and Technology, Shanghai 200093, China
Supported by National Natural Science Foundation of China (61374039), Natural Science Foundation of Shanghai (15ZR1429100), and Hujiang Foundation (C14002).
Abstract: Objective To propose a classification method for small sample tongue images based on transfer learning and fully connected neural network, so as to solve the problems of large amount of data, high requirement of training equipment and long training time of deep learning in the classification of tongue images. Methods Effective features such as tongue points and lines of tongue images were extracted by the convolution Inception_v3 network after training on the massive data set of ImageNet. The above features were classified by the fully connected neural network, and the image knowledge acquired by the deep learning network was transferred to the tongue image recognition task, and then the tongue data set were used to train and test the efficiency of the network. Results Compared with the typical tongue image classification method such as K-nearest neighbor (KNN) algorithm, support vector machine (SVM) algorithm and convolutional neural network (CNN) deep learning method, the two methods (Inception_v3+2NN and Inception_v3+3NN) in our experiment had higher classification rates for tongue images, with the accuracy rates being 90.30% and 93.98%, respectively, and had shorter training time for the sample. Conclusion Compared with KNN algorithm, SVM algorithm and CNN deep learning method, the tongue image classification method based on transfer learning and fully connected neural network can effectively improve the accuracy rate of tongue image classification and shorten the training time.
Key words: artificial intelligence     transfer learning     deep learning     tongue presentations     convolutional neural network    

传统中医舌诊依赖医师肉眼观察、经验辨析,其结果受诸多主观和外界因素影响,缺乏客观评价标准,制约了舌诊的应用与发展。随着图像处理、机器学习技术的发展,舌诊客观化研究取得了一定进展。有学者研究发现经平衡预处理后舌苔数据集训练GoogLeNet网络模型的结果优于传统舌象诊断方法[1-4]。刘关松等[5]和王昇等[6]提出了一种基于神经网络集成舌苔分类方法,发现该分类器具有较好的精细分类效果,符合中医舌诊要求。传统舌象识别过程包括舌象图片采集、颜色校正、舌体分割、分类器训练等,需要人工寻找特征,这在一定程度上影响分类结果。随着深度学习技术的发展,利用深度学习技术来进行舌象分类成为一种趋势。然而,深度学习技术需要大量标注样本、设备配置要求高且训练时间长,难以在实际中运用。

近年来,迁移学习获得了广泛认可和研究。Pan和Yang[7]、Cheriyadat[8]、Yang和Newsam[9]对迁移学习的历史、分类和挑战进行了详尽解释;李冠东等[10]利用迁移学习在深度卷积神经网络(convolutional neural network,CNN)模型下进行高分影像场景分类,提高了场景分类精度;刘晨等[11]、Lin等[12]、Schmidhuber[13]利用迁移学习先通过大样本对CNN进行预训练,再用超限学习机(extreme learning machine,ELM)代替全连接层,最后使用小样本对CNN模型进行训练,提高了分类准确率且缩短了训练时间;关胤[14]提出了一种基于残差网络迁移学习的花卉识别方法,提高了网络在线识别花卉系统的准确率、实时反馈和泛化能力。迁移学习可协助人类解决小样本分类问题,使用在海量数据集上训练好的模型对小样本进行特征提取,再用特征微调全连接神经网络降低利用深度网络解决问题的难度系数。本实验针对传统机器学习和深度学习在小样本舌象分类中出现的问题,将迁移学习方法应用到舌象分类的任务中,发现迁移学习可在提高分类准确率的同时降低配置要求并减少训练时间,现报告如下。

1 材料和方法 1.1 数据集

为了使网络模型对移动互联网终端采集到舌象图片有较好的泛化能力,本研究主要收集了正常舌象、裂纹、厚苔、点刺、齿痕和剥落苔6种舌象图片,如图 1所示。每种舌象图片有30张,共180张。简单处理使图片像素大小均为256×256,且使用独热编码(one-hot-encoding)对样本类别进行标注。若正常舌象为第1类,裂纹为第2类,厚苔为第3类,则它们的标签分别为[1 0 0 0 0 0]、[0 1 0 0 0 0]和[0 0 1 0 0 0],其他类别以此类推。为拟合模型和评估模型泛化误差,将数据样本分为训练集和测试集,训练集和测试集尽可能互斥。训练集与测试集的划分方法有留出法(hold-out)、交叉验证法(cross validation)和自助法(boot strapping)3种。本研究采用留出法,直接将数据集D划分为2个互斥的集合,其中一个集合作为训练集S,另一个集合作为测试集T,即D=STST=φ。将2/3~4/5的数据用于训练,剩余样本用于测试。为评估结果稳定性,随机选取90%的数据集作为训练集,10%的数据集作为测试集。

图 1 6类舌象图片 Fig 1 Six tongue images A: Normal tongue; B: Crack tongue; C: Thick tongue coating; D: Prick tongue; E: Indented tongue; F: Peeling coating

1.2 迁移学习与模型特征提取

迁移学习数学定义为:给定源域DS和学习任务TS,一个目标域DT和学习任务TT,迁移学习用DSTS中的知识帮助提高DT中目标预测函数fT的学习,并且DSDTTSTT。迁移学习可以分为样本迁移、特征迁移、关系迁移和模型迁移4类,研究如何将源任务上训练的知识迁移到仅有少量标签样本数据的目标任务上。深度网络的微调是一种深度网络迁移学习方法,利用他人已训练好的网络再调整自己的任务,节省时间和成本,间接扩充了实验数据,提升了模型鲁棒性。

GoogLeNet网络模型由Szegedy等[15]提出,其在ILSVRC14(ImageNet Large-Scale Visual Recognition Challenge 2014)大赛上以6.67%的top-5错误率赢得冠军。该研究提出的Inception_v1模型在不增加计算量的同时扩大了网络的深度和宽度,丰富了特征学习并提高了网络分类性能。Inception_v2模型中加入了BN层,减少了Internal Covariate Shift;同时模型使用2个3×3的卷积核代替Inception模块中5×5的卷积核,降低了参数数量并加快了计算速度。Inception_v3模型在上一模型的基础上进一步分解,为了加速计算,将7×7和3×3的卷积核分解为2个一维卷积(1×7,7×1)和(1×3,3×1)形式[16];又将1个卷积层分解为2个卷积层,使网络更深并增加了网络的非线性。Inception_v3模型经2012年的ImageNet数据集训练后,在测试集上取得了3.5%的top-5错误率(人类专家为5%[17])。

经ImageNet数据集训练好的Inception_v3模型具有抽取图像低层次特征(如点、线、曲线和边缘等)和抽象特征的能力。利用模型迁移和模型微调的思想,本研究将在ImageNet数据集上预训练好的Inception_v3模型知识(保存的所有卷积层参数)迁移到舌象分类的任务中,从而对舌象图像进行有效特征向量提取,减少舌象训练样本量和模型训练时间,提高分类的准确性。例如将图 1中的每张图像输入训练好的Inception_v3模型得到的特征为2 048维的向量,可视化像素大小为32×64的有效特征图,如图 2所示(图 2中的特征图与图 1相应舌象的位置一一对应)。

图 2 通过迁移学习得到的特征图 Fig 2 Feature images based on transfer learning A: Normal tongue feature; B: Crack tongue feature; C: Thick tongue coating feature; D: Prick tongue feature; E: Indented tongue feature; F: Peeling coating feature

1.3 基于迁移学习的全连接神经网络分类模型

本实验采用的全连接网络为多层前馈神经网络,如图 3所示,该网络有d个输入神经元、q个隐层单元和l个输出单元。隐含层第h个神经元的阈值为γh,输出层第j个神经元的阈值为θj,输入层第i个神经元与隐含层第h个神经元的连接权重为vih,隐含层第h个神经元与输出层第j个神经元的连接权重为whj。隐含层使用ReLU激活函数,输出层使用Softmax函数,分别记为f1(vT.x+γ)和f2(wT.B+θ)。其中vd×q大小的权值矩阵,xd×1大小的样本输入向量,γq×1大小的阈值向量;wq×l大小的权值矩阵,B=f1(.)大小为q×1,θl×1大小的阈值向量。网络前向传播输出为y=f2(wT.B+θ)(网络预测的概率值),其中y为样本对应的标签向量,大小为l×1。用交叉熵计算损失函数Loss=Y1log(y1)+…+ Yilog(yi)+…+ Yllog(yl),使用反向传播(back propagation,BP)算法、梯度下降最小化损失函数,dropout技术减少过拟合,寻找最优模型分类参数。

图 3 单隐层前馈神经网络 Fig 3 Single hidden layer feed forward neural network

基于迁移学习的全连接网络结构如图 4所示。左侧为已经训练好的深度卷积Inception_v3模型(即Google公开训练好的Inception_v3模型),网络中保存了拟合网络参数,具有提取图像强特征的能力,使用该深度模型提取小样本舌象特征。右侧为全连接神经网络,本实验使用了2层和3层2种全连接神经网络。Inception_v3+2NN为Inception_v3模型+dense1(2 048,ReLU)+dense2(6,Softmax),Inception_v3+3NN为Inception_v3模型+dense1(2 048,ReLU)+dense2(1 024,ReLU)+dense3(6,Softmax),使用提取的强特征微调全连接神经网络进行小样本舌象分类。舌象分类流程如图 5所示。

图 4 基于迁移学习的全连接网络结构图 Fig 4 Fully connected neural connection network structure diagram based on transfer learning

图 5 舌象图像分类流程图 Fig 5 Flow chart of tongue image classification

2 结果和讨论 2.1 实验环境

本实验使用了python和Google开源深度学习框架TensorFlow搭建神经网络,训练笔记本电脑配置为处理器Inter® CoreTM i5-3230M CPU 2.60 GHz,安装内存为4.00 GB,没有使用图形处理器,全程单线程运行。180张舌象图像中训练集为162张,测试集为18张,将数据转化为TFRecord格式备用。实验前已获取Google在ImageNet数据集上训练好的Inception_v3模型BP文件,本实验的网络模型为Inception_v3+2NN和Inception_v3+3NN。同时选取CNN、改进的CNN、K最近邻(K-nearest neighbor,KNN)算法、支持向量机(support vector machine,SVM)算法为对照,CNN和改进的CNN的网络架构大纲如表 1表 2所示。

表 1 卷积神经网络架构大纲 Tab 1 Architecture outline of convolutional neural network

表 2 改进的卷积神经网络架构大纲 Tab 2 Architecture outline of improved convolutional neural network

2.2 训练集的收敛性分析

分别采用CNN、改进的CNN、Inception_v3+2NN和Inception_v3+ 3NN基于训练集图像进行训练,其中学习率为0.001,批样本数为10,keep_prob为0.5,迭代次数为4 100,每迭代训练100次评估1次网络模型的分类准确率。由于KNN算法没有训练过程,SVM算法的训练过程是调优主要参数,均无法记录准确率,故只记录以上4种网络模型的准确率曲线,如图 6所示。

图 6 原始数据集训练正确率曲线 Fig 6 Accuracy rate curve of four networks trained by origin tongue sets CNN: Convolutional neural network; Inception_v3+2NN: Inception_v3 model plus two-layer neural network; Inception_v3+3NN: Inception_v3 model plus three-layer neural network

图 6分析可知,4种网络模型中CNN的收敛性最差,其在测试集中的平均分类准确率较Inception_v3+2NN和Inception_v3+3NN分别低73.63%、77.31%。分析其原因为CNN需要反复对训练集进行迭代以提取舌象特征,而训练集样本数量较少,池化层使用较多,使CNN难以进行特征学习,故训练收敛性差。改进后CNN的平均准确率相比Inception_v3+2NN和Inception_v3+3NN分别低62.52%、66.20%,分析其原因为改进的CNN只使用了一个池化层,并调小了网络结构,使网络在小样本训练集上也可以提取到有用特征,故收敛性较好。Inception_v3+3NN收敛性最好,这是由于Inception_v3模型能直接提取舌象的强表达基础特征,然后将这些特征在全连接神经网络上进行迭代训练,故其收敛性优于CNN和改进的CNN。

2.3 测试集的泛化性分析

使用迁移学习将训练集输入Inception_v3模型,提取具有强表达能力的特征向量,再分别使用2NN(两层全连接神经网络)和3NN(三层全连接神经网络)对特征向量进行训练,训练完成后使用测试集分别进行测试,记录算法模型的平均分类结果和运行时间。本实验方法与传统KNN、SVM和CNN方法的比较结果如表 3所示。

表 3 本研究分类方法与其他方法效果对比 Tab 3 Comparison of our methods with other methods

3 结论

KNN算法没有训练过程,其在利用测试集进行测试时,将图像转化为196 608维,根据离测试集最近的K个训练集类别出现频率进行分类,分类结果易受K值影响,样本维度较高时分类效果受限,其分类准确率为50%;SVM算法需要在一个最大边缘超平面进行分类,SVM算法的分类准确率低于KNN算法;CNN内有大量参数用于学习知识表示,CNN训练时需要大量数据拟合参数用于分类过程。为了减少计算量和训练时间,其采用了池化技术,这将造成原始图像信息丢失,致使CNN在小数据集上迭代训练时提取不到具有强代表性的舌象特征,易产生过拟合,因而表现较差。经减少CNN的池化次数并调节网络结构改进后CNN的分类准确率由16.66%提高至27.78%,但仍然低于其他方法,表明CNN不适用于小样本数据分类。

本实验将迁移学习与全连接神经网络结合,使用训练好的Inception_v3从小样本舌象训练集中提取特征,再利用特征微调全连接神经网络,较好地解决了传统机器学习和深度学习在小样本、多分类中的问题。本实验中Inception_v3+2NN和Inception_v3+3NN两种方法的分类准确率分别为90.30%和93.98%,表明对小样本数据进行分类时,迁移学习方法具有较好的特征提取效果。

综上,KNN和SVM算法由于没有大量迭代过程,因此分类时间较短,但在面对高维数据时容易受自身分类能力限制;CNN每次迭代都需大量的卷积运算、梯度计算进行参数拟合,且图片尺寸较大,故处理时间较长,难以使用小样本数据拟合大量参数并提取有用特征,泛化性较差。本实验将迁移学习与全连接神经网络结合并应用到小样本舌象分类领域,采用预训练好的Inception_v3模型,在实验中无需训练即可直接输入图像进行特征提取,有效缩短了训练时间。综合分析发现,与KNN、SVM和CNN方法相比,基于迁移学习的全连接神经网络舌象分类方法可有效提高舌象分类准确率、缩短网络训练时间。

参考文献
[1]
FU S, ZHENG H, YANG Z, LIU Y. Computerized tongue coating nature diagnosis using convolutional neural network[C]//IEEE, International Conference on Big Data Analysis. IEEE, 2017:730-734.
[2]
MARMANIS D, DATCU M, ESCH T, STILLA U. Deep learning earth observation classification using ImageNet pretrained networks[J]. IEEE Geosci Remote S, 2016, 13: 105-109. DOI:10.1109/LGRS.2015.2499239
[3]
SIMONYAN K, ZISSERMAN A. Very deep convolutional networks for large-scale image recognition[Z/OL]. arXiv:1409.1556, 2014. https://arxiv.org/pdf/1409.1556.pdf.
[4]
KRIZHEVSKY A, SUTSKEVER I, HINTON G E. ImageNet classification with deep convolutional neural networks[C]//International Conference on Neural Information Processing Systems. Curran Associates Inc., 2012:1097-1105.
[5]
刘关松, 徐建国, 高敦岳. 基于神经网络集成的舌苔分类方法[J]. 计算机工程, 2003, 29: 100-102.
[6]
王昇, 刘开华, 王丽婷. 舌诊图像点刺和瘀点的识别与提取[J]. 计算机工程与科学, 2017, 39: 1126-1132. DOI:10.3969/j.issn.1007-130X.2017.06.016
[7]
PAN S J, YANG Q. A survey on transfer learning[J]. IEEE T Knowl Data En, 2010, 22: 1345-1359. DOI:10.1109/TKDE.2009.191
[8]
CHERIYADAT A M. Unsupervised feature learning for aerial scene classification[J]. IEEE T Geosci Remote, 2013, 52: 439-451.
[9]
YANG Y, NEWSAM S. Spatial pyramid co-occurrence for image classification[C]//IEEE International Conference on Computer Vision. IEEE, 2011:1465-1472.
[10]
李冠东, 张春菊, 王铭恺, 张雪英. 卷积神经网络迁移的高分影像场景分类学习[J]. 测绘科学, 2019(6): 1-13.
[11]
刘晨, 曲长文, 周强, 李智, 李健伟. 基于卷积神经网络迁移学习的SAR图像目标分类[J]. 现代雷达, 2018, 40: 38-42.
[12]
LIN K, YANG H F, CHEN C S. Flower classification with few training examples via recalling visual patterns from deep CNN[C]//IPPR Conference on Computer Vision, Graphics, and Image Processing (CVGIP), 2015.
[13]
SCHMIDHUBER J. Deep learning in neural networks:an overview[J]. Neural Netw, 2014, 61: 85-117.
[14]
关胤. 基于残差网络迁移学习的花卉识别系统[J/OL]. 计算机工程与应用. [2018-05-15]. http://kns.cnki.net/kcms/detail/11.2127.TP.20180404.1703.020.html.
[15]
SZEGEDY C, LIU W, JIA Y, SERMANET P, REED S E, ANGUELOV D, et al. Going deeper with convolutions[C/OL]. 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2014. doi:10.1109/CVPR.2015.7298594.
[16]
SZEGEDY C, VANHOUCKE V, IOFFE S, SHLENS J, WOJNA Z. Rethinking the inception architecture for computer vision[C]//Conference on Computer Vision and Pattern Recognition (CVPR), Las Vegas, NV. IEEE, 2016:2818-2826.
[17]
RUSSAKOVSKY O, DENG J, SU H, KRAUSE J, SATHEESH S, MA S, et al. ImageNet large scale visual recognition challenge[J]. Int J Comput Vision, 2015, 115: 211-252. DOI:10.1007/s11263-015-0816-y