Source code for gradslam.slam.icpslam

from typing import Optional, Union
import warnings

import torch
import torch.nn as nn
from kornia.geometry.linalg import compose_transformations

from ..odometry.icp import ICPOdometryProvider
from ..odometry.gradicp import GradICPOdometryProvider
from ..odometry.icputils import downsample_pointclouds, downsample_rgbdimages
from ..structures.pointclouds import Pointclouds
from ..structures.rgbdimages import RGBDImages
from .fusionutils import find_active_map_points, update_map_aggregate

__all__ = ["ICPSLAM"]


[docs]class ICPSLAM(nn.Module): r"""ICP-SLAM for batched sequences of RGB-D images. Args: odom (str): Odometry method to be used from {'gt', 'icp', 'gradicp'}. Default: 'gradicp' dsratio (int): Downsampling ratio to apply to input frames before ICP. Only used if `odom` is 'icp' or 'gradicp'. Default: 4 numiters (int): Number of iterations to run the optimization for. Only used if `odom` is 'icp' or 'gradicp'. Default: 20 damp (float or torch.Tensor): Damping coefficient for nonlinear least-squares. Only used if `odom` is 'icp' or 'gradicp'. Default: 1e-8 dist_thresh (float or int or None): Distance threshold for removing `src_pc` points distant from `tgt_pc`. Only used if `odom` is 'icp' or 'gradicp'. Default: None lambda_max (float or int): Maximum value the damping function can assume (`lambda_min` will be :math:`\frac{1}{\text{lambda_max}}`). Only used if `odom` is 'gradicp'. B (float or int): gradLM falloff control parameter (see GradICPOdometryProvider description). Only used if `odom` is 'gradicp'. B2 (float or int): gradLM control parameter (see GradICPOdometryProvider description). Only used if `odom` is 'gradicp'. nu (float or int): gradLM control parameter (see GradICPOdometryProvider description). Only used if `odom` is 'gradicp'. device (torch.device or str or None): The desired device of internal tensors. If None, sets device to be the CPU. Default: None Examples:: >>> rgbdimages = RGBDImages(colors, depths, intrinsics, poses) >>> slam = ICPSLAM(odom='gt') >>> pointclouds, poses = slam(rgbdimages) >>> o3d.visualization.draw_geometries([pointclouds.o3d(0)]) >>> rgbdimages = RGBDImages(colors, depths, intrinsics, poses) >>> slam = ICPSLAM(odom='gt') >>> pointclouds = Pointclouds() >>> pointclouds, new_poses = self.step(pointclouds, frames[:, 0], None) >>> frames.poses[:, :1] = new_poses >>> pointclouds, new_poses = self.step(pointclouds, frames[:, 1], frames[:, 0]) >>> rgbdimages = RGBDImages(colors, depths, intrinsics, poses) >>> slam = ICPSLAM(odom='gradicp') >>> pointclouds = Pointclouds() >>> pointclouds, new_poses = self.step(pointclouds, frames[:, 0], None) >>> frames.poses[:, :1] = new_poses >>> pointclouds, new_poses = self.step(pointclouds, frames[:, 1], frames[:, 0]) """ # TODO: Try to have nn.Module features supported def __init__( self, *, odom: str = "gradicp", dsratio: int = 4, numiters: int = 20, damp: float = 1e-8, dist_thresh: Union[float, int, None] = None, lambda_max: Union[float, int] = 2.0, B: Union[float, int] = 1.0, B2: Union[float, int] = 1.0, nu: Union[float, int] = 200.0, device: Union[torch.device, str, None] = None, ): super().__init__() if odom not in ["gt", "icp", "gradicp"]: msg = "odometry method ({}) not supported for PointFusion. ".format(odom) msg += "Currently supported odometry modules for PointFusion are: 'gt', 'icp', 'gradicp'" raise ValueError(msg) odomprov = None if odom == "icp": odomprov = ICPOdometryProvider(numiters, damp, dist_thresh) elif odom == "gradicp": odomprov = GradICPOdometryProvider( numiters, damp, dist_thresh, lambda_max, B, B2, nu ) self.odom = odom self.odomprov = odomprov self.dsratio = dsratio device = torch.device(device) if device is not None else torch.device("cpu") self.device = torch.Tensor().to(device).device
[docs] def forward(self, frames: RGBDImages): r"""Builds global map pointclouds from a batch of input RGBDImages with a batch size of :math:`B` and sequence length of :math:`L`. Args: frames (gradslam.RGBDImages): Input batch of frames with a sequence length of `L`. Returns: tuple: tuple containing: - pointclouds (gradslam.Pointclouds): Pointclouds object containing :math:`B` global maps - poses (torch.Tensor): Poses computed by the odometry method Shape: - poses: :math:`(B, L, 4, 4)` """ if not isinstance(frames, RGBDImages): raise TypeError( "Expected frames to be of type gradslam.RGBDImages. Got {0}.".format( type(frames) ) ) pointclouds = Pointclouds(device=self.device) batch_size, seq_len = frames.shape[:2] recovered_poses = torch.empty(batch_size, seq_len, 4, 4).to(self.device) prev_frame = None for s in range(seq_len): live_frame = frames[:, s].to(self.device) if s == 0 and live_frame.poses is None: live_frame.poses = ( torch.eye(4, dtype=torch.float, device=self.device) .view(1, 1, 4, 4) .repeat(batch_size, 1, 1, 1) ) pointclouds, live_frame.poses = self.step( pointclouds, live_frame, prev_frame, inplace=True ) prev_frame = live_frame if self.odom != "gt" else None recovered_poses[:, s] = live_frame.poses[:, 0] return pointclouds, recovered_poses
[docs] def step( self, pointclouds: Pointclouds, live_frame: RGBDImages, prev_frame: Optional[RGBDImages] = None, inplace: bool = False, ): r"""Updates global map pointclouds with a SLAM step on `live_frame`. If `prev_frame` is not None, computes the relative transformation between `live_frame` and `prev_frame` using the selected odometry provider. If `prev_frame` is None, use the pose from `live_frame`. Args: pointclouds (gradslam.Pointclouds): Input batch of pointcloud global maps live_frame (gradslam.RGBDImages): Input batch of live frames (at time step :math:`t`). Must have sequence length of 1. prev_frame (gradslam.RGBDImages or None): Input batch of previous frames (at time step :math:`t-1`). Must have sequence length of 1. If None, will (skip calling odometry provider and) use the pose from `live_frame`. Default: None inplace (bool): Can optionally update the pointclouds and live_frame poses in-place. Default: False Returns: tuple: tuple containing: - pointclouds (gradslam.Pointclouds): Updated :math:`B` global maps - poses (torch.Tensor): Poses for the live_frame batch Shape: - poses: :math:`(B, 1, 4, 4)` """ if not isinstance(live_frame, RGBDImages): raise TypeError( "Expected live_frame to be of type gradslam.RGBDImages. Got {0}.".format( type(live_frame) ) ) live_frame.poses = self._localize(pointclouds, live_frame, prev_frame) pointclouds = self._map(pointclouds, live_frame, inplace) return pointclouds, live_frame.poses
def _localize( self, pointclouds: Pointclouds, live_frame: RGBDImages, prev_frame: RGBDImages ): r"""Compute the poses for `live_frame`. If `prev_frame` is not None, computes the relative transformation between `live_frame` and `prev_frame` using the selected odometry provider. If `prev_frame` is None, use the pose from `live_frame`. Args: pointclouds (gradslam.Pointclouds): Input batch of pointcloud global maps live_frame (gradslam.RGBDImages): Input batch of live frames (at time step :math:`t`). Must have sequence length of 1. prev_frame (gradslam.RGBDImages or None): Input batch of previous frames (at time step :math:`t-1`). Must have sequence length of 1. If None, will (skip calling odometry provider and) use the pose from `live_frame`. Default: None Returns: torch.Tensor: Poses for the live_frame batch Shape: - Output: :math:`(B, 1, 4, 4)` """ if not isinstance(pointclouds, Pointclouds): raise TypeError( "Expected pointclouds to be of type gradslam.Pointclouds. Got {0}.".format( type(pointclouds) ) ) if not isinstance(live_frame, RGBDImages): raise TypeError( "Expected live_frame to be of type gradslam.RGBDImages. Got {0}.".format( type(live_frame) ) ) if not isinstance(prev_frame, (RGBDImages, type(None))): raise TypeError( "Expected prev_frame to be of type gradslam.RGBDImages or None. Got {0}.".format( type(prev_frame) ) ) if prev_frame is not None: if self.odom == "gt": warnings.warn( "`prev_frame` is not used when using `odom='gt'` (should be None)" ) elif not prev_frame.has_poses: raise ValueError("`prev_frame` should have poses, but did not.") if prev_frame is None and pointclouds.has_points and self.odom != "gt": msg = "`prev_frame` was None despite `{}` odometry method. Using `live_frame` poses.".format( self.odom ) warnings.warn(msg) if prev_frame is None or self.odom == "gt": if not live_frame.has_poses: raise ValueError( "`live_frame` must have poses when `prev_frame` is None or `odom='gt'`." ) return live_frame.poses if self.odom in ["icp", "gradicp"]: live_frame.poses = prev_frame.poses frames_pc = downsample_rgbdimages(live_frame, self.dsratio) pc2im_bnhw = find_active_map_points(pointclouds, prev_frame) maps_pc = downsample_pointclouds(pointclouds, pc2im_bnhw, self.dsratio) transform = self.odomprov.provide(maps_pc, frames_pc) return compose_transformations( transform.squeeze(1), prev_frame.poses.squeeze(1) ).unsqueeze(1) def _map( self, pointclouds: Pointclouds, live_frame: RGBDImages, inplace: bool = False ): r"""Updates global map pointclouds by aggregating them with points from `live_frame`. Args: pointclouds (gradslam.Pointclouds): Input batch of pointcloud global maps live_frame (gradslam.RGBDImages): Input batch of live frames (at time step :math:`t`). Must have sequence length of 1. inplace (bool): Can optionally update the pointclouds in-place. Default: False Returns: gradslam.Pointclouds: Updated :math:`B` global maps """ return update_map_aggregate(pointclouds, live_frame, inplace)