Shortcuts

Source code for common.vision.datasets.keypoint_detection.freihand

"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import json
import time
import torch
import os
import os.path as osp
from torchvision.datasets.utils import download_and_extract_archive

from ...transforms.keypoint_detection import *
from .keypoint_dataset import Hand21KeypointDataset
from .util import *


""" General util functions. """
def _assert_exist(p):
    msg = 'File does not exists: %s' % p
    assert os.path.exists(p), msg


def json_load(p):
    _assert_exist(p)
    with open(p, 'r') as fi:
        d = json.load(fi)
    return d


def load_db_annotation(base_path, set_name=None):
    if set_name is None:
        # only training set annotations are released so this is a valid default choice
        set_name = 'training'

    print('Loading FreiHAND dataset index ...')
    t = time.time()

    # assumed paths to data containers
    k_path = os.path.join(base_path, '%s_K.json' % set_name)
    mano_path = os.path.join(base_path, '%s_mano.json' % set_name)
    xyz_path = os.path.join(base_path, '%s_xyz.json' % set_name)

    # load if exist
    K_list = json_load(k_path)
    mano_list = json_load(mano_path)
    xyz_list = json_load(xyz_path)

    # should have all the same length
    assert len(K_list) == len(mano_list), 'Size mismatch.'
    assert len(K_list) == len(xyz_list), 'Size mismatch.'

    print('Loading of %d samples done in %.2f seconds' % (len(K_list), time.time()-t))
    return list(zip(K_list, mano_list, xyz_list))


def projectPoints(xyz, K):
    """ Project 3D coordinates into image space. """
    xyz = np.array(xyz)
    K = np.array(K)
    uv = np.matmul(K, xyz.T).T
    return uv[:, :2] / uv[:, -1:]


""" Dataset related functions. """
def db_size(set_name):
    """ Hardcoded size of the datasets. """
    if set_name == 'training':
        return 32560  # number of unique samples (they exists in multiple 'versions')
    elif set_name == 'evaluation':
        return 3960
    else:
        assert 0, 'Invalid choice.'


class sample_version:
    gs = 'gs'  # green screen
    hom = 'hom'  # homogenized
    sample = 'sample'  # auto colorization with sample points
    auto = 'auto'  # auto colorization without sample points: automatic color hallucination

    db_size = db_size('training')

    @classmethod
    def valid_options(cls):
        return [cls.gs, cls.hom, cls.sample, cls.auto]


    @classmethod
    def check_valid(cls, version):
        msg = 'Invalid choice: "%s" (must be in %s)' % (version, cls.valid_options())
        assert version in cls.valid_options(), msg

    @classmethod
    def map_id(cls, id, version):
        cls.check_valid(version)
        return id + cls.db_size*cls.valid_options().index(version)


[docs]class FreiHand(Hand21KeypointDataset): """`FreiHand Dataset <https://lmb.informatik.uni-freiburg.de/projects/freihand/>`_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. task (str, optional): The post-processing option to create dataset. Choices include ``'gs'``: green screen \ recording, ``'auto'``: auto colorization without sample points: automatic color hallucination, \ ``'sample'``: auto colorization with sample points, ``'hom'``: homogenized, \ and ``'all'``: all hands. Default: 'all'. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~common.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: In `root`, there will exist following files after downloading. :: *.json training/ evaluation/ """ def __init__(self, root, split='train', task='all', download=True, **kwargs): if download: if not osp.exists(osp.join(root, "training")) or not osp.exists(osp.join(root, "evaluation")): download_and_extract_archive("https://lmb.informatik.uni-freiburg.de/data/freihand/FreiHAND_pub_v2.zip", download_root=root, filename="FreiHAND_pub_v2.zip", remove_finished=False, extract_root=root) assert split in ['train', 'test', 'all'] self.split = split assert task in ['all', 'gs', 'auto', 'sample', 'hom'] self.task = task if task == 'all': samples = self.get_samples(root, 'gs') + self.get_samples(root, 'auto') + self.get_samples(root, 'sample') + self.get_samples(root, 'hom') else: samples = self.get_samples(root, task) random.seed(42) random.shuffle(samples) samples_len = len(samples) samples_split = min(int(samples_len * 0.2), 3200) if self.split == 'train': samples = samples[samples_split:] elif self.split == 'test': samples = samples[:samples_split] super(FreiHand, self).__init__(root, samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = os.path.join(self.root, image_name) image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] # Crop the images such that the hand is at the center of the image # The images will be 1.5 times larger than the hand # The crop process will change Xc and Yc, leaving Zc with no changes bounding_box = get_bounding_box(keypoint2d) w, h = image.size left, upper, right, lower = scale_box(bounding_box, w, h, 1.5) image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d) # Change all hands to right hands if sample['left'] is False: image, keypoint2d = hflip(image, keypoint2d) image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.ones((self.num_keypoints, ), dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) z = keypoint3d_n[:, 2] meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) 'z': z, } return image, target, target_weight, meta def get_samples(self, root, version='gs'): set = 'training' # load annotations of this set db_data_anno = load_db_annotation(root, set) version_map = { 'gs': sample_version.gs, 'hom': sample_version.hom, 'sample': sample_version.sample, 'auto': sample_version.auto } samples = [] for idx in range(db_size(set)): image_name = os.path.join(set, 'rgb', '%08d.jpg' % sample_version.map_id(idx, version_map[version])) mask_name = os.path.join(set, 'mask', '%08d.jpg' % idx) intrinsic_matrix, mano, keypoint3d = db_data_anno[idx] keypoint2d = projectPoints(keypoint3d, intrinsic_matrix) sample = { 'name': image_name, 'mask_name': mask_name, 'keypoint2d': keypoint2d, 'keypoint3d': keypoint3d, 'intrinsic_matrix': intrinsic_matrix, 'left': False } samples.append(sample) return samples

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started