Shortcuts

Source code for common.utils.analysis

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import tqdm


[docs]def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, device: torch.device, max_num_features=None) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `feature_extractor` to collect features Args: data_loader (torch.utils.data.DataLoader): Data loader. feature_extractor (torch.nn.Module): A feature extractor. device (torch.device) max_num_features (int): The max number of features to return Returns: Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). """ feature_extractor.eval() all_features = [] with torch.no_grad(): for i, (images, target) in enumerate(tqdm.tqdm(data_loader)): if max_num_features is not None and i >= max_num_features: break images = images.to(device) feature = feature_extractor(images).cpu() all_features.append(feature) return torch.cat(all_features, dim=0)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started