Minimum Class Confusion (MCC)¶
-
class
dalib.adaptation.mcc.
MinimumClassConfusionLoss
(temperature)[source]¶ Minimum Class Confusion loss minimizes the class confusion in the target predictions.
You can see more details in Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020)
- Parameters
temperature (float) – The temperature for rescaling, the prediction will shrink to vanilla softmax if temperature is 1.0.
Note
Make sure that temperature is larger than 0.
- Inputs: g_t
g_t (tensor): unnormalized classifier predictions on target domain, \(g^t\)
- Shape:
g_t: \((minibatch, C)\) where C means the number of classes.
Output: scalar.
- Examples::
>>> temperature = 2.0 >>> loss = MinimumClassConfusionLoss(temperature) >>> # logits output from target domain >>> g_t = torch.randn(batch_size, num_classes) >>> output = loss(g_t)
MCC can also serve as a regularizer for existing methods. Examples:
>>> from dalib.modules.domain_discriminator import DomainDiscriminator >>> num_classes = 2 >>> feature_dim = 1024 >>> batch_size = 10 >>> temperature = 2.0 >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024) >>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') >>> mcc_loss = MinimumClassConfusionLoss(temperature) >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # logits output from source domain adn target domain >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)