点击下方卡片,关注「集智书童」公众号
测试时的域自适应旨在使用几张 未标注 的图像将源域上训练的模型适应到未见过的目标域。新兴研究表明,标签和域信息分别嵌入在权重矩阵和批量归一化(BN)层中。以前的工作通常直接更新整个网络,而没有明确地将标签和域知识解耦。结果,这导致了知识干扰和缺陷的域适应。
在本工作中,作者提出通过仅操作BN层来减少这种学习干扰并提高域知识学习。然而,BN的归一化步骤在从少量样本重新估计统计信息时内在地不稳定。作者发现,仅更新BN中的两个对数参数而保持源域统计信息可以大大减少歧义。为了进一步增强从无标签数据中提取域知识,作者构建了一个具有标签无关自监督学习(SSL)的辅助分支以提供监督。
此外,作者提出了一种基于元学习的双级优化方法,以强制辅助分支和主分支的学习目标对齐。作者的目标是使用辅助分支适应域并为主任务提供后续推理时的益处。作者的方法在推理时的计算成本与辅助分支相同,因为辅助分支可以在适应后完全丢弃。
在五个WILDS真实世界域移位数据集上,作者的方法超过了先前的研究。作者的方法还可以与具有标签相关优化的方法集成,以进一步推动性能边界。
代码:https://github.com/ynanwu/MABN
Introduction
深度模型由于与训练和测试数据分布的匹配而实现了惊人的性能。然而,这种假设在实际世界中是脆弱的,因为收集训练数据以覆盖通用分布是不可能的。因此,在推理时遇到的未见分布会导致性能退化,这源于分布转移。
无监督域自适应(UDA)是一种研究,通过将标记源数据和 未标注 目标数据的互相依赖性结合来减轻分布转移。显然,对于每个未见过的目标域,重复执行UDA是不切实际的。相比之下,域泛化(DG)旨在克服这一限制,通过训练在源数据上能够有效泛化的模型,以应对未见过的目标域。然而,期望一个通用的模型来处理所有不同的未见域是不现实的。
为了解决这个问题,一些先驱工作,如ARM提出,在推理之前,使用少量 未标注 数据将模型适应到每个目标域。作者将这种场景称为测试时域自适应(TT-DA),如图1所示。TT-DA背后的动机是,目标域中少量 未标注 数据(例如,在相机校准过程中获得的照片)易于获得,并且它们为目标域的底层分布提供了线索。
适应一个有限的 未标注 数据集仍然是一个挑战。ARM通过元学习自适应模块来解决无监督适应问题。其主要限制是,内层和外层优化是在相同的 未标注 数据批上进行的。直觉上,自适应模块仅优化以适应单个批次而不是更广泛的数据分布。不幸的是,这可能会阻碍有效的泛化。在更具挑战性的实际世界基准测试中,ARM甚至有时会落后于经验风险最小化(ERM) Baseline ,如WILDS排行榜所示。
相反,Meta-DMoE强制适应模型在训练期间将 Query 集划分为不重叠的 Query 集,以提高整体性能。然而,Meta-DMoE将适应性视为从一组在源域上训练的教师模型中 Query 目标相关知识的过程。适应性受教师模型的限制。此外,教师集合的大小随着源域的数量而增长,从而放大计算需求并显著减慢适应过程。
值得注意的是,ARM和Meta-DMoE都没有故意将域和标签之间的知识解耦。这种知识的潜在重叠可能会引入干扰,使模型容易受到性能下降的影响。
在本工作中,作者提出了一种简单而有效的解决方案,通过解耦标签相关知识来增强域知识获取的细化。作者的工作部分受到观察的影响,即权重矩阵倾向于包含标签信息,而特定域的知识嵌入在批量归一化(BN)层中。作者提出了一种战略性的BN层操作来优化域特定知识的获取和传输。BN层对输入特征进行归一化,然后使用两个仿射参数重新缩放和移动。
然而,在TT-DA下计算的目标域归一化统计信息可能不稳定,因为作者只有从目标域获取的一小批示例。相反,作者提出只适应两个仿射参数,而直接使用在训练过程中从源域学习的归一化统计信息。直觉是,特征首先会归一化到源分布,但适应的仿射参数旨在将归一化的特征拉向目标分布,如图1底部所示。在不稳定的统计估计引起的歧义下,在适应步骤中学习域知识优化更加稳定。
另一方面,这种策略在推理时的计算成本相同,无需承担额外的操作或参数的负担,这些负担学习更好的统计更为繁琐。此外,为了在无标签数据上生成面向域的监督信号,作者构建了一个辅助分支并采用类无关自监督学习(SSL)。
总体而言,作者进行两阶段训练。第一阶段,作者训练整个模型以学习标签知识和归一化统计信息,通过混合所有源数据。为了赋予仿射参数在新领域的自适应能力,在第二阶段,作者使用元学习中的双级优化。具体而言,作者将每个源域视为一个“任务”,并使用少量 未标注 图像通过辅助分支(同时保持其他模型参数不变)更新仿射参数(如图1所示)。为了确保仿射参数的优化与主要任务对齐,作者在任务上的一个不重叠集合上评估适应后的仿射参数的元目标。请注意,仿射参数在元 Level 优化,以充当适应的初始化。通过这种学习范式,作者的模型学会有效地使用 未标注 数据适应一个域,并使用适应后的模型进行推理。作者将这种方法称为Meta-Adaptive BN (MABN)。
作者的贡献如下:
-
提出了一种简单而有效的无监督自适应方法,专门针对TT-DA进行定制。作者通过每个目标域的自我监督损失仅适应仿射参数,以提高域知识学习。
-
采用双级优化来使学习目标与评估协议对齐,从而得到能够适应域知识的仿射参数。
-
进行了广泛的实验来证明作者的方法在域知识学习方面更有效。因此,作者的域适应模型可以无缝集成到基于熵的TTA方法(例如,TNET)中,其中优化更倾向于标签知识。
-
在WILDS-iWildCam上超过了ARM和Meta-DMoE,分别取得了9.7%和4.3%的Macro F1。作者在WILDS上的五个真实世界域移位基准测试中,无论是分类还是回归任务,都取得了优越的性能。
Related Work
域移位。无监督域自适应(UDA)通过联合训练标记源数据和 未标注 目标数据来解决域移位问题。流行的方法包括在不同分布之间对统计差异进行对齐;通过对抗训练开发通用特征空间。
最近,源域无关的UDA方法被提出,允许源数据不存在,例如基于生成的方法。然而,这两种设置都是不现实的。它们需要访问足够大的 未标注 目标数据集,并且其目标是在特定目标域上实现高性能。域泛化(DG)是另一种研究线程,它从一个或几个源域学习一个通用的模型,并期望该模型在未见过的目标域上表现良好。这两种域都无法同时访问。然而,将这样的模型部署到所有未见过的目标域上,无法探索域专业知识,通常会产生较差的解决方案。
测试时适应/训练(TTA)旨在通过适应测试数据来克服由域移位引起的表现下降。Sun等人(2020)在推理过程中使用旋转预测来更新模型。Chiet等人(2021);刘等人(2022, 2023)通过重建输入图像来实现内部学习。ARM将测试时适应与域泛化结合,元学习一个能够适应未见目标域的模型。Meta-DMoE建议将每个源域视为专家,通过 Query 目标域数据和这些专家之间的相关知识来解决域移位问题。然而,这些方法没有明确地识别出哪些知识以及如何学习它们以扩大性能提升。
批量归一化。Nado等人(2020)引入了预测时的批量归一化,利用测试批统计进行标准化。同样,(Du等人,2020)和(Hu等人,2021)分别使用预定义的超参数和移动平均来更新测试统计。相反,(Schneider等人,2020)提出通过结合源和测试批统计来适应批量归一化统计以减轻协变量漂移。 (Lim等人,2023)通过调整源和测试批统计之间的重要性来根据每个批量归一化层的域移移敏感性插值统计。与现有方法的主要区别在于,作者不是学习一个可扰动的统计量,而是专注于研究在TT-DA设置下少量学习者的泛化能力中仿射参数的作用。
元学习。现有的元学习方法可以分为:
- 基于模型的元学习;
- 基于优化的元学习;
- 基于度量的元学习。
典型的元学习方法利用双级优化来训练一个适用于下游适应的模型。作者的工作基于MAML,它通过任务的多个episodes来初始化模型并通过梯度更新进行快速适应。这种学习范式已经在不同的视觉任务中得到广泛应用,如零样本学习和类增量学习。在作者的情况下,适应是在无监督的方式下实现的,并且利用双级优化来适应未见域并在该域的所有数据样本上泛化良好。
The Proposed Method
问题设置。在本工作中,作者考虑了Zhong等人(2022)中提出的问题设置,作者称之为测试时域自适应(TT-DA)。在离线学习阶段,作者有权访问个标记源域,这些域分别表示为。每个源域包含一组标签数据,即,其中分别表示输入图像和相应的标签。在这些源域上进行离线训练后,作者获得了一个训练好的模型和一个适应机制。
在测试阶段,作者得到了一个新的目标域。作者的目标是使用从目标域获取的仅有的少量无标签图像来适应这个目标域。在这里,作者假设所有域共享相同的标签空间,但它们之间可以存在任何源域和目标域之间的域移位。
TT-DA与域自适应相关,但有一些关键区别。传统的无监督域自适应(UDA)假设在训练期间可以访问源域(可以是单个或多源域)和目标域。相比之下,在离线训练期间,TT-DA没有访问目标域的权限。在测试阶段,TT-DA的适应性不能访问源域。另一个相关设置是源域无关域自适应(SFDA)。关键区别在于,SFDA的适应性假设可以访问目标域的大量无标签数据。
相比之下,TT-DA只需要很少的目标图像来进行适应。TT-DA是许多实际场景中更现实的设置。例如,在野生动物监测应用中,在安装新位置(即域)的监控摄像头后,作者可能只需要收集很少的图像就可以部署一个适应性模型。TT-DA也与测试时适应(TTA)有关。关键区别在于,TTA通常为一批测试示例适应模型,然后使用适应性模型对同一批示例进行预测。
实际上,每次在需要预测之前都适应模型可能是不现实的。相比之下,TT-DA仅使用一小批图像一次适应模型,然后使用适应性模型对目标域的所有测试图像进行预测。在实际场景中,这是一个更现实的设置。
Motivations
利用少量 未标注 数据来适应模型是一个具有挑战性的问题,尤其是在遇到未知分布时。在这个复杂的设置中,有两个基本问题需要仔细考虑:
- 什么类型的知识对于适应未见域最有效?
- 如何获取足够的监督来指导模型更新以适应该域?
先前的工作已经表明,标签和域知识分别被编码在权重矩阵和批量归一化(BN)层中。在TT-DA的背景下,所有域共享相同的标签空间,这暗示着可以从大量的源数据中学习标签信息。
因此,作者的方法主要关注选择性地调节BN层,同时保持已经良好获得的标签知识不变。给定一个特征图,BN层对每个通道分别进行归一化和仿射变换。
和是批量的均值和方差,是一个小数以防止除以0。和是两个仿射参数。在训练过程中,运行均值和方差由每个批次更新,并收敛到源数据的真实统计信息。它们将在推理中使用。然而,使用少量示例估计未知目标域的统计信息是不稳定的,因为数据点是稀疏采样。
为了减少这种不稳定性引起的干扰,作者提出利用来自源数据的和,只更新和。尽管第一个归一化步骤将使输入特征向源分布转换,但和是优化的以适应并拉向目标分布的。这种方法实现简单,保持可学习参数的最小数量,且不引入任何额外推理成本。
为了为和提供监督,并同时进一步增强域知识学习,作者提出利用类无关的自监督学习。具体而言,一个辅助分支与主要的分类分支并行集成,如图2(a)所示。这两个分支共享相同的输入特征,该特征由 Backbone 网络编码。辅助分支的结构取决于自监督学习算法 – 它可以是一个简单的MLP或更复杂的结构。
在本工作中,作者并不打算设计新的自监督学习方法,而是采用现有的方法,例如BYOL。在适应目标域后,辅助分支可以被丢弃,只有原始网络,例如ResNet,用于推理。
Learning label-dependent representation
考虑到权重矩阵编码了丰富的标签信息,且所有域共享相同的标签空间,作者首先在源数据上进行大规模的训练。作者将中的所有数据混合在一起,以均匀地采样 mini-batch。辅助分支和主分支通过优化联合损失进行更新。
和分别表示自监督损失和监督交叉熵损失, respectively。请注意,可以相应地替换(例如,均方损失)以处理回归问题。这两个损失通过进行平衡。在训练期间收集的运行均值和方差统计信息收敛到源数据的准确统计信息。
Learning to adapt to unseen domain knowledge
通过优化的模型不一定准备好适应未见域,因为辅助损失()和主要损失()是相互独立的。换句话说,由辅助分支的梯度更新参数不能保证对主要任务有正面的改进。此外,该模型缺乏对其后续学习任务的意识,这涉及适应未见域。
Meta-auxiliary training.
为了应对这个问题,作者提出了一种元辅助学习方案,将两个损失之间的梯度对齐,并赋予模型学习适应未见域的能力。在元辅助训练阶段,作者将权重矩阵冻结以保留丰富的标签信息。作者还直接采用源数据的运行均值和方差来减少由不稳定的少量数据引起的域信息学习干扰。
因此,在这个阶段,只有可学习的仿射参数。这里使用上标和分别表示共享 Backbone 网络和辅助分支的参数。请注意,作者采用一个没有BN的单层MLP作为分类头,并且取决于自监督学习。如果辅助分支没有BN,可以简单地忽略它。
为了训练,作者采用元学习中的 episodic 学习方法,将每个域视为一个“任务”。在任务 Level 进行嵌套优化,而不是实例 Level ,这样可以元学习以满足学习适应未见域的任务。
对于每个迭代,作者选择一个源域(任务)。作者从中选择一个无标签的支持集和一个有标签的 Query 集。作者首先使用的学习率在支持集上对进行自监督学习以适应这个域:
其中表示所有权重矩阵。理想情况下,应该适应域并改进主分支。换句话说,适应后的模型应该在该域的所有数据上按照主要任务进行泛化。因此,作者使用公式2作为元目标,在分离的标记 Query 集上评估并更新如下:
其中是外层循环学习率。更新需要来自的监督,因此作者使用联合损失。请注意,在公式4中,评估是在上进行的,但梯度更新是在原始上进行的,以实现元 Level 更新。
为了简单起见,在公式3和4中省略了输入和输出的符号。该过程重复,直到收敛。算法1和图2(b)详细说明了整个训练流程。
元辅助测试。元参数已经专门学习,以帮助在未见过的目标域上进行域知识适应。在测试时,给定一个未见过的目标域,通过执行算法1和图2(c)中的第12行,使用包含几个 未标注 图像的支撑集,可以简单地获得适应参数。然后,适应模型可以在这个域的所有测试样本上进行推理,并丢弃辅助分支以保留计算成本。
Experiments
数据集和评估指标。在本工作中,作者遵循Meta-DMoE对方法在WILDS的五个基准测试进行评估:iWildCam,Camelyon17,RxRx1,FMoW和PovertyMap。
请注意,作者遵循官方的训练/验证/测试划分,并报告与[12]中相同的指标,包括准确性,Macro F1,最差(WC)准确性,皮尔逊相关系数(r)及其最差对应物。作者还将在DomainNet基准测试[20]上进行评估。补充说明提供了基准测试的详细描述。
模型架构。为了进行忠实的比较,作者遵循WILDS使用ResNet50,DenseNet12和ResNet18作为作者提出的方法的iWildCam/RxRx1,Camelyon17/FMoW和PovertyMap数据集的 Backbone 。 Backbone 的最后平均池化层的输出作为辅助分支和主分支的输入。
对于辅助分支,作者选择BYOL作为自监督学习方法。其架构包括两个带有BN和ReLU激活的MLP层,读者可以参考原始论文了解更多细节。对于主分支,作者简单地使用一个线性层用于分类和回归任务。
实现。作者遵循[13]使用ImageNet-1K预训练权重作为初始化进行联合训练。使用Adam优化器最小化公式2,学习率(LR)为,经过20个周期后,LR减少为原来的1/2。公式2中的设置为0.1。在元辅助训练期间,作者将整个网络的权重矩阵固定,并直接使用来自源数据的运行统计和进行BN层的统计。只有BN层的仿射参数和进一步使用算法1在固定LR为(对于)和(对于)下优化10个周期。
在测试期间,对于每个目标域,作者从iWildCam中随机选择12张图像,从其他数据集中选择32张图像进行适应性训练(算法1的第12-13行)。然后,使用适应后的模型在该域的所有图像上进行测试。该过程对于所有目标域重复进行。所有实验采用5个随机种子进行,以展示变化。
在iWildCam,Camelyon17和RxRx1数据集上,作者的方法在分类精度上分别比Meta-DMoE提高了3.5%,18.4%和3.9%。作者还以1.2/4.6的百分比优势在iWildCam上超过了Meta-DMoE,以1.0在Camelyon17上超过了Meta-DMoE,以2.9在RxRx1上超过了Meta-DMoE,以1.2/0.7在FMoW上超过了Meta-DMoE,以0.05/0.04在PovertyMap上超过了Meta-DMoE。
请注意,作者的方法只使用了一个模型,因此更加轻量,而Meta-DMoE则有一组教师模型。这些结果表明,确定域特定知识适应的关键参数是有效的。这种适应性对于增强目标域的泛化非常重要。表2报告了作者在DomainNet上的优势。
作者的方法MABN真的在学习域知识吗?为了证明作者的适应性模型已经学习了域特定信息并适应了每个目标域,作者进行了一个实验,将适应性参数在目标域之间进行Shuffle。
具体而言,假设存在个目标域,作者将所有适应性参数保存为。对于第个目标域,作者随机使用其他域(即)的,作者称之为“未匹配”。然后,作者将它与第个目标域使用其自己的进行比较,作者称之为“匹配”。
如表3所示,可以观察到几个观察结果:
- 当目标域使用非匹配的适应性域知识时,性能显著下降。然而,当目标域使用其自己的知识时,性能提升;
- 与非自适应模型相比,使用正确的适应性可以带来巨大的改进。
因此,作者可以得出结论,作者的方法为每个目标域学习独特的域知识,并提高整体性能。作者还使用t-SNE可视化调整前后的特征。如图3所示,经过调整后的每个类簇更具判别性。
为了进一步证明作者的方法正在学习域知识,作者将MABN与基于熵的TTA方法TENT无缝集成,TENT更注重标签相关性。为此,对于每个目标域,作者首先使用MABN来适应域,然后在每个批次上应用TENT。作者还比较了整个BN层和仅学习域特定信息的BN层的情况。
从表4中可以得出以下几个结论:1)适应整个BN层存在问题,因为归一化不稳定且与仿射参数干扰;2)由于权重矩阵中已经包含了与标签相关的信息,因此需要学习域特定信息;3)由于作者旨在学习域信息,与TENT的集成进一步通过1.28%和0.58%的准确率和F1得分取得了巨大的优势。
Elaboration on Batch Normalization (BN)
在主论文(动机部分)中,作者讨论了关于BN的动机。主论文中的公式1是简化的。在本补充材料中,作者将BN操作进行了更详细的阐述。对于每个层,给定一个特征图,其中维度为,其中是批处理大小,是空间维度,是通道数量。BN首先对每个通道进行归一化:
其中和分别是第个通道的均值和方差。归一化后的特征进一步通过两个仿射参数进行变换。
其中和是缩放和偏移参数。这样的批量归一化操作确保在将它们发送到下一个权重矩阵时,它们的分布不会改变。在训练期间,移动平均统计量和按如下方式更新:
其中是动量,意味着移动平均值由衰减,但仅在或上更新一小部分。经过多次迭代,或逐渐收敛到满足源域分布的稳定数字,这些数字将在推理中使用。显然,这些归一化统计量和学习的仿射变换仅适用于接近源数据分布的数据。一旦域移位发生,BN可能会出现问题。
在TT-DA设置下,对于每个目标域,作者允许使用少量 未标注 数据来更新模型。这个数量远小于训练批处理大小(例如,128或256)。在训练期间,通常将批处理大小设置得很大(例如,128或256),以稳定训练。
然而,适应图像的数量要少得多(例如,12张图像)。这导致了归一化统计量的估计问题,从而干扰了学习目标分布。因此,这激励作者固定归一化步骤,只更新可学习的仿射参数。
作者将不同适应方法(如图4所示)前后BN下的特征分布进行了可视化。第一行的右侧显示,如果没有适应,分布的尾部会出现扭曲。第三行的右侧显示,当更新正常的统计学和仿射参数时,扭曲仍然存在,分布不再是零中心。第二行的右侧显示,仅适应仿射参数时,尾部扭曲消失,整个分布更靠近零。这表明,在保持归一化统计量不变的情况下,学习分布的效果更有效,干扰更小。
Adaption without Meta-Auxiliary Learning
元辅助学习的最大贡献是将两个学习目标对齐,以便辅助分支更新的模型可以对主分类/回归分支产生益处。为了进一步展示其有效性,作者直接使用联合训练的模型进行适应。
图5显示,无论用于适应的图像数量如何,性能都非常稳定。有时,与没有适应的 Baseline 相比,性能甚至下降。这表明了强制两个目标保持一致和相互依赖的重要性。
Details of Datasets
作者在WILDS的5个图像测试平台上进行评估:iWildCam包括323个摄像头陷阱的203,029张图像,其中243和48个摄像头陷阱作为源域和目标域。Camelyon17包括5家医院提供的455,954个组织切片,其中3家医院作为源域,1家医院作为目标域。RxRx1包括51个运行批次的125,510张细胞图像,其中33和14个批次被认为作为源域和目标域。
在PovertyMap中,有来自46个国家和地区的19,669张卫星图像,其中26和10个国家被认为作为源域和目标域。FMoW包括在5个地理区域和16个不同时间拍摄的141,696张卫星图像,其中55个区域被认为作为源域,10个区域被认为作为目标域。在所有这些数据集中,作者选择具有最高验证性能的最佳模型。
DomainNet:包含6个域和5690000张图像,每个域有345个类别。作者遵循官方的“留一法”来训练6个模型。请注意,作者遵循官方的训练/测试划分,而不是随机选择数据集的一部分作为测试划分。
参考
[1]. Test-Time Domain Adaptation by Learning Domain-Aware Batch Normalization
扫码加入👉「集智书童」交流群
(备注:方向+学校/公司+昵称)
想要了解更多:
前沿AI视觉感知全栈知识👉「分类、检测、分割、关键点、车道线检测、3D视觉(分割、检测)、多模态、目标跟踪、NerF」
行业技术方案👉「AI安防、AI医疗、AI自动驾驶」AI模型部署落地实战👉「CUDA、TensorRT、NCNN、OpenVINO、MNN、ONNXRuntime以及地平线框架」
欢迎扫描上方二维码,加入「集智书童-知识星球」,日常分享论文、学习笔记、问题解决方案、部署方案以及全栈式答疑,期待交流!
免责声明凡本公众号注明“来源:XXX(非集智书童)”的作品,均转载自其它媒体,版权归原作者所有,如有侵权请联系我们删除,谢谢。
点击下方“阅读原文”,了解更多AI学习路上的「武功秘籍」