2. 南京信息工程大学 人工智能学院,江苏 南京 210044
2. School of Artificial Intelligence, Nanjing University of Information Science and Technology, Nanjing 210044, China
在大数据时代的背景下,数据呈爆炸式增长,但大部分数据缺失有效的标注信息。由于数据标注任务的成本较高,通过无监督学习的方式进行模型训练可以大大减少投入的人力、物力和时间成本,所以无监督学习成为机器学习领域一个重要的研究方向[1-2]。其次,传统机器学习算法中存在用训练集数据进行训练得到的模型无法适应现实场景的问题,这是由训练集数据与实际测试数据的特征分布不同导致的[3]。
针对以上问题,迁移学习(transfer learning,TL)方法被提出[4],域适应学习(domain adaptation learning,DAL)作为一种同构迁移学习方法[5],在源域与目标域样本特征分布不同但相似的前提下,将源域样本分类模型迁移到目标域,使模型适应目标域数据。无监督域适应模型通过带标签源域数据和无标签目标域数据进行训练,即使训练过程中不包含目标域标注信息,也可以在目标域数据中实现很好的识别效果。
Ghifary等[6]利用传统DAL思想,使用自编码器学习共享编码以获得域不变特征,实现在特征向量空间中,不同域样本特征之间的距离减小的目的,从而使无标签目标域样本得到正确分类。Sener等[7]提出利用聚类和伪标签的方法来获取分类特征,从而实现在无标签目标域上的分类。卷积神经网络中间特征的分布匹配被认为是实现域适应的有效方法[8]。最大均值差异(maximum mean discrepancy,MMD)[9]使用核函数映射特征来度量两不同分布之间的距离,通过最小化源域与目标域之间的距离得到域共享特征。Tzeng等[10]在分类损失的基础上加了一层适配层,通过在适配层上引入MMD距离来度量最小化两个领域的分布差异。Long等[11-12]在MMD方法的基础上改进,采用多层适配和多核MMD使域差异最小化,实现源域和目标域特征具有相似的特征分布。借鉴生成对抗网络(generative adversarial network,GAN)[13]独特的对抗训练方式,Ganin等[14]提出包含特征生成器和域分类器结构的模型DANN,利用特征生成器生成欺骗域分类器的特征,从而将源域和目标域数据映射到相似的概率分布上。王格格等[15]通过联合使用生成对抗网络和多核最大均值差异度量准则优化域间差异,以学习源域分布和目标域分布之间的共享特征。Sankaranarayanan等[16]提出了一个能够直接学习联合特征空间的对抗图像生成的无监督域适应方法GTA,利用图像生成的对抗过程学习一个源域和目标域特征分布最小化的特征空间。但由于上述使用GAN或MMD的分布对齐方法仅将不同域之间的距离拉近,没有考虑目标样本与决策边界之间的关系,因此无法优化域内类间差异,从而影响域适应分类效果。Saito等[17]通过训练两个分类器以最大化分类差异,但其方法只是减少源域和目标域之间的距离,而未增大目标域不同类之间的距离,这会使目标域样本靠近决策边界,使分类不确定性增加。
为此,本文提出一种基于分类差异和信息熵对抗的无监督域适应模型。利用两个分类器之间的不一致性对齐域间差异,使源域和目标域数据之间的距离最小,同时利用最小化熵的方式降低不确定性,使目标域特征远离决策边界,提高了目标域样本的类间差异。
1 分类差异和信息熵对抗假设给定带标签的源域数据集
相比于其他域适应算法,本文算法在最小化域间差异的同时,可以使目标域内不同类别样本之间的差异最大化。如图1所示,对于目标域数据,其他方法因为仅对齐域间差异,缩小源域和目标域数据之间的距离,所以特征生成器会在分类边界附近生成模糊特征。本文模型方法利用对抗训练思想,最小化源域与目标域数据之间的距离,同时使目标域不同类别远离分类边界,获得更加具有区分性的特征,从而提高域适应分类的准确率。
Download:
|
|
分类器的输出为经过Softmax函数得到的不同类别概率,根据信息熵的定义,可以得到该分类器结果的信息熵大小,信息熵越大表示不同类别的概率值越接近,表明分类边界越模糊,反之,信息熵越小,表明分类边界越清晰。如图2所示,借鉴对抗训练思想、特征生成器最小化信息熵、分类器最大化信息熵,实现使生成的特征向量
Download:
|
|
本文算法的目标是利用特定任务的分类器作为判别器来减小源域和目标域特征的距离,以考虑类边界和目标样本之间的关系。为实现这个目标,必须检测到靠近分类边界的目标域样本,本文算法利用了两种分类器在目标样本预测上的不一致性。由于源域数据带标签,所以分类器可以对源域样本正确分类,两分类器
Download:
|
|
使用距离
$ \forall h \in H,{R_T}(h) \leqslant {R_S}(h) + \frac{1}{2}{d_{{\mathcal{H}}}}(S,T) + \lambda $ | (1) |
$ {d_{{\mathcal{H}}}}(S,T) = 2\mathop {\sup }\limits_{(h,h') \in {H^2}} \left| {\mathop E\limits_{x \sim S} I[h(x) \ne h'(x)] - \mathop E\limits_{x \sim T} I[h(x) \ne h'(x)]} \right| $ | (2) |
$ \lambda = \min [{R_S}(h) + {R_T}(h)] $ | (3) |
式中:
$ {d_{{\mathcal{H}}}}(S,T) = \mathop {\sup }\limits_{(h,h') \in {H^2}} \mathop E\limits_{x \sim T} I[h(x) \ne h'(x)] $ | (4) |
式(4)表示两个分类器对目标域样本预测差异的极限值。将
$ \mathop {\sup }\limits_{{{F_1,F_2}}} \mathop E\limits_{x \sim T} I[{{F_1}} \circ G(x) \ne {{F_2}} \circ G(x)] $ | (5) |
引入对抗训练的方式,实现对特征提取器
$ \mathop {\min }\limits_G \mathop {\max }\limits_{{{F_1,F_2}}} \mathop E\limits_{x \sim T} I[{{F_1}} \circ G(x) \ne {{F_2}} \circ G(x)] $ | (6) |
本文算法的目标是获得一个特征生成器,这个特征生成器可以将目标样本的分类不确定性最小化,并且可以使目标域样本与源域样本的距离最小化。
1.3 Softmax交叉熵损失本文使用Softmax交叉熵损失来优化有标注源域数据集上的监督学习分类任务,通过对源域数据的监督学习可以保证特征生成器在先验特征空间上有合理的构造。Softmax交叉熵损失定义为
$ {L_{{\rm{cl}}}}({X_s},{Y_s}) = - \frac{1}{K}\sum\limits_{i = 1}^K {I(i = y_s^{(i)})\log {p_s}(x_s^{(i)})} $ | (7) |
式中:
将两个分类器的概率输出之差的绝对值之和定义为分类距离损失:
$ {L_d}({X_t}) = d({p_1}(y|{x_t}),{p_2}(y|{x_t})){\rm{ = }}\frac{1}{K}\sum\limits_{k = 1}^K {\left| {{p_{1k}} - {p_{2k}}} \right|} $ | (8) |
式中
在目标域中,一个理想的特征向量f输入分类器得到的概率输出应该集中于某一类上。由于目标域数据没有标注信息,无法知道样本的类别,因此本文通过最小化信息熵的方法来促使目标域样本分类概率集中于某一类上,使得到的分类结果更加具有确定性。定义熵损失如下:
$ {L_{{\rm{ent}}}}({X_t}) = H({X_t}) = \frac{1}{K}\sum\limits_{i = 1}^K { - F(G(x_t^{(i)}))\log F(G(x_t^{(i)}))} $ | (9) |
源域由于有标注信息,其样本的分类概率往往集中在所标注的类别上;而目标域由于存在域间差异,其在分类概率上往往不够集中。训练特征提取器最小化信息熵可以在特征向量层减小源域和目标域的域间差异,即使特征提取器具有更强的泛化能力。
1.6 算法流程1)从
2)通过有标注数据进行监督训练;
3)计算损失函数
4)反向传播梯度信号,更新
5)通过无标注数据进行域适应训练;
6)计算损失函数
7)计算损失函数
8)反向传播梯度信号,更新
9)重复训练步骤7)~8)n次。
2 训练步骤分类器
1)模型预训练
为了使特征生成器获得特定任务的区分特征,首先通过监督学习的方式训练特征生成器和分类器以正确地对源域样本进行分类。训练网络
$ \mathop {\min }\limits_{G,{F_1},{F_2}} {L_{{\rm{cl}}}}({X_s},{Y_s}) $ | (10) |
Download:
|
|
2)训练分类器
固定特征生成器
$ \mathop {\min }\limits_{{F_1},{F_2}} {L_{{\rm{cl}}}}({X_s},{Y_s}) - {L_d}({X_t}) - {L_{{\rm{ent1}}}}({X_t}) - {L_{{\rm{ent2}}}}({X_t}) $ | (11) |
3)训练特征生成器
固定分类器
$ \mathop {\min }\limits_G {L_d}({X_t}) + {L_{{\rm{ent1}}}}({X_t}) + {L_{{\rm{ent2}}}}({X_t}) $ | (12) |
在训练过程中,将不断重复上述3个步骤,以实现特征生成器和分类器关于分类距离和信息熵的对抗训练。
3 实验设计与结果分析为了评价ACDIE算法的性能和效果,本文设计了4种实验:数字标识域适应实验、实物域适应实验、t-SNE图可视化实验、信息熵损失对比实验。特征生成器
选择机器学习领域常用数据集进行域适应实验,包括MNIST[19]、USPS[20]、SVHN[21]、SYN SIG[22]和GTSRB[23],示例图片如图5所示。SVHN是现实生活中的街道门牌号数字数据集,包含99289张32像素
Download:
|
|
对于这5个域的数据样本,设置5种不同的域适应情况:
使用mini-batch随机梯度下降的优化器算法,batch size设置为128,随机种子值设置为1,Learning rate设置为0.0002,通过Adam优化器实现网络参数更新,weight decay设置为0.0005。
3.1.3 对比实验结果将本文算法与其他在域适应领域有代表性的方法进行比较,包括MMD[9]、DANN[14]、分离域共享特征和域独有特征的DSN[24]、基于域鉴别器对抗训练的ADDA[25]、学习多域联合分布的CoGAN[26]、利用图像生成的对抗过程学习源域和目标域特征分布差异最小化的GTA[16],以及最大化决策分类器差异的MCD[17]。表1展示了不同方法在5种实验设置情况下的域适应准确率,其中:Source Only表示只使用源域数据进行训练而不进行域适应;分类精度最高的值用粗体表示。根据实验结果,对于5种不同的域适应情况,ACDIE算法的准确率都为最高值。特别是,在MNIST
为了测试模型对于实际物体图片的域适应效果,设计在Ofiice-31数据集的域适应实验。Ofiice-31数据集含有31类不同物品的图片,共计4652张,是测试域适应算法的通用数据集。该数据集的图片分别来自3种不同的数据域,包括在亚马逊网站收集的样本数据Amazon(A)、通过电脑摄像头拍摄得到的样本数据Webcam(W)、利用单反相机拍摄得到的样本数据DSLR(D)。图6分别为A、D、W这3个不同域的图片数据。对于这3个域的数据样本,设置6种不同的域适应情况:A
Download:
|
|
使用mini-batch随机梯度下降的优化器算法,batch size设置为32,随机种子值设置为2 020。特征提取器
为了对比实验的合理性,所有方法在同等条件下进行对比实验,选取ResNet-50网络作为特征提取网络,对比方法包括DANN[14]、GTA[16]和使用条件对抗域适应的CDAN[27]。表2展示了不同方法在6种实验设置情况下的域适应准确率,其中ResNet-50表示使用ResNet-50作为特征提取器对源域数据进行训练而不进行域适应。
从实验结果可以看出,相较于现有的算法模型,本文所提出的ACDIE模型在不同域适应情况下的分类准确率都有不同程度的提高。在D
为了更加直观地看到经过域适应后特征向量的变化,本文采用t-SNE[28]方法将高维特征向量映射到适合观察的二维向量,进而实现数据的可视化。
图7和图8分别是在SVHN
Download:
|
|
Download:
|
|
为了验证将信息熵损失加入对抗训练的有效性,以基于分类差异的域适应模型为基础,设置4组对比实验:1)不加入信息熵损失;2) 仅在优化
从表3的对比实验结果可以看出,在实验3的情况下,通过在优化特征生成器
现有无监督域适应算法仅将不同域之间的距离拉近,没有考虑目标样本与决策边界之间的关系,没有扩大目标域内不同类别样本之间的距离。针对上述问题,本文提出利用两个分类器之间的不一致性对齐域间差异,减小源域和目标域之间的距离,同时通过最小化信息熵来降低分类不确定性的ACDIE模型。最小化信息熵能使相同类别的数据更加聚集,不同类别数据之间的距离更大,而且可以使目标域样本与源域样本在语义空间上分布更加对齐。大量的实验表明,本文提出的的模型相比于领域内其他模型取得了更优的性能,验证了所提改进算法的有效性。
尽管ACDIE模型在多个数据集中都有不错的表现,但它仍存在一些提升空间。在今后的工作中,将进一步从信息论的角度思考,考虑互信息等因素对模型的影响,以提升模型的准确率和鲁棒性。同时将进一步探究不同距离分布度量对域适应结果的影响。
[1] | WANG Xiaolong, GUPTA A. Unsupervised learning of visual representations using videos[C]//Proceedings of the IEEE International Conference on Computer Vision. Santiago, Chile: IEEE, 2015: 2794−2802. (0) |
[2] | MAHJOURIAN R, WICKE M, ANGELOVA A. Unsupervised learning of depth and ego-motion from monocular video using 3D geometric constraints[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City, USA: IEEE, 2018: 5667−5675. (0) |
[3] |
刘建伟, 孙正康, 罗雄麟. 域自适应学习研究进展[J]. 自动化学报, 2014, 40(8): 1576-1600. LIU Jianwei, SUN Zhengkang, LUO Xionglin. Review and research development on domain adaptation learning[J]. Acta automatica sinica, 2014, 40(8): 1576-1600. (0) |
[4] | PAN S J, YANG Qiang. A survey on transfer learning[J]. IEEE transactions on knowledge and data engineering, 2010, 22(10): 1345-1359. DOI:10.1109/TKDE.2009.191 (0) |
[5] | ROZANTSEV A, SALZMANN M, FUA P. Beyond sharing weights for deep domain adaptation[J]. IEEE transactions on pattern analysis and machine intelligence, 2019, 41(4): 801-814. DOI:10.1109/TPAMI.2018.2814042 (0) |
[6] | GHIFARY M, KLEIJN W B, ZHANG Mengjie, et al. Deep reconstruction-classification networks for unsupervised domain adaptation[C]//Proceedings of the 14th European Conference on Computer Vision. Amsterdam, The Netherlands: Springer, 2016: 597−613. (0) |
[7] | SENER O, SONG H O, SAXENA A, et al. Learning transferrable representations for unsupervised domain adaptation[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona, Spain: Curran Associates Inc., 2016: 2110−2118. (0) |
[8] | SUN Baochen, FENG Jiashi, SAENKO K. Return of frustratingly easy domain adaptation[C]//Proceedings of the Thirtieth AAAI Conference on Artificial Intelligence. Phoenix, Arizona: AAAI Press, 2016: 2058−2065. (0) |
[9] | GRETTON A, BORGWARDT K M, RASCH M J, et al. A kernel two-sample test[J]. The journal of machine learning research, 2012, 13: 723-773. (0) |
[10] | TZENG E, HOFFMAN J, ZHANG Ning, et al. Deep domain confusion: maximizing for domain invariance[J]. Computer science, 2014. (0) |
[11] | LONG Mingsheng, CAO Yue, WANG Jianmin, et al. Learning transferable features with deep adaptation networks[C]//Proceedings of the 32nd International Conference on Machine Learning. Lille, France: JMLR, 2015: 97−105. (0) |
[12] | LONG Mingsheng, ZHU Han, WANG Jianmin, et al. Unsupervised domain adaptation with residual transfer networks[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona, Spain: Curran Associates Inc., 2016: 136−144. (0) |
[13] | GOODFELLOW I J, POUGET-ABADIE J, MIRZA M, et al. Generative adversarial nets[C]//Proceedings of the 27th International Conference on Neural Information Processing Systems. Montreal, Canada: MIT Press, 2014: 2672−2680. (0) |
[14] | GANIN Y, USTINOVA E, AJAKAN H, et al. Domain-adversarial training of neural networks[J]. The journal of machine learning research, 2016, 17(1): 2096-2030. (0) |
[15] |
王格格, 郭涛, 余游, 等. 基于生成对抗网络的无监督域适应分类模型[J]. 电子学报, 2020, 48(6): 1190-1197. WANG Gege, GUO Tao, YU You, et al. Unsupervised domain adaptation classification model based on generative adversarial network[J]. Acta electronica sinica, 2020, 48(6): 1190-1197. DOI:10.3969/j.issn.0372-2112.2020.06.021 (0) |
[16] | SANKARANARAYANAN S, BALAJI Y, CASTILLO C D, et al. Generate to adapt: aligning domains using generative adversarial networks[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City, USA: IEEE, 2018: 8503−8512. (0) |
[17] | SAITO K, WATANABE K, USHIKU Y, et al. Maximum classifier discrepancy for unsupervised domain adaptation[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City, USA: IEEE, 2018: 3723−3732. (0) |
[18] | BEN-DAVID S, BLITZER J, CRAMMER K, et al. A theory of learning from different domains[J]. Machine learning, 2010, 79(1/2): 151-175. (0) |
[19] | LECUN Y, BOTTOU L, BENGIO Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324. DOI:10.1109/5.726791 (0) |
[20] | HULL J J. A database for handwritten text recognition research[J]. IEEE transactions on pattern analysis and machine intelligence, 1994, 16(5): 550-554. DOI:10.1109/34.291440 (0) |
[21] | NETZER Y, WANG T, COATES A, et al. Reading digits in natural images with unsupervised feature learning[C]//Proceedings of the NIPS Workshop on Deep Learning and Unsupervised Feature Learning. Granada, Spain, 2011: 5−16. (0) |
[22] | MOISEEV B, KONEV A, CHIGORIN A, et al. Evaluation of traffic sign recognition methods trained on synthetically generated data[C]//Proceedings of the 15th International Conference on Advanced Concepts for Intelligent Vision Systems. Poznań, Poland: Springer, 2013: 576−583. (0) |
[23] | STALLKAMP J, SCHLIPSING M, SALMEN J, et al. The German traffic sign recognition benchmark: a multi-class classification competition[C]//Proceedings of 2011 International Joint Conference on Neural Networks. San Jose, USA: IEEE, 2011: 1453−1460. (0) |
[24] | BOUSMALIS K, TRIGEORGIS G, SILBERMAN N, et al. Domain separation networks[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona, Spain: Curran Associates Inc., 2016: 343−351. (0) |
[25] | TZENG E, HOFFMAN J, SAENKO K, et al. Adversarial discriminative domain adaptation[C]//Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition. Honolulu, USA: IEEE, 2017: 7167−7176. (0) |
[26] | LIU Mingyu, TUZEL O. Coupled generative adversarial networks[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona, Spain: Curran Associates Inc., 2016: 469−477. (0) |
[27] | LONG Mingsheng, CAO Zhangjie, WANG Jianmin, et al. Conditional adversarial domain adaptation[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. Montréal, Canada: Curran Associates Inc., 2018: 1647−1657. (0) |
[28] | VAN DER MAATEN L, HINTON G. Visualizing data using t-SNE[J]. Journal of machine learning research, 2008, 9(2605): 2579-2605. (0) |