Shortcuts

Minimum Class Confusion (MCC)

class dalib.adaptation.mcc.MinimumClassConfusionLoss(temperature)[source]

Minimum Class Confusion loss minimizes the class confusion in the target predictions.

You can see more details in Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020)

Parameters

temperature (float) – The temperature for rescaling, the prediction will shrink to vanilla softmax if temperature is 1.0.

Note

Make sure that temperature is larger than 0.

Inputs: g_t
  • g_t (tensor): unnormalized classifier predictions on target domain, \(g^t\)

Shape:
  • g_t: \((minibatch, C)\) where C means the number of classes.

  • Output: scalar.

Examples::
>>> temperature = 2.0
>>> loss = MinimumClassConfusionLoss(temperature)
>>> # logits output from target domain
>>> g_t = torch.randn(batch_size, num_classes)
>>> output = loss(g_t)

MCC can also serve as a regularizer for existing methods. Examples:

>>> from dalib.modules.domain_discriminator import DomainDiscriminator
>>> num_classes = 2
>>> feature_dim = 1024
>>> batch_size = 10
>>> temperature = 2.0
>>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024)
>>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
>>> mcc_loss = MinimumClassConfusionLoss(temperature)
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # logits output from source domain adn target domain
>>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for transfer learning

Get Started