Learning Strategy¶
Group Distributionally robust optimization (GroupDRO)¶
-
class
dglib.generalization.groupdro.
AutomaticUpdateDomainWeightModule
(num_domains, eta, device)[source]¶ Maintaining group weight based on loss history of all domains according to Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020).
Suppose we have \(N\) domains. During each iteration, we first calculate unweighted loss among all domains, resulting in \(loss\in R^N\). Then we update domain weight by
\[w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]\]where \(\eta\) is the hyper parameter which ensures smoother change of weight. As \(w \in R^N\) denotes a distribution, we normalize \(w\) by its sum. At last, weighted loss is calculated as our objective
\[objective = \sum_{k=1}^N w_k * loss_k\]- Parameters
-
get_domain_weight
(sampled_domain_idxes)[source]¶ Get domain weight to calculate final objective.
- Inputs:
sampled_domain_idxes (list): sampled domain indexes in current mini-batch
- Shape:
sampled_domain_idxes: \((D, )\) where D means the number of sampled domains in current mini-batch
Outputs: \((D, )\)
-
update
(sampled_domain_losses, sampled_domain_idxes)[source]¶ Update domain weight using loss of current mini-batch.
- Inputs:
sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch
sampled_domain_idxes (list): sampled domain indexes in current mini-batch
- Shape:
sampled_domain_losses: \((D, )\) where D means the number of sampled domains in current mini-batch
sampled_domain_idxes: \((D, )\)
Invariant Risk Minimization (IRM)¶
-
class
dglib.generalization.irm.
InvariancePenaltyLoss
[source]¶ Invariance Penalty Loss from Invariant Risk Minimization. We adopt implementation from DomainBed. Given classifier output \(y\) and ground truth \(labels\), we split \(y\) into two parts \(y_1, y_2\), corresponding labels are \(labels_1, labels_2\). Next we calculate cross entropy loss with respect to a dummy classifier \(w\), resulting in \(grad_1, grad_2\) . Invariance penalty is then \(grad_1*grad_2\).
- Inputs:
y: predictions from model
labels: ground truth
- Shape:
y: \((N, C)\) where C means the number of classes.
labels: \((N, )\) where N mean mini-batch size
Meta Learning for Domain Generalization (MLDG)¶
Learning to Generalize: Meta-Learning for Domain Generalization (AAAI 2018)
Consider there are \(S\) source domains, at each learning iteration MLDG splits the original \(S\) source domains into meta-train domains \(S_1\) and meta-test domains \(S_2\). The inner objective is cross entropy loss on meta-train domains \(S_1\). The outer (meta-optimization) objective contains two terms. The first one (which is the same as inner objective) is cross entropy loss on meta-train domains \(S_1\) with current model parameters \(\theta\)
where \(l\) denotes cross entropy loss, \(f(\theta, x)\) denotes predictions from model. The second term is cross entropy loss on meta-test domains \(S_2\) with inner optimized model parameters \(\theta_{updated}\)
In this way, MLDG simulates train/test domain shift during training by synthesizing virtual testing domains within each mini-batch. The outer objective forces that steps to improve training domain performance should also improve testing domain performance.
Note
Because we need to compute second-order gradient, this full optimization process may take a long time and have heavy budget on GPU resource. A first order approximation implementation can be found at DomainBed.
Variance Risk Extrapolation (VREx)¶
Out-of-Distribution Generalization via Risk Extrapolation (ICML 2021)
VREx shows that reducing differences in risk across training domains can reduce a model’s sensitivity to a wide range of extreme distributional shifts. Consider there are \(S\) source domains. At each learning iteration VREx first computes cross entropy loss on each source domain separately, producing a loss vector \(l \in R^S\). The ERM (vanilla cross entropy) loss can be computed as
And the penalty loss is
The final objective is then
where \(\beta\) is the trade off hyper parameter.