Partial Adversarial Domain Adaptation (PADA)¶
-
class
dalib.adaptation.pada.
ClassWeightModule
(temperature=0.1)[source]¶ Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)
Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows
\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}softmax( \hat{y}_i / T),\]where \(\mathcal{\gamma}\) is a \(|\mathcal{C}|\)-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.
In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)
- Parameters
temperature (float, optional) – hyper-parameters \(T\). Default: 0.1
- Shape:
Inputs: (minibatch, \(|\mathcal{C}|\))
Outputs: (\(|\mathcal{C}|\),)
-
class
dalib.adaptation.pada.
AutomaticUpdateClassWeightModule
(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]¶ Calculating class weight based on the output of classifier. See
ClassWeightModule
about the details of the calculation. Every N iterations, the class weight is updated automatically.- Parameters
update_steps (int) – N, the number of iterations to update class weight.
data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.
classifier (torch.nn.Module) – Classifier.
num_classes (int) – Number of classes.
device (torch.device) – The device to run classifier.
temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for dalib.addaptation.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
-
get_class_weight_for_adversarial_loss
(source_labels)[source]¶ - Outputs:
w_s: source weight for
DomainAdversarialLoss
w_t: target weight for
DomainAdversarialLoss
- Shape:
w_s: \((minibatch, )\)
w_t: \((minibatch, )\)
-
get_class_weight_for_cross_entropy_loss
()[source]¶ Outputs: weight for F.cross_entropy
Shape: \((C, )\) where C means the number of classes.
-
get_partial_classes_weight
()[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.
-
dalib.adaptation.pada.
collect_classification_results
(data_loader, classifier, device)[source]¶ Fetch data from data_loader, and then use classifier to collect classification results
- Parameters
data_loader (torch.utils.data.DataLoader) – Data loader.
classifier (torch.nn.Module) – A classifier.
device (torch.device) –
- Returns
Classification results in shape (len(data_loader), \(|\mathcal{C}|\)).