Shortcuts

Source code for common.vision.datasets.keypoint_detection.keypoint_dataset

"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from abc import ABC
import numpy as np
from torch.utils.data.dataset import Dataset
from webcolors import name_to_rgb
import cv2


[docs]class KeypointDataset(Dataset, ABC): """A generic dataset class for image keypoint detection Args: root (str): Root directory of dataset num_keypoints (int): Number of keypoints samples (list): list of data transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~common.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 keypoints_group (dict): a dict that stores the index of different types of keypoints colored_skeleton (dict): a dict that stores the index and color of different skeleton """ def __init__(self, root, num_keypoints, samples, transforms=None, image_size=(256, 256), heatmap_size=(64, 64), sigma=2, keypoints_group=None, colored_skeleton=None): self.root = root self.num_keypoints = num_keypoints self.samples = samples self.transforms = transforms self.image_size = image_size self.heatmap_size = heatmap_size self.sigma = sigma self.keypoints_group = keypoints_group self.colored_skeleton = colored_skeleton def __len__(self): return len(self.samples)
[docs] def visualize(self, image, keypoints, filename): """Visualize an image with its keypoints, and store the result into a file Args: image (PIL.Image): keypoints (torch.Tensor): keypoints in shape K x 2 filename (str): the name of file to store """ assert self.colored_skeleton is not None image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy() for (_, (line, color)) in self.colored_skeleton.items(): for i in range(len(line) - 1): start, end = keypoints[line[i]], keypoints[line[i + 1]] cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), color=name_to_rgb(color), thickness=3) for keypoint in keypoints: cv2.circle(image, (int(keypoint[0]), int(keypoint[1])), 3, name_to_rgb('black'), 1) cv2.imwrite(filename, image)
[docs] def group_accuracy(self, accuracies): """ Group the accuracy of K keypoints into different kinds. Args: accuracies (list): accuracy of the K keypoints Returns: accuracy of ``N=len(keypoints_group)`` kinds of keypoints """ grouped_accuracies = dict() for name, keypoints in self.keypoints_group.items(): grouped_accuracies[name] = sum([accuracies[idx] for idx in keypoints]) / len(keypoints) return grouped_accuracies
[docs]class Body16KeypointDataset(KeypointDataset, ABC): """ Dataset with 16 body keypoints. """ # TODO: add image head = (9,) shoulder = (12, 13) elbow = (11, 14) wrist = (10, 15) hip = (2, 3) knee = (1, 4) ankle = (0, 5) all = (12, 13, 11, 14, 10, 15, 2, 3, 1, 4, 0, 5) right_leg = (0, 1, 2, 8) left_leg = (5, 4, 3, 8) backbone = (8, 9) right_arm = (10, 11, 12, 8) left_arm = (15, 14, 13, 8) def __init__(self, root, samples, **kwargs): colored_skeleton = { "right_leg": (self.right_leg, 'yellow'), "left_leg": (self.left_leg, 'green'), "backbone": (self.backbone, 'blue'), "right_arm": (self.right_arm, 'purple'), "left_arm": (self.left_arm, 'red'), } keypoints_group = { "head": self.head, "shoulder": self.shoulder, "elbow": self.elbow, "wrist": self.wrist, "hip": self.hip, "knee": self.knee, "ankle": self.ankle, "all": self.all } super(Body16KeypointDataset, self).__init__(root, 16, samples, keypoints_group=keypoints_group, colored_skeleton=colored_skeleton, **kwargs)
[docs]class Hand21KeypointDataset(KeypointDataset, ABC): """ Dataset with 21 hand keypoints. """ # TODO: add image MCP = (1, 5, 9, 13, 17) PIP = (2, 6, 10, 14, 18) DIP = (3, 7, 11, 15, 19) fingertip = (4, 8, 12, 16, 20) all = tuple(range(21)) thumb = (0, 1, 2, 3, 4) index_finger = (0, 5, 6, 7, 8) middle_finger = (0, 9, 10, 11, 12) ring_finger = (0, 13, 14, 15, 16) little_finger = (0, 17, 18, 19, 20) def __init__(self, root, samples, **kwargs): colored_skeleton = { "thumb": (self.thumb, 'yellow'), "index_finger": (self.index_finger, 'green'), "middle_finger": (self.middle_finger, 'blue'), "ring_finger": (self.ring_finger, 'purple'), "little_finger": (self.little_finger, 'red'), } keypoints_group = { "MCP": self.MCP, "PIP": self.PIP, "DIP": self.DIP, "fingertip": self.fingertip, "all": self.all } super(Hand21KeypointDataset, self).__init__(root, 21, samples, keypoints_group=keypoints_group, colored_skeleton=colored_skeleton, **kwargs)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started