Similar Class Style Augmentation for Efficient Cross-Domain Few-Shot Learning

Indian Institute of Science, Bengaluru, India
CVPRW, 2023

Abstract

Cross-Domain Few-Shot Learning (CD-FSL) aims to recognize new classes from unseen domains, given limited training samples. Majority of the state-of-the-art approaches for this task introduce new task-specific additional parameters for adapting to the novel task, which involves changing the trained model architecture, in addition to increasing the number of model parameters. The first contribution of this work is to revisit existing approaches like modifying the Batch Normalization affine parameters and the scale hyperparameter in cosine similarity based softmax loss for adapting the trained model to new tasks, without changing the model architecture. Secondly, to aid model learning with few examples per class, we propose to augment the data of each class with the styles of semantically similar classes. Extensive evaluation on the challenging Meta-Dataset shows that this simple framework is very effective for the CD-FSL task. We also show that the Similar-class Style Augmentation module can be seamlessly integrated with existing approaches to further improve their performance, thus establishing state-of-the-art in this challenging area.

Cross-Domain Few-Shot Learning

In CD-FSL, a universal feature extractor \(F\) is trained on labeled data from multiple source domains \(D_\text{train}\). At test time, it must adapt to N-way K-shot tasks sampled from unseen classes in unseen domains \(D_\text{test}\). Each task \(\mathcal{T} = (\mathcal{S}, \mathcal{Q})\) consists of a labeled support set \(\mathcal{S}\) and an unlabeled query set \(\mathcal{Q}\).

Most state-of-the-art methods (FLUTE, URL, TSA) handle this by introducing task-specific learnable modules, which either increase model parameters or change the trained architecture. This is undesirable in many practical settings. SSA-BNS addresses CD-FSL without any architectural changes or extra parameters, by revisiting two under-explored components: BatchNorm adaptation and the cosine similarity scale factor.

CD-FSL task overview
Figure 1. CD-FSL task. Training (left): use labeled multi-domain data to learn a universal feature extractor. Testing (right): N-way K-shot tasks with unseen classes from unseen domains. The support set adapts the model; the query set evaluates it.

SSA-BNS Framework

SSA-BNS framework
Figure 2. SSA-BNS framework. A support set instance \(x_i^s\) is augmented using a similar-class sample \(x_j^s\) (where \(y_j^s \in \mathcal{S}_{y_i^s}\)) through SSA modules inserted in the universal feature extractor. Task adaptation minimizes the \(\mathcal{L}_\text{NCC}\) loss over actual and SSA-augmented features.

1. BatchNorm Adaptation

Given the universal feature extractor, we adapt only the BatchNorm affine parameters \(\{\gamma, \beta\}\) at each layer \(l\), without changing any other part of the model. Batch-normalized activations are:

\begin{align} f^l_\text{BN} = \gamma^l \hat{f}^l + \beta^l; \quad \hat{f}^l = \frac{f^l - \mu^l}{\sqrt{(\sigma^l)^2 + \epsilon}} \end{align}

The BN parameters \(\{\gamma, \beta\}\) are optimized to minimize the Nearest Centroid Classifier (NCC) loss over the support set:

\begin{align} \min_{\gamma, \beta} \; \frac{1}{n_\mathcal{S}} \sum_{(x_i^s, y_i^s) \in \mathcal{S}} \mathcal{L}_\text{NCC}(z_i^s, y_i^s; \eta) \quad \text{where} \quad z_i^s = F(x_i^s) \end{align}

Class centroids are the mean of support features per class:

\begin{align} \mathbf{c}_k = \frac{1}{|\mathcal{S}_k|} \sum_{x_i^s \in \mathcal{S}_k} z_i^s \end{align}

2. Cosine Similarity Scale Factor

The NCC loss uses cosine similarity with a scale hyperparameter \(\eta\):

\begin{align} p(y=k \mid z_i^s; \eta) = \frac{e^{\eta \cos\theta_{i,k}}}{\sum_{j=1}^{C} e^{\eta \cos\theta_{i,j}}} \end{align}

Prior works URL and TSA fixed \(\eta = 10\). In CD-FSL, the test domain can be very different from training, so cosine similarities tend to be low. We find that \(\eta = 25\) is significantly better: it expands the probability range so that correctly classified samples receive high confidence without the collapse seen at \(\eta = 50\) (which causes rapid overfitting on the support set within ~10 iterations).

3. Similar Class Style Augmentation (SSA)

To overcome limited support data, we augment each sample with the style (channel-wise feature statistics) of a semantically similar class sample. Class similarity is measured via cosine similarity of centroids:

\begin{align} \text{sim}(\mathbf{c}_i, \mathbf{c}_j) = \frac{\mathbf{c}_i^T \mathbf{c}_j}{\|\mathbf{c}_i\| \|\mathbf{c}_j\|} \end{align}

The similar class set for class \(k\) is:

\begin{align} \mathcal{S}_k = \{t \mid \text{sim}(\mathbf{c}_t, \mathbf{c}_k) > \tau;\; t = 1,\ldots,C\} \end{align}

For sample \(x_i\) of class \(y_i\), we randomly pick \(x_j\) from a similar class \(y_j \in \mathcal{S}_{y_i}\) and mix their intermediate feature statistics at layer \(l\):

\begin{align} \mu_\text{ssa}(f_i; f_j) &= \lambda\,\mu(f_i) + (1-\lambda)\,\mu(f_j) \\ \sigma_\text{ssa}(f_i; f_j) &= \lambda\,\sigma(f_i) + (1-\lambda)\,\sigma(f_j) \\ f_i^\text{ssa} &= \sigma_\text{ssa} \odot \frac{f_i - \mu(f_i)}{\sigma(f_i)} + \mu_\text{ssa} \end{align}

The content of \(x_i\) is preserved in \(f_i^\text{ssa}\), so the augmented sample retains its class label \(y_i\). The final SSA-BNS objective jointly minimizes NCC loss on real and augmented features:

\begin{align} \min_{\gamma, \beta} \; \frac{1}{2n_\mathcal{S}} \sum_{(x_i^s, y_i^s)} \left[ \mathcal{L}_\text{NCC}(z_i^s, y_i^s; \eta) + \mathcal{L}_\text{NCC}(z_i^\text{ssa}, y_i^s; \eta) \right] \end{align}

SSA is inserted after the first two ResNet blocks with \(\lambda = 0.5\) and similarity threshold \(\tau = 0.7\). No additional parameters are introduced.

Experimental Results

We evaluate on the Meta-Dataset benchmark (8 seen + 5 unseen domains) using a ResNet-18 universal feature extractor. Average accuracy and 95% confidence interval are reported over 600 tasks.

Dataset SUR URT FLUTE tri-M URL* TSA* SSA-BNS TSA*+SSA
ImageNet56.2±1.056.8±1.158.6±1.051.8±1.158.8±1.159.5±1.056.6±1.058.9±1.1
Omniglot94.1±0.494.2±0.492.0±0.693.2±0.594.5±0.494.9±0.495.2±0.595.6±0.4
Aircraft85.5±0.585.8±0.582.8±0.787.2±0.589.4±0.489.9±0.489.6±0.490.0±0.5
Birds71.0±1.076.2±0.875.3±0.879.2±0.880.7±0.881.1±0.881.8±0.882.2±0.7
Textures71.0±0.871.6±0.771.2±0.868.8±0.877.2±0.777.5±0.776.4±0.777.6±0.7
Quick Draw81.8±0.682.4±0.677.3±0.779.5±0.782.5±0.681.7±0.682.8±0.682.7±0.7
Fungi64.3±0.964.0±1.048.5±1.058.1±1.168.1±0.966.3±0.866.7±0.866.6±0.8
VGG Flower82.9±0.887.9±0.690.5±0.591.6±0.692.0±0.592.2±0.592.8±0.693.0±0.5
Traffic Sign51.0±1.148.2±1.163.0±1.058.4±1.163.3±1.182.8±1.077.9±1.184.9±1.1
MSCOCO52.0±1.151.5±1.152.8±1.150.0±1.057.3±1.057.6±1.056.1±0.958.1±1.0
MNIST94.3±0.490.6±0.596.2±0.395.6±0.594.7±0.496.7±0.498.3±0.598.5±0.4
CIFAR-1066.5±0.967.0±0.875.4±0.878.6±0.774.2±0.882.9±0.779.4±0.782.9±0.7
CIFAR-10056.9±1.157.3±1.062.0±1.067.1±1.063.5±1.070.4±0.969.0±0.970.8±0.9
Avg seen75.977.474.576.280.480.480.280.8
Avg unseen64.162.969.969.970.678.176.179.0
Avg all71.471.872.773.876.679.578.780.1

Table 1. Average accuracy (%) over 600 tasks on Meta-Dataset. * indicates methods that use additional parameters beyond the feature extractor. SSA-BNS uses no additional parameters and outperforms URL with 262K extra parameters. TSA*+SSA achieves the best overall average.

Ablation: Effect of BN Adaptation and SSA

SSA BNS (η) Aircraft Fungi CIFAR-100 MSCOCO
✗✗87.065.659.953.1
✗η=1089.166.066.954.5
✗η=2589.566.468.455.7
✗η=5089.566.267.755.4
✓η=2589.666.769.056.1

Table 2. BN adaptation with η=25 consistently outperforms η=10 (default in URL/TSA). Adding SSA further improves performance across both seen and unseen domains.

Comparison with Other Augmentation Strategies

Augmentation Aircraft Fungi CIFAR-100 MSCOCO
RandAugment88.865.266.955.2
MixUp88.466.367.954.6
Feature MixUp88.966.368.355.3
Random MixStyle89.666.268.255.3
SSA (Proposed)89.666.769.056.1

Table 3. SSA outperforms all compared augmentation techniques. Restricting style mixing to semantically similar classes (SSA) consistently beats class-agnostic Random MixStyle.

Parameter Efficiency

Method Additional parameters Trainable parameters
FLUTE32K32K
URL262K262K
TSA1482K1482K
TSA+SSA1482K1482K
SSA-BNSNone9.6K

Table 4. SSA-BNS introduces no additional parameters, training only the existing BN affine parameters (9.6K in ResNet-18). It outperforms URL which adds 262K parameters, using 154x fewer trainable parameters than TSA.

BibTeX

@InProceedings{sreenivas2023ssabns,
  author    = {Sreenivas, Manogna and Biswas, Soma},
  title     = {Similar Class Style Augmentation for Efficient Cross-Domain Few-Shot Learning},
  booktitle = {CVPR Workshops},
  year      = {2023}
}