Shortcuts

Importance Weighted Adversarial Nets (IWAN)

class dalib.adaptation.iwan.ImportanceWeightModule(discriminator, partial_classes_index=None)[source]

Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)

Parameters
  • discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)

  • partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.

Examples:

>>> domain_discriminator = DomainDiscriminator(1024, 1024)
>>> importance_weight_module = ImportanceWeightModule(domain_discriminator)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>>     # feature from source domain
>>>     f_s = torch.randn(32, 1024)
>>>     # importance weights for source instance
>>>     w_s = importance_weight_module.get_importance_weight(f_s)
get_importance_weight(feature)[source]

Get importance weights for each instance.

Parameters

feature (tensor) – feature from source domain, in shape \((N, F)\)

Returns

instance weight in shape \((N, 1)\)

get_partial_classes_weight(weights, labels)[source]

Get class weight averaged on the partial classes and non-partial classes respectively.

Parameters
  • weights (tensor) – instance weight in shape \((N, 1)\)

  • labels (tensor) – ground truth labels in shape \((N, 1)\)

Warning

This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for transfer learning

Get Started