Domain Adversarial Learning¶
Domain Adversarial Training in Close-set DA¶
DANN: Domain Adversarial Neural Network¶
-
class
dalib.adaptation.dann.
DomainAdversarialLoss
(domain_discriminator, reduction='mean', grl=None)[source]¶ The Domain Adversarial Loss proposed in Domain-Adversarial Training of Neural Networks (ICML 2015)
Domain adversarial loss measures the domain discrepancy through training a domain discriminator. Given domain discriminator \(D\), feature representation \(f\), the definition of DANN loss is
\[loss(\mathcal{D}_s, \mathcal{D}_t) = \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(f_i^s)] + \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(f_j^t)].\]- Parameters
domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)
reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
.'none'
: no reduction will be applied,'mean'
: the sum of the output will be divided by the number of elements in the output,'sum'
: the output will be summed. Default:'mean'
grl (WarmStartGradientReverseLayer, optional) – Default: None.
- Inputs:
f_s (tensor): feature representations on source domain, \(f^s\)
f_t (tensor): feature representations on target domain, \(f^t\)
w_s (tensor, optional): a rescaling weight given to each instance from source domain.
w_t (tensor, optional): a rescaling weight given to each instance from target domain.
- Shape:
f_s, f_t: \((N, F)\) where F means the dimension of input features.
Outputs: scalar by default. If
reduction
is'none'
, then \((N, )\).
Examples:
>>> from dalib.modules.domain_discriminator import DomainDiscriminator >>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024) >>> loss = DomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024) >>> # If you want to assign different weights to each instance, you should pass in w_s and w_t >>> w_s, w_t = torch.randn(20), torch.randn(20) >>> output = loss(f_s, f_t, w_s, w_t)
CDAN: Conditional Domain Adversarial Network¶
-
class
dalib.adaptation.cdan.
ConditionalDomainAdversarialLoss
(domain_discriminator, entropy_conditioning=False, randomized=False, num_classes=-1, features_dim=-1, randomized_dim=1024, reduction='mean')[source]¶ The Conditional Domain Adversarial Loss used in Conditional Adversarial Domain Adaptation (NIPS 2018)
Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a conditional manner. Given domain discriminator \(D\), feature representation \(f\) and classifier predictions \(g\), the definition of CDAN loss is
\[\begin{split}loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\ &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\\end{split}\]where \(T\) is a
MultiLinearMap
orRandomizedMultiLinearMap
which convert two tensors to a single tensor.- Parameters
domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)
entropy_conditioning (bool, optional) – If True, use entropy-aware weight to reweight each training example. Default: False
randomized (bool, optional) – If True, use randomized multi linear map. Else, use multi linear map. Default: False
num_classes (int, optional) – Number of classes. Default: -1
features_dim (int, optional) – Dimension of input features. Default: -1
randomized_dim (int, optional) – Dimension of features after randomized. Default: 1024
reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
.'none'
: no reduction will be applied,'mean'
: the sum of the output will be divided by the number of elements in the output,'sum'
: the output will be summed. Default:'mean'
Note
You need to provide num_classes, features_dim and randomized_dim only when randomized is set True.
- Inputs:
g_s (tensor): unnormalized classifier predictions on source domain, \(g^s\)
f_s (tensor): feature representations on source domain, \(f^s\)
g_t (tensor): unnormalized classifier predictions on target domain, \(g^t\)
f_t (tensor): feature representations on target domain, \(f^t\)
- Shape:
g_s, g_t: \((minibatch, C)\) where C means the number of classes.
f_s, f_t: \((minibatch, F)\) where F means the dimension of input features.
Output: scalar by default. If
reduction
is'none'
, then \((minibatch, )\).
Examples:
>>> from dalib.modules.domain_discriminator import DomainDiscriminator >>> from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss >>> import torch >>> num_classes = 2 >>> feature_dim = 1024 >>> batch_size = 10 >>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024) >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # logits output from source domain adn target domain >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> output = loss(g_s, f_s, g_t, f_t)
-
class
dalib.adaptation.cdan.
RandomizedMultiLinearMap
(features_dim, num_classes, output_dim=1024)[source]¶ Random multi linear map
Given two inputs \(f\) and \(g\), the definition is
\[T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),\]where \(\odot\) is element-wise product, \(R_f\) and \(R_g\) are random matrices sampled only once and fixed in training.
- Parameters
- Shape:
f: (minibatch, features_dim)
g: (minibatch, num_classes)
Outputs: (minibatch, output_dim)
ADDA: Adversarial Discriminative Domain Adaptation¶
-
class
dalib.adaptation.adda.
DomainAdversarialLoss
[source]¶ Domain adversarial loss from Adversarial Discriminative Domain Adaptation (CVPR 2017). Similar to the original GAN paper, ADDA argues that replacing \(\text{log}(1-p)\) with \(-\text{log}(p)\) in the adversarial loss provides better gradient qualities. Detailed optimization process can be found at examples/domain_adaptation/image_classification/adda.py.
- Inputs:
domain_pred (tensor): predictions of domain discriminator
domain_label (str, optional): whether the data comes from source or target. Must be ‘source’ or ‘target’. Default: ‘source’
- Shape:
domain_pred: \((minibatch,)\).
Outputs: scalar.
Note
ADDAgrl is also implemented and benchmarked. You can find code at examples/domain_adaptation/image_classification/addagrl.py.
BSP: Batch Spectral Penalization¶
-
class
dalib.adaptation.bsp.
BatchSpectralPenalizationLoss
[source]¶ Batch spectral penalization loss from Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation (ICML 2019).
Given source features \(f_s\) and target features \(f_t\) in current mini batch, singular value decomposition is first performed
\[f_s = U_s\Sigma_sV_s^T\]\[f_t = U_t\Sigma_tV_t^T\]Then batch spectral penalization loss is calculated as
\[loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2)\]where \(\sigma_{s,i},\sigma_{t,i}\) refer to the \(i-th\) largest singular value of source features and target features respectively. We empirically set \(k=1\).
- Inputs:
f_s (tensor): feature representations on source domain, \(f^s\)
f_t (tensor): feature representations on target domain, \(f^t\)
- Shape:
f_s, f_t: \((N, F)\) where F means the dimension of input features.
Outputs: scalar.
Domain Adversarial Training in Open-set DA¶
OSBP: Open Set Domain Adaptation by Backpropagation¶
-
class
dalib.adaptation.osbp.
UnknownClassBinaryCrossEntropy
(t=0.5)[source]¶ Binary cross entropy loss to make a boundary for unknown samples, proposed by Open Set Domain Adaptation by Backpropagation (ECCV 2018).
Given a sample on target domain \(x_t\) and its classifcation outputs \(y\), the binary cross entropy loss is defined as
\[L_{\text{adv}}(x_t) = -t \text{log}(p(y=C+1|x_t)) - (1-t)\text{log}(1-p(y=C+1|x_t))\]where t is a hyper-parameter and C is the number of known classes.
- Parameters
t (float) – Predefined hyper-parameter. Default: 0.5
- Inputs:
y (tensor): classification outputs (before softmax).
- Shape:
y: \((minibatch, C+1)\) where C is the number of known classes.
Outputs: scalar
Domain Adversarial Training in Partial DA¶
PADA: Partial Adversarial Domain Adaptation¶
-
class
dalib.adaptation.pada.
ClassWeightModule
(temperature=0.1)[source]¶ Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)
Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows
\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),\]where \(\mathcal{\gamma}\) is a \(|\mathcal{C}|\)-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.
In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)
- Parameters
temperature (float, optional) – hyper-parameters \(T\). Default: 0.1
- Shape:
Inputs: (minibatch, \(|\mathcal{C}|\))
Outputs: (\(|\mathcal{C}|\),)
-
class
dalib.adaptation.pada.
AutomaticUpdateClassWeightModule
(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]¶ Calculating class weight based on the output of classifier. See
ClassWeightModule
about the details of the calculation. Every N iterations, the class weight is updated automatically.- Parameters
update_steps (int) – N, the number of iterations to update class weight.
data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.
classifier (torch.nn.Module) – Classifier.
num_classes (int) – Number of classes.
device (torch.device) – The device to run classifier.
temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for dalib.addaptation.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
-
get_class_weight_for_adversarial_loss
(source_labels)[source]¶ - Outputs:
w_s: source weight for
DomainAdversarialLoss
w_t: target weight for
DomainAdversarialLoss
- Shape:
w_s: \((minibatch, )\)
w_t: \((minibatch, )\)
-
get_class_weight_for_cross_entropy_loss
()[source]¶ Outputs: weight for F.cross_entropy
Shape: \((C, )\) where C means the number of classes.
-
get_partial_classes_weight
()[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
-
dalib.adaptation.pada.
collect_classification_results
(data_loader, classifier, device)[source]¶ Fetch data from data_loader, and then use classifier to collect classification results
- Parameters
data_loader (torch.utils.data.DataLoader) – Data loader.
classifier (torch.nn.Module) – A classifier.
device (torch.device) –
- Returns
Classification results in shape (len(data_loader), \(|\mathcal{C}|\)).
IWAN: Importance Weighted Adversarial Nets¶
-
class
dalib.adaptation.iwan.
ImportanceWeightModule
(discriminator, partial_classes_index=None)[source]¶ Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)
- Parameters
discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s)
-
get_importance_weight
(feature)[source]¶ Get importance weights for each instance.
- Parameters
feature (tensor) – feature from source domain, in shape \((N, F)\)
- Returns
instance weight in shape \((N, 1)\)
-
get_partial_classes_weight
(weights, labels)[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
- Parameters
weights (tensor) – instance weight in shape \((N, 1)\)
labels (tensor) – ground truth labels in shape \((N, 1)\)
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
Domain Adversarial Training in Segmentation¶
ADVENT: Adversarial Entropy Minimization¶
-
class
dalib.adaptation.advent.
Discriminator
(num_classes, ndf=64)[source]¶ Domain discriminator model from ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019)
Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0.
- Parameters
- Shape:
Inputs: \((minibatch, C, H, W)\) where \(C\) is the number of classes
Outputs: \((minibatch, 1, H, W)\)
-
class
dalib.adaptation.advent.
DomainAdversarialEntropyLoss
(discriminator)[source]¶ The Domain Adversarial Entropy Loss
Minimizing entropy with adversarial learning through training a domain discriminator.
- Parameters
domain_discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of predictions. Its input shape is \((minibatch, C, H, W)\) and output shape is \((minibatch, 1, H, W)\)
- Inputs:
logits (tensor): logits output of segmentation model
domain_label (str, optional): whether the data comes from source or target. Choices: [‘source’, ‘target’]. Default: ‘source’
- Shape:
logits: \((minibatch, C, H, W)\) where \(C\) means the number of classes
Outputs: scalar.
Examples:
>>> B, C, H, W = 2, 19, 512, 512 >>> discriminator = Discriminator(num_classes=C) >>> dann = DomainAdversarialEntropyLoss(discriminator) >>> # logits output on source domain and target domain >>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W) >>> loss = 0.5 * (dann(y_s, "source") + dann(y_t, "target"))
-
eval
()[source]¶ Sets the module in evaluation mode. In the training mode, all the parameters in discriminator will be set requires_grad=False.
This is equivalent with
self.train(False)
.