import glob
import os
from collections import OrderedDict
from typing import Optional, Union
import cv2
import imageio
import numpy as np
import torch
from ..geometry.geometryutils import relative_transformation
from natsort import natsorted
from torch.utils import data
from . import datautils
__all__ = ["Scannet"]
[docs]class Scannet(data.Dataset):
r"""A torch Dataset for loading in `the Scannet dataset <http://www.scan-net.org/>`_. Will fetch sequences of
rgb images, depth maps, intrinsics matrices, poses, frame to frame relative transformations (with first frame's
pose as the reference transformation), names of sequences, and semantic segmentation labels.
Args:
basedir (str): Path to the base directory containing the `sceneXXXX_XX/` directories from ScanNet. Each scene
subdirectory is assumed to contain `color/`, `depth/`, `intrinsic/`, `label-filt/` and `pose/` directories.
seqmetadir (str): Path to directory containing sequence associations. Directory is assumed to contain
metadata `.txt` files (one metadata per sequence): e.g. `sceneXXXX_XX-seq_Y.txt` .
scenes (str or tuple of str): Scenes to use from sequences (used for creating train/val/test splits). Can
be path to a `.txt` file where each line is a scene name (`sceneXXXX_XX`), a tuple of scene names, or None
to use all scenes.
start (int): Index of the frame from which to start for every sequence. Default: 0
end (int): Index of the frame at which to end for every sequence. Default: -1
height (int): Spatial height to resize frames to. Default: 480
width (int): Spatial width to resize frames to. Default: 640
seg_classes (str): The palette of classes that the network should learn. Either `"nyu40"` or `"scannet20"`.
Default: `"scannet20"`
channels_first (bool): If True, will use channels first representation :math:`(B, L, C, H, W)` for images
`(batchsize, sequencelength, channels, height, width)`. If False, will use channels last representation
:math:`(B, L, H, W, C)`. Default: False
normalize_color (bool): Normalize color to range :math:`[0, 1]` or leave it at range :math:`[0, 255]`.
Default: False
return_depth (bool): Determines whether to return depths. Default: True
return_intrinsics (bool): Determines whether to return intrinsics. Default: True
return_pose (bool): Determines whether to return poses. Default: True
return_transform (bool): Determines whether to return transforms w.r.t. initial pose being transformed to be
identity. Default: True
return_names (bool): Determines whether to return sequence names. Default: True
return_labels (bool): Determines whether to return segmentation labels. Default: True
Examples::
>>> dataset = Scannet(
basedir="ScanNet-gradSLAM/extractions/scans/",
seqmetadir="ScanNet-gradSLAM/extractions/sequence_associations/",
scenes=("scene0000_00", "scene0001_00")
)
>>> loader = data.DataLoader(dataset=dataset, batch_size=4)
>>> colors, depths, intrinsics, poses, transforms, names, labels = next(iter(loader))
"""
def __init__(
self,
basedir: str,
seqmetadir: str,
scenes: Union[tuple, str, None],
start: Optional[int] = 0,
end: Optional[int] = -1,
height: int = 480,
width: int = 640,
seg_classes: str = "scannet20",
channels_first: bool = False,
normalize_color: bool = False,
*,
return_depth: bool = True,
return_intrinsics: bool = True,
return_pose: bool = True,
return_transform: bool = True,
return_names: bool = True,
return_labels: bool = True,
):
super(Scannet, self).__init__()
basedir = os.path.normpath(basedir)
self.height = height
self.width = width
self.height_downsample_ratio = float(height) / 480
self.width_downsample_ratio = float(width) / 640
self.seg_classes = seg_classes
self.channels_first = channels_first
self.normalize_color = normalize_color
self.return_depth = return_depth
self.return_intrinsics = return_intrinsics
self.return_pose = return_pose
self.return_transform = return_transform
self.return_names = return_names
self.return_labels = return_labels
self.color_encoding = get_color_encoding(self.seg_classes)
# Start and end frames. Used to determine sequence length.
self.start = start
self.end = end
full_sequence = self.end == -1
if start < 0:
raise ValueError("Start frame cannot be less than 0.")
if not (end == -1 or end > start):
raise ValueError(
"End frame ({}) should be equal to -1 or greater than start ({})".format(
end, start
)
)
self.seqlen = self.end - self.start
# scenes should be a tuple
if isinstance(scenes, str):
if os.path.isfile(scenes):
with open(scenes, "r") as f:
scenes = tuple(f.read().split("\n"))
else:
raise ValueError("incorrect filename: {} doesn't exist".format(scenes))
elif not (scenes is None or isinstance(scenes, tuple)):
msg = "scenes should either be path to split.txt or tuple of scenes or None, but was of type %r instead"
raise TypeError(msg % type(scenes))
# Get a list of all color, depth, pose, label and intrinsics files.
colorfiles, depthfiles, posefiles = [], [], []
labelfiles, intrinsicsfiles, seqnames = [], [], []
seqmetapaths = natsorted(glob.glob(os.path.join(seqmetadir, "*.txt")))
for seqmetapath in seqmetapaths:
scene_name = os.path.basename(seqmetapath).split("-")[0]
if scenes is not None:
if scene_name not in scenes:
continue
seq_colorfiles, seq_depthfiles, seq_posefiles = [], [], []
seq_labelfiles, seq_intrinsicsfiles = [], []
with open(seqmetapath, "r") as f:
lines = f.readlines()
if full_sequence:
self.end = len(lines)
self.seqlen = self.end - self.start
if self.seqlen > len(lines):
msg = "sequence length can't be larger than dataset sequence length but it was: %r > %r"
raise ValueError(msg % (self.seqlen, len(lines)))
lines = lines[self.start : self.end]
for line in lines:
line = line.strip().split()
msg = "incorrect reading from scannet metadata"
if line[0] != "color":
raise ValueError(msg)
seq_colorfiles.append(os.path.join(basedir, line[1]))
if line[2] != "depth":
raise ValueError(msg)
seq_depthfiles.append(os.path.join(basedir, line[3]))
if line[4] != "pose":
raise ValueError(msg)
seq_posefiles.append(os.path.join(basedir, line[5]))
if line[6] != "label-filt":
raise ValueError(msg)
seq_labelfiles.append(os.path.join(basedir, line[7]))
if line[14] != "intrinsic_depth":
raise ValueError(msg)
seq_intrinsicsfiles.append(os.path.join(basedir, line[15]))
colorfiles.append(seq_colorfiles)
depthfiles.append(seq_depthfiles)
posefiles.append(seq_posefiles)
labelfiles.append(seq_labelfiles)
intrinsicsfiles.append(seq_intrinsicsfiles[0])
seqnames.append(os.path.basename(seqmetapath).split(".")[0])
self.num_sequences = len(colorfiles)
# Class members to store the list of valid filepaths.
self.colorfiles = colorfiles
self.depthfiles = depthfiles
self.posefiles = posefiles
self.labelfiles = labelfiles
self.intrinsicsfiles = intrinsicsfiles
self.seqnames = seqnames
# Scaling factor for depth images
self.scaling_factor = 1000.0
def __len__(self):
r"""Returns the length of the dataset. """
return self.num_sequences
def __getitem__(self, idx: int):
r"""Returns the data from the sequence at index idx.
Returns:
color_seq (torch.Tensor): Sequence of rgb images of each frame
depth_seq (torch.Tensor): Sequence of depths of each frame
pose_seq (torch.Tensor): Sequence of poses of each frame
transform_seq (torch.Tensor): Sequence of transformations between each frame in the sequence and the
previous frame. Transformations are w.r.t. the first frame in the sequence having identity pose
(relative transformations with first frame's pose as the reference transformation). First
transformation in the sequence will always be `torch.eye(4)`.
label_seq (torch.Tensor): Sequence of semantic segmentation labels
intrinsics (torch.Tensor): Intrinsics for the current sequence
seqname (str): Name of the sequence
Shape:
- color_seq: :math:`(L, H, W, 3)` if `channels_first` is False, else :math:`(L, 3, H, W)`. `L` denotes
sequence length.
- depth_seq: :math:`(L, H, W, 1)` if `channels_first` is False, else :math:`(L, 1, H, W)`. `L` denotes
sequence length.
- pose_seq: :math:`(L, 4, 4)` where `L` denotes sequence length.
- transform_seq: :math:`(L, 4, 4)` where `L` denotes sequence length.
- label_seq: :math:`(L, H, W)` where `L` denotes sequence length.
- intrinsics: :math:`(1, 4, 4)`
"""
# Read in the color, depth, pose, label and intrinstics info.
color_seq_path = self.colorfiles[idx]
depth_seq_path = self.depthfiles[idx]
pose_seq_path = self.posefiles[idx]
label_seq_path = self.labelfiles[idx]
intrinsics_path = self.intrinsicsfiles[idx]
seqname = self.seqnames[idx]
color_seq, depth_seq, pose_seq, label_seq = [], [], [], []
poses = []
for i in range(self.seqlen):
color = np.asarray(imageio.imread(color_seq_path[i]), dtype=float)
color = self._preprocess_color(color)
color = torch.from_numpy(color)
color_seq.append(color)
if self.return_depth:
depth = np.asarray(imageio.imread(depth_seq_path[i]), dtype=np.int64)
depth = self._preprocess_depth(depth)
depth = torch.from_numpy(depth)
depth_seq.append(depth)
if self.return_pose or self.return_transform:
pose = np.loadtxt(pose_seq_path[i]).astype(float)
poses.append(pose)
pose = torch.from_numpy(pose)
pose_seq.append(pose)
if self.return_labels:
label = np.asarray(imageio.imread(label_seq_path[i]), dtype=np.uint8)
label = self._preprocess_label(label)
label = torch.from_numpy(label)
label_seq.append(label)
output = []
color_seq = torch.stack(color_seq, 0).float()
output.append(color_seq)
if self.return_depth:
depth_seq = torch.stack(depth_seq, 0).float()
output.append(depth_seq)
if self.return_intrinsics:
intrinsics = np.loadtxt(intrinsics_path).astype(float)
intrinsics = self._preprocess_intrinsics(intrinsics)
intrinsics = torch.from_numpy(intrinsics).float()
output.append(intrinsics)
if self.return_pose:
pose_seq = torch.stack(pose_seq, 0).float()
pose_seq = self._preprocess_poses(pose_seq)
output.append(pose_seq)
if self.return_transform:
transform_seq = datautils.poses_to_transforms(poses)
transform_seq = [torch.from_numpy(x).float() for x in transform_seq]
transform_seq = torch.stack(transform_seq, 0).float()
output.append(transform_seq)
if self.return_names:
output.append(seqname)
if self.return_labels:
label_seq = torch.stack(label_seq, 0).float()
output.append(label_seq)
return tuple(output)
def _preprocess_color(self, color: np.ndarray):
r"""Preprocesses the color image by resizing to :math:`(H, W, C)`, (optionally) normalizing values to
:math:`[0, 1]`, and (optionally) using channels first :math:`(C, H, W)` representation.
Args:
color (np.ndarray): Raw input rgb image
Retruns:
np.ndarray: Preprocessed rgb image
Shape:
- Input: :math:`(H_\text{old}, W_\text{old}, C)`
- Output: :math:`(H, W, C)` if `self.channels_first == False`, else :math:`(C, H, W)`.
"""
color = cv2.resize(
color, (self.width, self.height), interpolation=cv2.INTER_LINEAR
)
if self.normalize_color:
color = datautils.normalize_image(color)
if self.channels_first:
color = datautils.channels_first(color)
return color
def _preprocess_depth(self, depth: np.ndarray):
r"""Preprocesses the depth image by resizing, adding channel dimension, and scaling values to meters. Optionally
converts depth from channels last :math:`(H, W, 1)` to channels first :math:`(1, H, W)` representation.
Args:
depth (np.ndarray): Raw depth image
Returns:
np.ndarray: Preprocessed depth
Shape:
- depth: :math:`(H_\text{old}, W_\text{old})`
- Output: :math:`(H, W, 1)` if `self.channels_first == False`, else :math:`(1, H, W)`.
"""
depth = cv2.resize(
depth.astype(float),
(self.width, self.height),
interpolation=cv2.INTER_NEAREST,
)
depth = np.expand_dims(depth, -1)
if self.channels_first:
depth = datautils.channels_first(depth)
return depth / self.scaling_factor
def _preprocess_intrinsics(self, intrinsics: Union[torch.Tensor, np.ndarray]):
r"""Preprocesses the intrinsics by scaling `fx`, `fy`, `cx`, `cy` based on new frame size and expanding the
0-th dimension.
Args:
intrinsics (torch.Tensor or np.ndarray): Intrinsics matrix to be preprocessed
Returns:
Output (torch.Tensor or np.ndarray): Preprocessed intrinsics
Shape:
- intrinsics: :math:`(4, 4)`
- Output: :math:`(1, 4, 4)`
"""
scaled_intrinsics = datautils.scale_intrinsics(
intrinsics, self.height_downsample_ratio, self.width_downsample_ratio
)
if torch.is_tensor(scaled_intrinsics):
return scaled_intrinsics.unsqueeze(0)
elif isinstance(scaled_intrinsics, np.ndarray):
return np.expand_dims(scaled_intrinsics, 0)
def _preprocess_poses(self, poses: torch.Tensor):
r"""Preprocesses the poses by transforming all of them such that the initial pose will be identity.
Args:
poses (torch.Tensor): Pose matrices to be preprocessed
Returns:
Output (torch.Tensor): Poses relative to the initial frame
Shape:
- poses: :math:`(L, 4, 4)` where :math:`L` denotes sequence length.
- Output: :math:`(L, 4, 4)` where :math:`L` denotes sequence length.
"""
return relative_transformation(
poses[0].unsqueeze(0).repeat(poses.shape[0], 1, 1), poses
)
def _preprocess_label(self, label: np.ndarray):
r"""Preprocesses the "nyu40" label image by resizing it and (optionally) converting to "scannet20" labels
Args:
label (np.ndarray): "nyu40" label image with `uint8` values
Returns:
np.ndarray: Preprocessed labels
Shape:
- label: :math:`(H_\text{old}, W_\text{old})`
- Output: :math:`(H, W)`
"""
label = cv2.resize(
label, (self.width, self.height), interpolation=cv2.INTER_NEAREST
)
if self.seg_classes.lower() == "scannet20":
label = nyu40_to_scannet20(label)
label = np.expand_dims(label, -1)
return label
def get_color_encoding(seg_classes):
r"""Gets the color palette for different sets of labels (`"nyu40"` or `"scannet20"`)
Args:
seg_classes (str): Determines whether to use `"nyu40"` labels or `"scannet20"`
Returns:
Output (OrderedDict): Label names as keys and color palettes as values.
"""
if seg_classes.lower() == "nyu40":
# Color palette for nyu40 labels
return OrderedDict(
[
("unlabeled", (0, 0, 0)),
("wall", (174, 199, 232)),
("floor", (152, 223, 138)),
("cabinet", (31, 119, 180)),
("bed", (255, 187, 120)),
("chair", (188, 189, 34)),
("sofa", (140, 86, 75)),
("table", (255, 152, 150)),
("door", (214, 39, 40)),
("window", (197, 176, 213)),
("bookshelf", (148, 103, 189)),
("picture", (196, 156, 148)),
("counter", (23, 190, 207)),
("blinds", (178, 76, 76)),
("desk", (247, 182, 210)),
("shelves", (66, 188, 102)),
("curtain", (219, 219, 141)),
("dresser", (140, 57, 197)),
("pillow", (202, 185, 52)),
("mirror", (51, 176, 203)),
("floormat", (200, 54, 131)),
("clothes", (92, 193, 61)),
("ceiling", (78, 71, 183)),
("books", (172, 114, 82)),
("refrigerator", (255, 127, 14)),
("television", (91, 163, 138)),
("paper", (153, 98, 156)),
("towel", (140, 153, 101)),
("showercurtain", (158, 218, 229)),
("box", (100, 125, 154)),
("whiteboard", (178, 127, 135)),
("person", (120, 185, 128)),
("nightstand", (146, 111, 194)),
("toilet", (44, 160, 44)),
("sink", (112, 128, 144)),
("lamp", (96, 207, 209)),
("bathtub", (227, 119, 194)),
("bag", (213, 92, 176)),
("otherstructure", (94, 106, 211)),
("otherfurniture", (82, 84, 163)),
("otherprop", (100, 85, 144)),
]
)
elif seg_classes.lower() == "scannet20":
# Color palette for scannet20 labels
return OrderedDict(
[
("unlabeled", (0, 0, 0)),
("wall", (174, 199, 232)),
("floor", (152, 223, 138)),
("cabinet", (31, 119, 180)),
("bed", (255, 187, 120)),
("chair", (188, 189, 34)),
("sofa", (140, 86, 75)),
("table", (255, 152, 150)),
("door", (214, 39, 40)),
("window", (197, 176, 213)),
("bookshelf", (148, 103, 189)),
("picture", (196, 156, 148)),
("counter", (23, 190, 207)),
("desk", (247, 182, 210)),
("curtain", (219, 219, 141)),
("refrigerator", (255, 127, 14)),
("showercurtain", (158, 218, 229)),
("toilet", (44, 160, 44)),
("sink", (112, 128, 144)),
("bathtub", (227, 119, 194)),
("otherfurniture", (82, 84, 163)),
]
)
def nyu40_to_scannet20(label):
r"""Remaps a label image from the `"nyu40"` class palette to the `"scannet20"` class palette"""
# Ignore indices 13, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26. 27. 29. 30. 31. 32, 35. 37. 38, 40
# Because, these classes from 'nyu40' are absent from 'scannet20'. Our label files are in
# 'nyu40' format, hence this 'hack'. To see detailed class lists visit:
# http://kaldir.vc.in.tum.de/scannet_benchmark/labelids_all.txt ('nyu40' labels)
# http://kaldir.vc.in.tum.de/scannet_benchmark/labelids.txt ('scannet20' labels)
# The remaining labels are then to be mapped onto a contiguous ordering in the range [0,20]
# The remapping array comprises tuples (src, tar), where 'src' is the 'nyu40' label, and 'tar' is the
# corresponding target 'scannet20' label
remapping = [
(0, 0),
(13, 0),
(15, 0),
(17, 0),
(18, 0),
(19, 0),
(20, 0),
(21, 0),
(22, 0),
(23, 0),
(25, 0),
(26, 0),
(27, 0),
(29, 0),
(30, 0),
(31, 0),
(32, 0),
(35, 0),
(37, 0),
(38, 0),
(40, 0),
(14, 13),
(16, 14),
(24, 15),
(28, 16),
(33, 17),
(34, 18),
(36, 19),
(39, 20),
]
for src, tar in remapping:
label[np.where(label == src)] = tar
return label