A Simple Signal for Domain Shift

1IISER Pune    2Indian Institute of Science, Bengaluru, India
ICCVW, 2023

Abstract

Test time domain adaptation has come to the forefront as a challenging scenario in recent times. Although single domain test-time adaptation has been well studied and shown impressive performance, this can be limiting when the model is deployed in a dynamic test environment. We explore this continual domain test time adaptation problem here. Specifically, we question if we can translate the effectiveness of single domain adaptation methods to continuous test-time adaptation scenario. We take a step towards bridging the gap between these two settings by proposing a domain shift detection mechanism and hence allowing us to employ the current test-time adaptation methods even in a continual setting. We propose to use the given source domain trained model to continually measure the similarity between the feature representations of the consecutive batches. A domain shift is detected when this measure crosses a certain threshold, which we use as a trigger to reset the model back to source and continue test-time adaptation. We demonstrate the effectiveness of our method by performing experiments across datasets, batch sizes and different single domain test-time adaptation baselines.

Problem Setting

In standard Test-Time Adaptation (TTA), a model trained on source domain data is adapted at test time to a single target domain. Existing methods (TENT, AaD, EATA) work well under this assumption. However, when the model is deployed in the real world, it encounters a continuous stream of data from changing domains — a setting known as Continual Test-Time Adaptation (CTTA).

In CTTA, no explicit signal is provided about when the domain changes. Naively applying TTA methods in this setting causes error accumulation and catastrophic forgetting: the model overfits to past domains and degrades as new domains arrive. For example, TENT-CTTA achieves 86.8% mean error on ImageNet-C, while TENT-TTA (with oracle domain shift knowledge) achieves only 58.6%.

The key challenge is: can we automatically detect when a domain shift has occurred, so that TTA methods can be applied effectively in a continual setting, without requiring oracle domain shift labels?

t-SNE plots and DSS signals across batch sizes
Figure 1: (Top) t-SNE plots of batch features for batch sizes 25, 50, and 200, coloured by corruption order. Clusters become more compact and separated as batch size grows. (Bottom) Corresponding DSS signals: spikes align with actual domain shifts (red dotted lines), showing that the mean feature cosine dissimilarity is a reliable domain shift indicator.

Proposed Method: DSS

Source Model as a Domain Shift Detector

We observe that the feature extractor of the source-trained model naturally encodes domain-specific information. For a batch of test samples \(x_t = \{x_1, \ldots, x_N\}\) at time \(t\), we compute the mean feature vector:

\[ \mathbb{E}[v_f(t)] = \frac{1}{N} \sum_{k=1}^{N} f(x_k) \]

As batch size increases, the class-specific component of the mean averages out, leaving the domain-specific component dominant. This makes the mean feature a reliable proxy for the domain identity of the batch.

Domain Shift Signal (DSS)

We define the Domain Shift Signal as the cosine dissimilarity between the mean features of consecutive batches:

\[ \text{DSS} = 1 - \cos\!\left(\mathbb{E}[v_f(t)],\; \mathbb{E}[v_f(t-1)]\right) \tag{1} \]

When consecutive batches come from the same domain, DSS is low. When a domain shift occurs, the mean features diverge and DSS spikes.

Domain Shift Detector (DSD)

To robustly detect shifts, we compute a moving average of DSS and define the detector as:

\[ \text{DSD} = \mathbb{1}\!\left\{\text{DSS} > k \cdot \text{MovAvg}(\text{DSS})\right\} \tag{2} \]

where \(k = 3\) across all experiments. When DSD fires, the model is reset to the source checkpoint, allowing the underlying TTA method to restart adaptation on the new domain.

Algorithm

The DSS module is entirely decoupled from the choice of TTA method. It wraps any single-domain TTA method (TENT, AaD, EATA, etc.) and converts it into a continual TTA method:

  1. For each incoming batch, compute the mean feature vector using the fixed source feature extractor.
  2. Compute DSS as cosine dissimilarity with the previous batch's mean feature.
  3. If DSD fires (DSS exceeds adaptive threshold): reset model to source checkpoint and reset optimizer state.
  4. Continue with standard TTA adaptation on the current batch.

This approach is computationally cheap: it adds only a single forward pass through the frozen source feature extractor per batch, with negligible memory overhead.

Results

We evaluate DSS on three benchmarks (ImageNet-C, CIFAR-100C, CIFAR-10C) with 15 corruption types at severity 5. Each TTA method is compared in three settings: TTA (oracle domain shift labels, upper bound), CTTA (continual, no domain labels, baseline), and DSS (our method, no domain labels). Metric: error % (lower is better). Numbers in brackets show improvement of DSS over CTTA.

Method Mean Error (%) ↓
ImageNet-C CIFAR-100C CIFAR-10C
Source82.046.443.5
BN Stats68.635.420.4
CoTTA62.632.516.2
RMT59.930.217.0
SATA60.130.316.1
TENT-TTA58.631.018.6
TENT-CTTA86.861.220.7
TENT-DSS (Ours)57.9 (+28.9)31.5 (+29.7)18.4 (+0.2)
AaD-TTA68.534.919.6
AaD-CTTA90.062.522.0
AaD-DSS (Ours)68.4 (+21.6)35.0 (+27.5)19.5 (+1.5)
EATA-TTA56.931.618.1
EATA-CTTA58.032.217.9
EATA-DSS (Ours)56.7 (+1.3)31.4 (+0.8)17.9 (0.0)

Table 1: Mean error % (lower is better) across 15 corruptions at severity 5. Numbers in brackets show improvement of DSS over the corresponding CTTA baseline. Backbones: ResNet-50 (ImageNet-C), ResNeXt-29 (CIFAR-100C), WideResNet-28 (CIFAR-10C).

BibTeX

@InProceedings{chakrabarty2023dss,
  author    = {Chakrabarty, Goirik and Sreenivas, Manogna and Biswas, Soma},
  title     = {A Simple Signal for Domain Shift},
  booktitle = {ICCVW},
  year      = {2023}
}