Shortcuts

Partial Adversarial Domain Adaptation (PADA)

class dalib.adaptation.pada.ClassWeightModule(temperature=0.1)[source]

Calculating class weight based on the output of classifier. Introduced by Partial Adversarial Domain Adaptation (ECCV 2018)

Given classification logits outputs \(\{\hat{y}_i\}_{i=1}^n\), where \(n\) is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows

\[\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}softmax( \hat{y}_i / T),\]

where \(\mathcal{\gamma}\) is a \(|\mathcal{C}|\)-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature.

In practice, it’s possible that some of the weights are very small, thus, we normalize weight \(\mathcal{\gamma}\) by dividing its largest element, i.e. \(\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})\)

Parameters

temperature (float, optional) – hyper-parameters \(T\). Default: 0.1

Shape:
  • Inputs: (minibatch, \(|\mathcal{C}|\))

  • Outputs: (\(|\mathcal{C}|\),)

class dalib.adaptation.pada.AutomaticUpdateClassWeightModule(update_steps, data_loader, classifier, num_classes, device, temperature=0.1, partial_classes_index=None)[source]

Calculating class weight based on the output of classifier. See ClassWeightModule about the details of the calculation. Every N iterations, the class weight is updated automatically.

Parameters
  • update_steps (int) – N, the number of iterations to update class weight.

  • data_loader (torch.utils.data.DataLoader) – The data loader from which we can collect classification outputs.

  • classifier (torch.nn.Module) – Classifier.

  • num_classes (int) – Number of classes.

  • device (torch.device) – The device to run classifier.

  • temperature (float, optional) – T, temperature in ClassWeightModule. Default: 0.1

  • partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     class_weight_module.step()
>>>     # weight for F.cross_entropy
>>>     w_c = class_weight_module.get_class_weight_for_cross_entropy_loss()
>>>     # weight for dalib.addaptation.dann.DomainAdversarialLoss
>>>     w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
get_class_weight_for_adversarial_loss(source_labels)[source]
Outputs:
Shape:
  • w_s: \((minibatch, )\)

  • w_t: \((minibatch, )\)

get_class_weight_for_cross_entropy_loss()[source]

Outputs: weight for F.cross_entropy

Shape: \((C, )\) where C means the number of classes.

get_partial_classes_weight()[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

dalib.adaptation.pada.collect_classification_results(data_loader, classifier, device)[source]

Fetch data from data_loader, and then use classifier to collect classification results

Parameters
Returns

Classification results in shape (len(data_loader), \(|\mathcal{C}|\)).

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for transfer learning

Get Started