Importance Weighted Adversarial Nets (IWAN)¶
-
class
dalib.adaptation.iwan.
ImportanceWeightModule
(discriminator, partial_classes_index=None)[source]¶ Calculating class weight based on the output of discriminator. Introduced by Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018)
- Parameters
discriminator (torch.nn.Module) – A domain discriminator object, which predicts the domains of features. Its input shape is \((N, F)\) and output shape is \((N, 1)\)
partial_classes_index (list[int], optional) – The index of partial classes. Note that this parameter is just for debugging, since in real-world dataset, we have no access to the index of partial classes. Default: None.
Examples:
>>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s)
-
get_importance_weight
(feature)[source]¶ Get importance weights for each instance.
- Parameters
feature (tensor) – feature from source domain, in shape \((N, F)\)
- Returns
instance weight in shape \((N, 1)\)
-
get_partial_classes_weight
(weights, labels)[source]¶ Get class weight averaged on the partial classes and non-partial classes respectively.
- Parameters
weights (tensor) – instance weight in shape \((N, 1)\)
labels (tensor) – ground truth labels in shape \((N, 1)\)
Warning
This function is just for debugging, since in real-world dataset, we have no access to the index of partial classes and this function will throw an error when partial_classes_index is None.