Shortcuts

Source code for dalib.adaptation.bsp

"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
from common.modules.classifier import Classifier as ClassifierBase


[docs]class BatchSpectralPenalizationLoss(nn.Module): r"""Batch spectral penalization loss from `Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation (ICML 2019) <http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf>`_. Given source features :math:`f_s` and target features :math:`f_t` in current mini batch, singular value decomposition is first performed .. math:: f_s = U_s\Sigma_sV_s^T .. math:: f_t = U_t\Sigma_tV_t^T Then batch spectral penalization loss is calculated as .. math:: loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2) where :math:`\sigma_{s,i},\sigma_{t,i}` refer to the :math:`i-th` largest singular value of source features and target features respectively. We empirically set :math:`k=1`. Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - f_s, f_t: :math:`(N, F)` where F means the dimension of input features. - Outputs: scalar. """ def __init__(self): super(BatchSpectralPenalizationLoss, self).__init__() def forward(self, f_s, f_t): _, s_s, _ = torch.svd(f_s) _, s_t, _ = torch.svd(f_t) loss = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) return loss
class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started