Source code for gradslam.structures.rgbdimages

from typing import Optional, Union

from plotly.subplots import make_subplots
import torch

from .structutils import numpy_to_plotly_image
from ..geometry.geometryutils import create_meshgrid
from ..geometry.projutils import inverse_intrinsics

__all__ = ["RGBDImages"]


[docs]class RGBDImages(object): r"""Initializes an RGBDImage object consisting of a batch of a sequence of rgb images, depth maps, camera intrinsics, and (optionally) poses. Args: rgb_image (torch.Tensor): 3-channel rgb image depth_image (torch.Tensor): 1-channel depth map intrinsics (torch.Tensor): camera intrinsics poses (torch.Tensor or None): camera extrinsics. Default: None channels_first(bool): indicates whether `rgb_image` and `depth_image` have channels first or channels last representation (i.e. rgb_image.shape is :math:`(B, L, H, W, 3)` or :math:`(B, L, 3, H, W)`. Default: False device (torch.device or str or None): The desired device of internal tensors. If None, sets device to be same as `rgb_image` device. Default: None pixel_pos (torch.Tensor or None): Similar to meshgrid but with extra channel of 1s at the end. If provided, can save computations when computing vertex maps. Default: None Shape: - rgb_image: :math:`(B, L, H, W, 3)` if `channels_first` is False, else :math:`(B, L, 3, H, W)` - depth_image: :math:`(B, L, H, W, 1)` if `channels_first` is False, else :math:`(B, L, 1, H, W)` - intrinsics: :math:`(B, 1, 4, 4)` - poses: :math:`(B, L, 4, 4)` - pixel_pos: :math:`(B, L, H, W, 3)` if `channels_first` is False, else :math:`(B, L, 3, H, W)` Examples:: >>> colors = torch.rand([2, 8, 32, 32, 3]) >>> depths = torch.rand([2, 8, 32, 32, 1]) >>> intrinsics = torch.rand([2, 1, 4, 4]) >>> poses = torch.rand([2, 8, 4, 4]) >>> rgbdimages = gradslam.RGBDImages(colors, depths, intrinsics, poses) >>> print(rgbdimages.shape) (2, 8, 32, 32) >>> rgbd_select = rgbd_frame[1, 4:8] >>> print(rgbd_select.shape) (1, 4, 32, 32) >>> print(rgbdimages.vertex_map.shape) (2, 8, 32, 32, 3) >>> print(rgbdimages.normal_map.shape) (2, 8, 32, 32, 3) """ _INTERNAL_TENSORS = [ "_rgb_image", "_depth_image", "_intrinsics", "_poses", "_pixel_pos", "_vertex_map", "_normal_map", "_global_vertex_map", "_global_normal_map", ] def __init__( self, rgb_image: torch.Tensor, depth_image: torch.Tensor, intrinsics: torch.Tensor, poses: Optional[torch.Tensor] = None, channels_first: bool = False, device: Union[torch.device, str, None] = None, *, pixel_pos: Optional[torch.Tensor] = None, ): super().__init__() # input type checks if not torch.is_tensor(rgb_image): msg = "Expected rgb_image to be of type tensor; got {}" raise TypeError(msg.format(type(rgb_image))) if not torch.is_tensor(depth_image): msg = "Expected depth_image to be of type tensor; got {}" raise TypeError(msg.format(type(depth_image))) if not torch.is_tensor(intrinsics): msg = "Expected intrinsics to be of type tensor; got {}" raise TypeError(msg.format(type(intrinsics))) if not (poses is None or torch.is_tensor(poses)): msg = "Expected poses to be of type tensor or None; got {}" raise TypeError(msg.format(type(poses))) if not isinstance(channels_first, bool): msg = "Expected channels_first to be of type bool; got {}" raise TypeError(msg.format(type(channels_first))) if not (pixel_pos is None or torch.is_tensor(pixel_pos)): msg = "Expected pixel_pos to be of type tensor or None; got {}" raise TypeError(msg.format(type(pixel_pos))) self._channels_first = channels_first # input ndim checks if rgb_image.ndim != 5: msg = "rgb_image should have ndim=5, but had ndim={}".format(rgb_image.ndim) raise ValueError(msg) if depth_image.ndim != 5: msg = "depth_image should have ndim=5, but had ndim={}".format( depth_image.ndim ) raise ValueError(msg) if intrinsics.ndim != 4: msg = "intrinsics should have ndim=4, but had ndim={}".format( intrinsics.ndim ) raise ValueError(msg) if poses is not None and poses.ndim != 4: msg = "poses should have ndim=4, but had ndim={}".format(poses.ndim) raise ValueError(msg) self._rgb_image_shape = rgb_image.shape self._depth_shape = tuple( v if i != self.cdim else 1 for i, v in enumerate(rgb_image.shape) ) self._intrinsics_shape = (rgb_image.shape[0], 1, 4, 4) self._poses_shape = (*rgb_image.shape[:2], 4, 4) self._pixel_pos_shape = ( *rgb_image.shape[: self.cdim], *rgb_image.shape[self.cdim + 1 :], 3, ) # input shape checks if rgb_image.shape[self.cdim] != 3: msg = "Expected rgb_image to have 3 channels on dimension {0}. Got {1} instead" raise ValueError(msg.format(self.cdim, rgb_image.shape[self.cdim])) if depth_image.shape != self._depth_shape: msg = "Expected depth_image to have shape {0}. Got {1} instead" raise ValueError(msg.format(self._depth_shape, depth_image.shape)) if intrinsics.shape != self._intrinsics_shape: msg = "Expected intrinsics to have shape {0}. Got {1} instead" raise ValueError(msg.format(self._intrinsics_shape, intrinsics.shape)) if poses is not None and (poses.shape != self._poses_shape): msg = "Expected poses to have shape {0}. Got {1} instead" raise ValueError(msg.format(self._poses_shape, poses.shape)) if pixel_pos is not None and (pixel_pos.shape != self._pixel_pos_shape): msg = "Expected pixel_pos to have shape {0}. Got {1} instead" raise ValueError(msg.format(self._pixel_pos_shape, pixel_pos.shape)) # assert device type inputs = [rgb_image, depth_image, intrinsics, poses, pixel_pos] devices = [x.device for x in inputs if x is not None] if len(set(devices)) != 1: raise ValueError( "All inputs must be on same device, but got more than 1 device: {}".format( set(devices) ) ) self._rgb_image = rgb_image if device is None else rgb_image.to(device) self.device = self._rgb_image.device self._depth_image = depth_image.to(self.device) self._intrinsics = intrinsics.to(self.device) self._poses = poses.to(self.device) if poses is not None else None self._pixel_pos = pixel_pos.to(self.device) if pixel_pos is not None else None self._vertex_map = None self._global_vertex_map = None self._normal_map = None self._global_normal_map = None self._valid_depth_mask = None self._B, self._L = self._rgb_image.shape[:2] self.h = ( self._rgb_image.shape[3] if self._channels_first else self._rgb_image.shape[2] ) self.w = ( self._rgb_image.shape[4] if self._channels_first else self._rgb_image.shape[3] ) self.shape = (self._B, self._L, self.h, self.w) def __getitem__(self, index): r""" Args: index (int or slice or list of int): Specifying the index of the rgbdimages to retrieve. Can be an int, slice, list of ints or a boolean tensor. Returns: gradslam.RGBDImages: Selected rgbdimages. The rgbdimages tensors are not cloned. """ if isinstance(index, tuple) or isinstance(index, int): _index_slices = () if isinstance(index, int): _index_slices += (slice(index, index + 1),) + (slice(None, None),) elif len(index) > 2: raise IndexError("Only batch and sequences can be indexed") elif isinstance(index, tuple): for x in index: if isinstance(x, int): _index_slices += (slice(x, x + 1),) else: _index_slices += (x,) new_rgb = self._rgb_image[_index_slices[0], _index_slices[1]] if new_rgb.shape[0] == 0: raise IndexError( "Incorrect indexing at dimension 0, make sure range is within 0 and {0}".format( self._B ) ) if new_rgb.shape[1] == 0: raise IndexError( "Incorrect indexing at dimension 1, make sure range is within 0 and {0}".format( self._L ) ) new_depth = self._depth_image[_index_slices[0], _index_slices[1]] new_intrinsics = self._intrinsics[_index_slices[0], :] other = RGBDImages( new_rgb, new_depth, new_intrinsics, channels_first=self.channels_first, ) for k in self._INTERNAL_TENSORS: if k in ["_rgb_image", "_depth_image", "_intrinsics"]: continue v = getattr(self, k) if torch.is_tensor(v): setattr(other, k, v[_index_slices[0], _index_slices[1]]) return other else: raise IndexError(index) def __len__(self): return self._B @property def channels_first(self): r"""Gets bool indicating whether RGBDImages representation is channels first or not Returns: bool: True if RGBDImages representation is channels first, else False. """ return self._channels_first @property def cdim(self): r"""Gets the channel dimension Returns: int: :math:`2` if self.channels_first is True, else :math:`4`. """ return 2 if self.channels_first else 4 @property def rgb_image(self): r"""Gets the rgb image Returns: torch.Tensor: tensor representation of `rgb_image` Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ return self._rgb_image @property def depth_image(self): r"""Gets the depth image Returns: torch.Tensor: tensor representation of `depth_image` Shape: - Output: :math:`(B, L, H, W, 1)` if self.channels_first is False, else :math:`(B, L, 1, H, W)` """ return self._depth_image @property def intrinsics(self): r"""Gets the `intrinsics` Returns: torch.Tensor: tensor representation of `intrinsics` Shape: - Output: :math:`(B, 1, 4, 4)` """ return self._intrinsics @property def poses(self): r"""Gets the `poses` Returns: torch.Tensor: tensor representation of `poses` Shape: - Output: :math:`(B, L, 4, 4)` """ return self._poses @property def pixel_pos(self): r"""Gets the `pixel_pos` Returns: torch.Tensor: tensor representation of `pixel_pos` Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ return self._pixel_pos @property def valid_depth_mask(self): r"""Gets a mask which is True wherever `self.dept_image` is :math:`>0` Returns: torch.Tensor: Tensor of dtype bool with same shape as `self.depth_image`. Tensor is True wherever `self.depth_image` > 0, and False otherwise. Shape: - Output: :math:`(B, L, H, W, 1)` if self.channels_first is False, else :math:`(B, L, 1, H, W)` """ if self._valid_depth_mask is None: self._valid_depth_mask = self._depth_image > 0 return self._valid_depth_mask @property def has_poses(self): r"""Determines whether self has `poses` or not Returns: bool """ return self._poses is not None @property def vertex_map(self): r"""Gets the local vertex maps Returns: torch.Tensor: tensor representation of local coordinated vertex maps Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ if self._vertex_map is None: self._compute_vertex_map() return self._vertex_map @property def normal_map(self): r"""Gets the local normal maps Returns: torch.Tensor: tensor representation of local coordinated normal maps Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ if self._normal_map is None: self._compute_normal_map() return self._normal_map @property def global_vertex_map(self): r"""Gets the global vertex maps Returns: torch.Tensor: tensor representation of global coordinated vertex maps Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ if self._global_vertex_map is None: self._compute_global_vertex_map() return self._global_vertex_map @property def global_normal_map(self): r"""Gets the global normal maps Returns: torch.Tensor: tensor representation of global coordinated normal maps Shape: - Output: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ if self._global_normal_map is None: self._compute_global_normal_map() return self._global_normal_map @rgb_image.setter def rgb_image(self, value): r"""Updates `rgb_image` of self. Args: value (torch.Tensor): New rgb image values Shape: - value: :math:`(B, L, H, W, 3)` if self.channels_first is False, else :math:`(B, L, 3, H, W)` """ if value is not None: self._assert_shape(value, self._rgb_image_shape) self._rgb_image = value @depth_image.setter def depth_image(self, value): r"""Updates `depth_image` of self. Args: value (torch.Tensor): New depth image values Shape: - value: :math:`(B, L, H, W, 1)` if self.channels_first is False, else :math:`(B, L, 1, H, W)` """ if value is not None: self._assert_shape(value, self._depth_image_shape) self._depth_image = value self._vertex_map = None self._normal_map = None self._global_vertex_map = None self._global_normal_map = None @intrinsics.setter def intrinsics(self, value): r"""Updates `intrinsics` of self. Args: value (torch.Tensor): New intrinsics values Shape: - value: :math:`(B, 1, 4, 4)` """ if value is not None: self._assert_shape(value, self._intrinsics_shape) self._intrinsics = value self._vertex_map = None self._normal_map = None self._global_vertex_map = None self._global_normal_map = None @poses.setter def poses(self, value): r"""Updates `poses` of self. Args: value (torch.Tensor): New pose values Shape: - value: :math:`(B, L, 4, 4)` """ if value is not None: self._assert_shape(value, self._poses_shape) self._poses = value self._global_vertex_map = None self._global_normal_map = None
[docs] def detach(self): r"""Detachs RGBDImages object. All internal tensors are detached individually. Returns: gradslam.RGBDImages: detached gradslam.RGBDImages object """ other = self.clone() for k in self._INTERNAL_TENSORS: v = getattr(self, k) if torch.is_tensor(v): setattr(other, k, v.detach()) return other
[docs] def clone(self): r"""Returns deep copy of RGBDImages object. All internal tensors are cloned individually. Returns: gradslam.RGBDImages: cloned gradslam.RGBDImages object """ other = RGBDImages( rgb_image=self._rgb_image.clone(), depth_image=self._depth_image.clone(), intrinsics=self._intrinsics.clone(), channels_first=self.channels_first, ) for k in self._INTERNAL_TENSORS: if k in ["_rgb_image", "_depth_image", "_intrinsics"]: continue v = getattr(self, k) if torch.is_tensor(v): setattr(other, k, v.clone()) return other
[docs] def to(self, device: Union[torch.device, str], copy: bool = False): r"""Match functionality of torch.Tensor.to(device) If copy = True or the self Tensor is on a different device, the returned tensor is a copy of self with the desired torch.device. If copy = False and the self Tensor already has the correct torch.device, then self is returned. Args: device (torch.device or str): Device id for the new tensor. copy (bool): Boolean indicator whether or not to clone self. Default False. Returns: gradslam.RGBDImages """ # hack to know which gpu is used when device("cuda") device = torch.Tensor().to(device).device if not copy and self.device == device: return self other = self.clone() other.device = device for k in self._INTERNAL_TENSORS: v = getattr(self, k) if torch.is_tensor(v): setattr(other, k, v.to(device)) return other
[docs] def cpu(self): r"""Match functionality of torch.Tensor.cpu() Returns: gradslam.RGBDImages """ return self.to(torch.device("cpu"))
[docs] def cuda(self): r"""Match functionality of torch.Tensor.cuda() Returns: gradslam.RGBDImages """ return self.to(torch.device("cuda"))
[docs] def to_channels_last(self, copy: bool = False): r"""Converts to channels last representation If copy = True or self channels_first is True, the returned RGBDImages object is a copy of self with channels last representation. If copy = False and self channels_first is already False, then self is returned. Args: copy (bool): Boolean indicator whether or not to clone self. Default False. Returns: gradslam.RGBDImages """ if not (copy or self.channels_first): return self return self.clone().to_channels_last_()
[docs] def to_channels_first(self, copy: bool = False): r"""Converts to channels first representation If copy = True or self channels_first is False, the returned RGBDImages object is a copy of self with channels first representation. If copy = False and self channels_first is already True, then self is returned. Args: copy (bool): Boolean indicator whether or not to clone self. Default False. Returns: gradslam.RGBDImages """ if not copy and self.channels_first: return self return self.clone().to_channels_first_()
[docs] def to_channels_last_(self): r"""Converts to channels last representation. In place operation. Returns: gradslam.RGBDImages """ if not self.channels_first: return self ordering = (0, 1, 3, 4, 2) # B L C H W -> B L H W C permute = RGBDImages._permute_if_not_None self._rgb_image = permute(self._rgb_image, ordering) self._depth_image = permute(self._depth_image, ordering) self._vertex_map = permute(self._vertex_map, ordering) self._global_vertex_map = permute(self._global_vertex_map, ordering) self._normal_map = permute(self._normal_map, ordering) self._global_normal_map = permute(self._global_normal_map, ordering) self._channels_first = False self._rgb_image_shape = tuple(self._rgb_image.shape) self._depth_image_shape = tuple(self._depth_image.shape) return self
[docs] def to_channels_first_(self): r"""Converts to channels first representation. In place operation. Returns: gradslam.RGBDImages """ if self.channels_first: return self ordering = (0, 1, 4, 2, 3) # B L H W C -> B L C H W permute = RGBDImages._permute_if_not_None self._rgb_image = permute(self._rgb_image, ordering) self._depth_image = permute(self._depth_image, ordering) self._vertex_map = permute(self._vertex_map, ordering) self._global_vertex_map = permute(self._global_vertex_map, ordering) self._normal_map = permute(self._normal_map, ordering) self._global_normal_map = permute(self._global_normal_map, ordering) self._channels_first = True self._rgb_image_shape = tuple(self._rgb_image.shape) self._depth_image_shape = tuple(self._depth_image.shape) return self
@staticmethod def _permute_if_not_None( tensor: Optional[torch.Tensor], ordering: tuple, contiguous: bool = True ): r"""Permutes input if it is not None based on given ordering Args: tensor (torch.Tensor or None): Tensor to be permuted, or None ordering (tuple): The desired ordering of dimensions contiguous (bool): Whether to call `.contiguous()` on permuted tensor before returning. Default: True Returns: torch.Tensor or None: Permuted tensor or None """ if tensor is None: return None assert torch.is_tensor(tensor) return ( tensor.permute(*ordering).contiguous() if contiguous else tensor.permute(*ordering) ) def _compute_vertex_map(self): r"""Coverts a batch of depth images into a batch of vertex maps.""" B, L = self.shape[:2] device = self._depth_image.device if self._pixel_pos is None: meshgrid = ( create_meshgrid(self.h, self.w, normalized_coords=False) .view(1, 1, self.h, self.w, 2) .repeat(B, L, 1, 1, 1) .to(device) ) self._pixel_pos = torch.cat( [ meshgrid[..., 1:], meshgrid[..., 0:1], torch.ones_like(meshgrid[..., 0].unsqueeze(-1)), ], -1, ) Kinv = inverse_intrinsics(self._intrinsics)[..., :3, :3] # TODO: Time tests for all einsums. Might not be efficient (especially on cpu). Kinv = Kinv.repeat(1, L, 1, 1) # Add an extra channel of ones to meshgrid for z values if self.channels_first: self._vertex_map = ( torch.einsum("bsjc,bshwc->bsjhw", Kinv, self._pixel_pos) * self._depth_image ) else: self._vertex_map = ( torch.einsum("bsjc,bshwc->bshwj", Kinv, self._pixel_pos) * self._depth_image ) # zero out missing depth values self._vertex_map = self._vertex_map * self.valid_depth_mask.to( self._vertex_map.dtype ) def _compute_global_vertex_map(self): r"""Coverts a batch of local vertex maps into a batch of global vertex maps.""" if self._poses is None: self._global_vertex_map = self.vertex_map.clone() return local_vertex_map = self.vertex_map B, L = self.shape[:2] rmat = self._poses[..., :3, :3] tvec = self._poses[..., :3, 3] # TODO: Time tests for all einsums. Might not be efficient (especially on cpu). # Add an extra channel of ones to meshgrid for z values if self.channels_first: self._global_vertex_map = torch.einsum( "bsjc,bschw->bsjhw", rmat, local_vertex_map ) self._global_vertex_map = self._global_vertex_map + tvec.view(B, L, 3, 1, 1) else: self._global_vertex_map = torch.einsum( "bsjc,bshwc->bshwj", rmat, local_vertex_map ) self._global_vertex_map = self._global_vertex_map + tvec.view(B, L, 1, 1, 3) # zero out missing depth values self._global_vertex_map = self._global_vertex_map * self.valid_depth_mask.to( self._global_vertex_map.dtype ) def _compute_normal_map(self): r"""Converts a batch of vertex maps to a batch of normal maps.""" dhoriz: torch.Tensor = torch.zeros_like(self.vertex_map) dverti: torch.Tensor = torch.zeros_like(self.vertex_map) if self.channels_first: dhoriz[..., :-1] = self.vertex_map[..., 1:] - self.vertex_map[..., :-1] dverti[..., :-1, :] = ( self.vertex_map[..., 1:, :] - self.vertex_map[..., :-1, :] ) dhoriz[..., -1] = dhoriz[..., -2] dverti[..., -1, :] = dverti[..., -2, :] dim = 2 else: dhoriz[..., :-1, :] = ( self.vertex_map[..., 1:, :] - self.vertex_map[..., :-1, :] ) dverti[..., :-1, :, :] = ( self.vertex_map[..., 1:, :, :] - self.vertex_map[..., :-1, :, :] ) dhoriz[..., -1, :] = dhoriz[..., -2, :] dverti[..., -1, :, :] = dverti[..., -2, :, :] dim = -1 normal_map: torch.Tensor = torch.cross(dhoriz, dverti, dim=dim) norm: torch.Tensor = normal_map.norm(dim=dim).unsqueeze(dim) self._normal_map: torch.Tensor = normal_map / torch.where( norm == 0, torch.ones_like(norm), norm ) # zero out missing depth values self._normal_map = self._normal_map * self.valid_depth_mask.to( self._normal_map.dtype ) def _compute_global_normal_map(self): r"""Coverts a batch of local noraml maps into a batch of global normal maps.""" if self._poses is None: self._global_normal_map = self.normal_map.clone() return local_normal_map = self.normal_map B, L = self.shape[:2] rmat = self._poses[..., :3, :3] if self.channels_first: self._global_normal_map = torch.einsum( "bsjc,bschw->bsjhw", rmat, local_normal_map ) else: self._global_normal_map = torch.einsum( "bsjc,bshwc->bshwj", rmat, local_normal_map )
[docs] def plotly( self, index: int, include_depth: bool = True, as_figure: bool = True, ms_per_frame: int = 50, ): r"""Converts `index`-th sequence of rgbd images to either a `plotly.graph_objects.Figure` or a list of dicts containing `plotly.graph_objects.Image` objects of rgb and (optionally) depth images: .. code-block:: python frames = [ {'name': 0, 'data': [rgbImage0, depthImage0], 'traces': [0, 1]}, {'name': 1, 'data': [rgbImage1, depthImage1], 'traces': [0, 1]}, {'name': 2, 'data': [rgbImage2, depthImage2], 'traces': [0, 1]}, ... ] Returned `frames` can be passed to `go.Figure(frames=frames)`. Args: index (int): Index of which rgbd image (from the batch of rgbd images) to convert to plotly representation. include_depth (bool): If True, will include depth images in the returned object. Default: True as_figure (bool): If True, returns a `plotly.graph_objects.Figure` object which can easily be visualized by calling `.show()` on. Otherwise, returns a list of dicts (`frames`) which can be passed to `go.Figure(frames=frames)`. Default: True ms_per_frame (int): Milliseconds per frame when play button is hit. Only applicable if `as_figure=True`. Default: 50 Returns: plotly.graph_objects.Figure or list of dict: If `as_figure` is True, will return `plotly.graph_objects.Figure` object from the `index`-th sequence of rgbd images. Else, returns a list of dicts (`frames`). """ if not isinstance(index, int): raise TypeError("Index should be int, but was {}.".format(type(index))) def frame_args(duration): return { "frame": {"duration": duration, "redraw": True}, "mode": "immediate", "fromcurrent": True, "transition": {"duration": duration, "easing": "linear"}, } torch_rgb = self.rgb_image[index] if (torch_rgb.max() < 1.1).item(): torch_rgb = torch_rgb * 255 torch_rgb = torch.clamp(torch_rgb, min=0.0, max=255.0) numpy_rgb = torch_rgb.detach().cpu().numpy().astype("uint8") Image_rgb = [numpy_to_plotly_image(rgb, i) for i, rgb in enumerate(numpy_rgb)] if not include_depth: frames = [{"data": [frame], "name": i} for i, frame in enumerate(Image_rgb)] else: torch_depth = self.depth_image[index, ..., 0] scale = 10 ** torch.log10(255.0 / torch_depth.detach().max()).floor().item() numpy_depth = (torch_depth * scale).detach().cpu().numpy().astype("uint8") Image_depth = [ numpy_to_plotly_image(d, i, True, scale) for i, d in enumerate(numpy_depth) ] frames = [ {"name": i, "data": list(frame), "traces": [0, 1]} for i, frame in enumerate(zip(Image_rgb, Image_depth)) ] if not as_figure: return frames steps = [ {"args": [[i], frame_args(0)], "label": i, "method": "animate"} for i in range(self._L) ] sliders = [ { "active": 0, "yanchor": "top", "xanchor": "left", "currentvalue": {"prefix": "Frame: "}, "pad": {"b": 10, "t": 60}, "len": 0.9, "x": 0.1, "y": 0, "steps": steps, } ] updatemenus = [ { "buttons": [ { "args": [None, frame_args(ms_per_frame)], "label": "&#9654;", "method": "animate", }, { "args": [[None], frame_args(0)], "label": "&#9724;", "method": "animate", }, ], "direction": "left", "pad": {"r": 10, "t": 70}, "showactive": False, "type": "buttons", "x": 0.1, "xanchor": "right", "y": 0, "yanchor": "top", } ] if not include_depth: fig = make_subplots(rows=1, cols=1, subplot_titles=("RGB",)) fig.add_traces(frames[0]["data"][0]) else: fig = make_subplots( rows=2, cols=1, subplot_titles=("RGB", "Depth"), shared_xaxes=True, shared_yaxes=False, vertical_spacing=0.1, ) fig.add_trace(frames[0]["data"][0], row=1, col=1) # initial rgb frame fig.add_trace(frames[0]["data"][1], row=2, col=1) # initial depth frame fig.update_layout(scene=dict(aspectmode="data")) fig.update_layout( autosize=False, height=1080 ) # autosize is not perfect with subplots fig.update(frames=frames) fig.update_layout(updatemenus=updatemenus, sliders=sliders) return fig
# TODO: rotation + transformation: keep in mind to apply to vertices, normals *and* poses def _assert_shape(self, value: torch.Tensor, shape: tuple): r"""Asserts if value is a tensor with same shape as `shape` Args: value (torch.Tensor): Tensor to check shape of shape (tuple): Expected shape of value """ if not isinstance(value, torch.Tensor): raise TypeError("value must be torch.Tensor. Got {}".format(type(value))) if value.shape != shape: msg = "Expected value to have shape {0}. Got {1} instead" raise ValueError(msg.format(shape, value.shape))