Analysis Tools¶
-
common.utils.analysis.
collect_feature
(data_loader, feature_extractor, device, max_num_features=None)[source]¶ Fetch data from data_loader, and then use feature_extractor to collect features
- Parameters
data_loader (torch.utils.data.DataLoader) – Data loader.
feature_extractor (torch.nn.Module) – A feature extractor.
device (torch.device) –
max_num_features (int) – The max number of features to return
- Returns
Features in shape (min(len(data_loader), max_num_features * mini-batch size), \(|\mathcal{F}|\)).
-
common.utils.analysis.a_distance.
calculate
(source_feature, target_feature, device, progress=True, training_epochs=10)[source]¶ Calculate the \(\mathcal{A}\)-distance, which is a measure for distribution discrepancy.
The definition is \(dist_\mathcal{A} = 2 (1-2\epsilon)\), where \(\epsilon\) is the test error of a classifier trained to discriminate the source from the target.
- Parameters
source_feature (tensor) – features from source domain in shape \((minibatch, F)\)
target_feature (tensor) – features from target domain in shape \((minibatch, F)\)
device (torch.device) –
progress (bool) – if True, displays a the progress of training A-Net
training_epochs (int) – the number of epochs when training the classifier
- Returns
\(\mathcal{A}\)-distance
-
common.utils.analysis.tsne.
visualize
(source_feature, target_feature, filename, source_color='r', target_color='b')[source]¶ Visualize features from different domains using t-SNE.
- Parameters
source_feature (tensor) – features from source domain in shape \((minibatch, F)\)
target_feature (tensor) – features from target domain in shape \((minibatch, F)\)
filename (str) – the file name to save t-SNE
source_color (str) – the color of the source features. Default: ‘r’
target_color (str) – the color of the target features. Default: ‘b’