Shortcuts

Source code for dglib.generalization.mixstyle.sampler

"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import copy
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data.sampler import Sampler


[docs]class RandomDomainMultiInstanceSampler(Sampler): r"""Randomly sample :math:`N` domains, then randomly select :math:`P` instances in each domain, for each instance, randomly select :math:`K` images to form a mini-batch of size :math:`N\times P\times K`. Args: dataset (ConcatDataset): dataset that contains data from multiple domains batch_size (int): mini-batch size (:math:`N\times P\times K` here) n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here) num_instances (int): number of instances to select in each domain (:math:`K` here) """ def __init__(self, dataset, batch_size, n_domains_per_batch, num_instances): super(Sampler, self).__init__() self.dataset = dataset self.sample_idxes_per_domain = {} for idx, (_, _, domain_id) in enumerate(self.dataset): if domain_id not in self.sample_idxes_per_domain: self.sample_idxes_per_domain[domain_id] = [] self.sample_idxes_per_domain[domain_id].append(idx) self.n_domains_in_dataset = len(self.sample_idxes_per_domain) self.n_domains_per_batch = n_domains_per_batch assert self.n_domains_in_dataset >= self.n_domains_per_batch assert batch_size % n_domains_per_batch == 0 self.batch_size_per_domain = batch_size // n_domains_per_batch assert self.batch_size_per_domain % num_instances == 0 self.num_instances = num_instances self.num_classes_per_domain = self.batch_size_per_domain // num_instances self.length = len(list(self.__iter__())) def __iter__(self): sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain) domain_idxes = [idx for idx in range(self.n_domains_in_dataset)] final_idxes = [] stop_flag = False while not stop_flag: selected_domains = random.sample(domain_idxes, self.n_domains_per_batch) for domain in selected_domains: sample_idxes = sample_idxes_per_domain[domain] selected_idxes = self.sample_multi_instances(sample_idxes) final_idxes.extend(selected_idxes) for idx in selected_idxes: sample_idxes_per_domain[domain].remove(idx) remaining_size = len(sample_idxes_per_domain[domain]) if remaining_size < self.batch_size_per_domain: stop_flag = True return iter(final_idxes) def sample_multi_instances(self, sample_idxes): idxes_per_cls = {} for idx in sample_idxes: _, cls, _ = self.dataset[idx] if cls not in idxes_per_cls: idxes_per_cls[cls] = [] idxes_per_cls[cls].append(idx) cls_list = [cls for cls in idxes_per_cls if len(idxes_per_cls[cls]) >= self.num_instances] if len(cls_list) < self.num_classes_per_domain: return random.sample(sample_idxes, self.batch_size_per_domain) selected_idxes = [] selected_classes = random.sample(cls_list, self.num_classes_per_domain) for cls in selected_classes: selected_idxes.extend(random.sample(idxes_per_cls[cls], self.num_instances)) return selected_idxes def __len__(self): return self.length

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started