Shortcuts

Learning Strategy

Group Distributionally robust optimization (GroupDRO)

class dglib.generalization.groupdro.AutomaticUpdateDomainWeightModule(num_domains, eta, device)[source]

Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020).

Suppose we have \(N\) domains. During each iteration, we first calculate unweighted loss among all domains, resulting in \(loss\in R^N\). Then we update domain weight by

\[w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]\]

where \(\eta\) is the hyper parameter which ensures smoother change of weight. As \(w \in R^N\) denotes a distribution, we normalize \(w\) by its sum. At last, weighted loss is calculated as our objective

\[objective = \sum_{k=1}^N w_k * loss_k\]
Parameters
  • num_domains (int) – The number of source domains.

  • eta (float) – Hyper parameter eta.

  • device (torch.device) – The device to run on.

get_domain_weight(sampled_domain_idxes)[source]

Get domain weight to calculate final objective.

Inputs:
  • sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
  • sampled_domain_idxes: \((D, )\) where D means the number of sampled domains in current mini-batch

  • Outputs: \((D, )\)

update(sampled_domain_losses, sampled_domain_idxes)[source]

Update domain weight using loss of current mini-batch.

Inputs:
  • sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch

  • sampled_domain_idxes (list): sampled domain indexes in current mini-batch

Shape:
  • sampled_domain_losses: \((D, )\) where D means the number of sampled domains in current mini-batch

  • sampled_domain_idxes: \((D, )\)

Invariant Risk Minimization (IRM)

class dglib.generalization.irm.InvariancePenaltyLoss[source]

Invariance Penalty Loss from Invariant Risk Minimization. We adopt implementation from DomainBed. Given classifier output \(y\) and ground truth \(labels\), we split \(y\) into two parts \(y_1, y_2\), corresponding labels are \(labels_1, labels_2\). Next we calculate cross entropy loss with respect to a dummy classifier \(w\), resulting in \(grad_1, grad_2\) . Invariance penalty is then \(grad_1*grad_2\).

Inputs:
  • y: predictions from model

  • labels: ground truth

Shape:
  • y: \((N, C)\) where C means the number of classes.

  • labels: \((N, )\) where N mean mini-batch size

Meta Learning for Domain Generalization (MLDG)

Learning to Generalize: Meta-Learning for Domain Generalization (AAAI 2018)

Consider there are \(S\) source domains, at each learning iteration MLDG splits the original \(S\) source domains into meta-train domains \(S_1\) and meta-test domains \(S_2\). The inner objective is cross entropy loss on meta-train domains \(S_1\). The outer (meta-optimization) objective contains two terms. The first one (which is the same as inner objective) is cross entropy loss on meta-train domains \(S_1\) with current model parameters \(\theta\)

\[\mathbb{E}_{(x,y) \in S_1} l(f(\theta, x), y)\]

where \(l\) denotes cross entropy loss, \(f(\theta, x)\) denotes predictions from model. The second term is cross entropy loss on meta-test domains \(S_2\) with inner optimized model parameters \(\theta_{updated}\)

\[\mathbb{E}_{(x,y) \in S_2} l(f(\theta_{updated}, x), y)\]

In this way, MLDG simulates train/test domain shift during training by synthesizing virtual testing domains within each mini-batch. The outer objective forces that steps to improve training domain performance should also improve testing domain performance.

Note

Because we need to compute second-order gradient, this full optimization process may take a long time and have heavy budget on GPU resource. A first order approximation implementation can be found at DomainBed.

Variance Risk Extrapolation (VREx)

Out-of-Distribution Generalization via Risk Extrapolation (ICML 2021)

VREx shows that reducing differences in risk across training domains can reduce a model’s sensitivity to a wide range of extreme distributional shifts. Consider there are \(S\) source domains. At each learning iteration VREx first computes cross entropy loss on each source domain separately, producing a loss vector \(l \in R^S\). The ERM (vanilla cross entropy) loss can be computed as

\[l_{\text{ERM}} = \frac{1}{S}\sum_{i=1}^S l_i\]

And the penalty loss is

\[penalty = \frac{1}{S} \sum_{i=1}^S {(l_i - l_{\text{ERM}})}^2\]

The final objective is then

\[objective = l_{\text{ERM}} + \beta * penalty\]

where \(\beta\) is the trade off hyper parameter.

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started