Shortcuts

Source code for common.utils.analysis.tsne

"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import matplotlib

matplotlib.use('Agg')
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as col


[docs]def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor, filename: str, source_color='r', target_color='b'): """ Visualize features from different domains using t-SNE. Args: source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` filename (str): the file name to save t-SNE source_color (str): the color of the source features. Default: 'r' target_color (str): the color of the target features. Default: 'b' """ source_feature = source_feature.numpy() target_feature = target_feature.numpy() features = np.concatenate([source_feature, target_feature], axis=0) # map features to 2-d using TSNE X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) # domain labels, 1 represents source while 0 represents target domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature)))) # visualize using matplotlib fig, ax = plt.subplots(figsize=(10, 10)) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20) plt.xticks([]) plt.yticks([]) plt.savefig(filename)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started