Shortcuts

Source code for dglib.generalization.coral

"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn


[docs]class CorrelationAlignmentLoss(nn.Module): r"""The `Correlation Alignment Loss` in `Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) <https://arxiv.org/pdf/1607.01719.pdf>`_. Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by .. math:: C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S)) .. math:: C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T)) where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of source and target samples, respectively. We use :math:`d` to denote feature dimension, use :math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is given by .. math:: l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size. - Outputs: scalar. """ def __init__(self): super(CorrelationAlignmentLoss, self).__init__() def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: mean_s = f_s.mean(0, keepdim=True) mean_t = f_t.mean(0, keepdim=True) cent_s = f_s - mean_s cent_t = f_t - mean_t cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1) cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1) mean_diff = (mean_s - mean_t).pow(2).mean() cov_diff = (cov_s - cov_t).pow(2).mean() return mean_diff + cov_diff

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started