Source code for common.utils.analysis
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import tqdm
[docs]def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module,
device: torch.device, max_num_features=None) -> torch.Tensor:
"""
Fetch data from `data_loader`, and then use `feature_extractor` to collect features
Args:
data_loader (torch.utils.data.DataLoader): Data loader.
feature_extractor (torch.nn.Module): A feature extractor.
device (torch.device)
max_num_features (int): The max number of features to return
Returns:
Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`).
"""
feature_extractor.eval()
all_features = []
with torch.no_grad():
for i, (images, target) in enumerate(tqdm.tqdm(data_loader)):
if max_num_features is not None and i >= max_num_features:
break
images = images.to(device)
feature = feature_extractor(images).cpu()
all_features.append(feature)
return torch.cat(all_features, dim=0)