Source code for gradslam.odometry.icputils

from typing import Optional, Union

from chamferdist.chamfer import knn_points
import torch

from ..geometry.geometryutils import transform_pointcloud
from ..geometry.se3utils import se3_exp
from ..structures.pointclouds import Pointclouds
from ..structures.rgbdimages import RGBDImages


__all__ = [
    "solve_linear_system",
    "gauss_newton_solve",
    "point_to_plane_ICP",
    "point_to_plane_gradICP",
    "downsample_pointclouds",
    "downsample_rgbdimages",
]


[docs]def solve_linear_system( A: torch.Tensor, b: torch.Tensor, damp: Union[float, torch.Tensor] = 1e-8 ): r"""Solves the normal equations of a linear system Ax = b, given the constraint matrix A and the coefficient vector b. Note that this solves the normal equations, not the linear system. That is, solves :math:`A^T A x = A^T b`, not :math:`Ax = b`. Args: A (torch.Tensor): The constraint matrix of the linear system. b (torch.Tensor): The coefficient vector of the linear system. damp (float or torch.Tensor): Damping coefficient to optionally condition the linear system (in practice, a damping coefficient of :math:`\rho` means that we are solving a modified linear system that adds a tiny :math:`\rho` to each diagonal element of the constraint matrix :math:`A`, so that the linear system becomes :math:`(A^TA + \rho I)x = b`, where :math:`I` is the identity matrix of shape :math:`(\text{num_of_variables}, \text{num_of_variables})`. Default: 1e-8 Returns: torch.Tensor: Solution vector of the normal equations of the linear system Shape: - A: :math:`(\text{num_of_equations}, \text{num_of_variables})` - b: :math:`(\text{num_of_equations}, 1)` - Output: :math:`(\text{num_of_variables}, 1)` """ if not torch.is_tensor(A): raise TypeError( "Expected A to be of type torch.Tensor. Got {0}.".format(type(A)) ) if not torch.is_tensor(b): raise TypeError( "Expected b to be of type torch.Tensor. Got {0}.".format(type(b)) ) if not (isinstance(damp, float) or torch.is_tensor(damp)): raise TypeError( "Expected damp to be of type float or torch.Tensor. Got {0}.".format( type(damp) ) ) if torch.is_tensor(damp) and damp.ndim != 0: raise ValueError( "Expected torch.Tensor damp to have ndim=0 (scalar). Got {0}.".format( damp.ndim ) ) if A.ndim != 2: raise ValueError("A should have ndim=2, but had ndim={}".format(A.ndim)) if b.ndim != 2: raise ValueError("b should have ndim=2, but had ndim={}".format(b.ndim)) if b.shape[1] != 1: raise ValueError("b.shape[1] should 1, but was {0}".format(b.shape[1])) if A.shape[0] != b.shape[0]: raise ValueError( "A.shape[0] and b.shape[0] should be equal ({0} != {1})".format( A.shape[0], b.shape[0] ) ) damp = ( damp if torch.is_tensor(damp) else torch.tensor(damp, dtype=A.dtype, device=A.device) ) # Construct the normal equations A_t = torch.transpose(A, 0, 1) damp_matrix = torch.eye(A.shape[1]).to(A.device) At_A = torch.matmul(A_t, A) + damp_matrix * damp # Solve the normal equations (for now, by inversion!) return torch.matmul(torch.inverse(At_A), torch.matmul(A_t, b))
[docs]def gauss_newton_solve( src_pc: torch.Tensor, tgt_pc: torch.Tensor, tgt_normals: torch.Tensor, dist_thresh: Union[float, int, None] = None, ): r"""Computes Gauss Newton step by forming linear equation. Points from `src_pc` which have a distance greater than `dist_thresh` to the closest point in `tgt_pc` will be filtered. Args: src_pc (torch.Tensor): Source pointcloud (the pointcloud that needs warping). tgt_pc (torch.Tensor): Target pointcloud (the pointcloud to which the source pointcloud must be warped to). tgt_normals (torch.Tensor): Per-point normal vectors for each point in the target pointcloud. dist_thresh (float or int or None): Distance threshold for removing `src_pc` points distant from `tgt_pc`. Default: None Returns: tuple: tuple containing: - A (torch.Tensor): linear system equation - b (torch.Tensor): linear system residual - chamfer_indices (torch.Tensor): Index of the closest point in `tgt_pc` for each point in `src_pc` that was not filtered out. Shape: - src_pc: :math:`(1, N_s, 3)` - tgt_pc: :math:`(1, N_t, 3)` - tgt_normals: :math:`(1, N_t, 3)` - A: :math:`(N_sf, 6)` where :math:`N_sf \leq N_s` - b: :math:`(N_sf, 1)` where :math:`N_sf \leq N_s` - chamfer_indices: :math:`(1, N_sf)` where :math:`N_sf \leq N_s` """ if not torch.is_tensor(src_pc): raise TypeError( "Expected src_pc to be of type torch.Tensor. Got {0}.".format(type(src_pc)) ) if not torch.is_tensor(tgt_pc): raise TypeError( "Expected tgt_pc to be of type torch.Tensor. Got {0}.".format(type(tgt_pc)) ) if not torch.is_tensor(tgt_normals): raise TypeError( "Expected tgt_normals to be of type torch.Tensor. Got {0}.".format( type(tgt_normals) ) ) if not ( isinstance(dist_thresh, float) or isinstance(dist_thresh, int) or dist_thresh is None ): raise TypeError( "Expected dist_thresh to be of type float or int. Got {0}.".format( type(dist_thresh) ) ) if src_pc.ndim != 3: raise ValueError( "src_pc should have ndim=3, but had ndim={}".format(src_pc.ndim) ) if tgt_pc.ndim != 3: raise ValueError( "tgt_pc should have ndim=3, but had ndim={}".format(tgt_pc.ndim) ) if tgt_normals.ndim != 3: raise ValueError( "tgt_normals should have ndim=3, but had ndim={}".format(tgt_normals.ndim) ) if src_pc.shape[0] != 1: raise ValueError( "src_pc.shape[0] should be 1, but was {} instead".format(src_pc.shape[0]) ) if tgt_pc.shape[0] != 1: raise ValueError( "tgt_pc.shape[0] should be 1, but was {} instead".format(tgt_pc.shape[0]) ) if tgt_normals.shape[0] != 1: raise ValueError( "tgt_normals.shape[0] should be 1, but was {} instead".format( tgt_normals.shape[0] ) ) if tgt_pc.shape[1] != tgt_normals.shape[1]: raise ValueError( "tgt_pc.shape[1] and tgt_normals.shape[1] must be equal. Got {0}!={1}".format( tgt_pc.shape[1], tgt_normals.shape[1] ) ) if src_pc.shape[2] != 3: raise ValueError( "src_pc.shape[2] should be 3, but was {} instead".format(src_pc.shape[2]) ) if tgt_pc.shape[2] != 3: raise ValueError( "tgt_pc.shape[2] should be 3, but was {} instead".format(tgt_pc.shape[2]) ) if tgt_normals.shape[2] != 3: raise ValueError( "tgt_normals.shape[2] should be 3, but was {} instead".format( tgt_normals.shape[2] ) ) src_pc = src_pc.contiguous() tgt_pc = tgt_pc.contiguous() tgt_normals = tgt_normals.contiguous() _KNN = knn_points(src_pc, tgt_pc) dist1, idx1 = _KNN.dists.squeeze(-1), _KNN.idx.squeeze(-1) dist_filter = ( torch.ones_like(dist1[0], dtype=torch.bool) if dist_thresh is None else dist1[0] < dist_thresh ) chamfer_indices = idx1[0][dist_filter].long() sx = src_pc[0, dist_filter, 0].view(-1, 1) sy = src_pc[0, dist_filter, 1].view(-1, 1) sz = src_pc[0, dist_filter, 2].view(-1, 1) # Closest point/normal to each source point assoc_pts = torch.index_select(tgt_pc, 1, chamfer_indices) assoc_normals = torch.index_select(tgt_normals, 1, chamfer_indices) # Closest destination point to each source point dx = assoc_pts[0, :, 0].view(-1, 1) dy = assoc_pts[0, :, 1].view(-1, 1) dz = assoc_pts[0, :, 2].view(-1, 1) nx = assoc_normals[0, :, 0].view(-1, 1) ny = assoc_normals[0, :, 1].view(-1, 1) nz = assoc_normals[0, :, 2].view(-1, 1) A = torch.cat( [nx, ny, nz, nz * sy - ny * sz, nx * sz - nz * sx, ny * sx - nx * sy], 1 ) b = nx * (dx - sx) + ny * (dy - sy) + nz * (dz - sz) return A, b, chamfer_indices
[docs]def point_to_plane_ICP( src_pc: torch.Tensor, tgt_pc: torch.Tensor, tgt_normals: torch.Tensor, initial_transform: Optional[torch.Tensor] = None, numiters: int = 20, damp: float = 1e-8, dist_thresh: Union[float, int, None] = None, ): r"""Computes a rigid transformation between `tgt_pc` (target pointcloud) and `src_pc` (source pointcloud) using a point-to-plane error metric and the LM (Levenberg–Marquardt) solver. Args: src_pc (torch.Tensor): Source pointcloud (the pointcloud that needs warping). tgt_pc (torch.Tensor): Target pointcloud (the pointcloud to which the source pointcloud must be warped to). tgt_normals (torch.Tensor): Per-point normal vectors for each point in the target pointcloud. initial_transform (torch.Tensor or None): The initial estimate of the transformation between 'src_pc' and 'tgt_pc'. If None, will use the identity matrix as the initial transform. Default: None numiters (int): Number of iterations to run the optimization for. Default: 20 damp (float): Damping coefficient for nonlinear least-squares. Default: 1e-8 dist_thresh (float or int or None): Distance threshold for removing `src_pc` points distant from `tgt_pc`. Default: None Returns: tuple: tuple containing: - transform (torch.Tensor): linear system residual - chamfer_indices (torch.Tensor): Index of the closest point in `tgt_pc` for each point in `src_pc` that was not filtered out. Shape: - src_pc: :math:`(1, N_s, 3)` - tgt_pc: :math:`(1, N_t, 3)` - tgt_normals: :math:`(1, N_t, 3)` - initial_transform: :math:`(4, 4)` - transform: :math:`(4, 4)` - chamfer_indices: :math:`(1, N_sf)` where :math:`N_sf \leq N_s` """ if not torch.is_tensor(src_pc): raise TypeError( "Expected src_pc to be of type torch.Tensor. Got {0}.".format(type(src_pc)) ) if not torch.is_tensor(tgt_pc): raise TypeError( "Expected tgt_pc to be of type torch.Tensor. Got {0}.".format(type(tgt_pc)) ) if not torch.is_tensor(tgt_normals): raise TypeError( "Expected tgt_normals to be of type torch.Tensor. Got {0}.".format( type(tgt_normals) ) ) if not (torch.is_tensor(initial_transform) or initial_transform is None): raise TypeError( "Expected initial_transform to be of type torch.Tensor. Got {0}.".format( type(initial_transform) ) ) if not isinstance(numiters, int): raise TypeError( "Expected numiters to be of type int. Got {0}.".format(type(numiters)) ) if initial_transform.ndim != 2: raise ValueError( "Expected initial_transform.ndim to be 2. Got {0}.".format( initial_transform.ndim ) ) if not (initial_transform.shape[0] == 4 and initial_transform.shape[1] == 4): raise ValueError( "Expected initial_transform.shape to be (4, 4). Got {0}.".format( initial_transform.shape ) ) src_pc = src_pc.contiguous() tgt_pc = tgt_pc.contiguous() tgt_normals = tgt_normals.contiguous() dtype = src_pc.dtype device = src_pc.device damp = torch.tensor(damp, dtype=dtype, device=device) # include initial transform initial_transform = ( torch.eye(4, dtype=dtype, device=device) if initial_transform is None else initial_transform ) src_pc = transform_pointcloud(src_pc[0], initial_transform).unsqueeze(0) transform = initial_transform for it in range(numiters): # Form the linear system and compute the residual A, b, chamfer_indices = gauss_newton_solve( src_pc, tgt_pc, tgt_normals, dist_thresh ) residual = b[:, 0] # Solve the linear system xi = solve_linear_system(A, b, damp) # Apply exponential to find the transform residual_transform = se3_exp(xi) # Find error err = torch.dot(residual.t(), residual) pc_error = torch.sqrt(torch.sum((torch.mm(A, xi) - b) ** 2)) # Lookahead error (for LM) # calculate transformed cloud one_step_pc = transform_pointcloud(src_pc[0], residual_transform).unsqueeze(0) # Form new linear system and compute one-step residual _, one_step_b, chamfer_indices_onestep = gauss_newton_solve( one_step_pc, tgt_pc, tgt_normals, dist_thresh ) one_step_residual = one_step_b[:, 0] # Find new error new_err = torch.dot(one_step_residual.t(), one_step_residual) if new_err < err: # We are in a trust region src_pc = one_step_pc damp = damp / 2 # update transform transform = torch.mm(residual_transform, transform) else: damp = damp * 2 return transform, chamfer_indices
[docs]def point_to_plane_gradICP( src_pc: torch.Tensor, tgt_pc: torch.Tensor, tgt_normals: torch.Tensor, initial_transform: Optional[torch.Tensor] = None, numiters: int = 20, damp: float = 1e-8, dist_thresh: Union[float, int, None] = None, lambda_max: Union[float, int] = 2.0, B: Union[float, int] = 1.0, B2: Union[float, int] = 1.0, nu: Union[float, int] = 200.0, ): r"""Computes a rigid transformation between `tgt_pc` (target pointcloud) and `src_pc` (source pointcloud) using a point-to-plane error metric and gradLM (:math:`\nabla LM`) solver (See gradLM section of `the gradSLAM paper <https://arxiv.org/abs/1910.10672>`__). The iterate and damping coefficient are updated by: .. math:: lambda_1 = Q_\lambda(r_0, r_1) & = \lambda_{min} + \frac{\lambda_{max} - \lambda_{min}}{1 + e^{-B (r_1 - r_0)}} \\ Q_x(r_0, r_1) & = x_0 + \frac{\delta x_0}{\sqrt[nu]{1 + e^{-B2*(r_1 - r_0)}}}` Args: src_pc (torch.Tensor): Source pointcloud (the pointcloud that needs warping). tgt_pc (torch.Tensor): Target pointcloud (the pointcloud to which the source pointcloud must be warped to). tgt_normals (torch.Tensor): Per-point normal vectors for each point in the target pointcloud. initial_transform (torch.Tensor or None): The initial estimate of the transformation between 'src_pc' and 'tgt_pc'. If None, will use the identity matrix as the initial transform. Default: None numiters (int): Number of iterations to run the optimization for. Default: 20 damp (float): Damping coefficient for nonlinear least-squares. Default: 1e-8 dist_thresh (float or int or None): Distance threshold for removing `src_pc` points distant from `tgt_pc`. Default: None lambda_max (float or int): Maximum value the damping function can assume (`lambda_min` will be :math:`\frac{1}{\text{lambda_max}}`) B (float or int): gradLM falloff control parameter B2 (float or int): gradLM control parameter nu (float or int): gradLM control parameter Returns: tuple: tuple containing: - transform (torch.Tensor): linear system residual - chamfer_indices (torch.Tensor): Index of the closest point in `tgt_pc` for each point in `src_pc` that was not filtered out. Shape: - src_pc: :math:`(1, N_s, 3)` - tgt_pc: :math:`(1, N_t, 3)` - tgt_normals: :math:`(1, N_t, 3)` - initial_transform: :math:`(4, 4)` - transform: :math:`(4, 4)` - chamfer_indices: :math:`(1, N_sf)` where :math:`N_sf \leq N_s` """ if not torch.is_tensor(src_pc): raise TypeError( "Expected src_pc to be of type torch.Tensor. Got {0}.".format(type(src_pc)) ) if not torch.is_tensor(tgt_pc): raise TypeError( "Expected tgt_pc to be of type torch.Tensor. Got {0}.".format(type(tgt_pc)) ) if not torch.is_tensor(tgt_normals): raise TypeError( "Expected tgt_normals to be of type torch.Tensor. Got {0}.".format( type(tgt_normals) ) ) if not (torch.is_tensor(initial_transform) or initial_transform is None): raise TypeError( "Expected initial_transform to be of type torch.Tensor. Got {0}.".format( type(initial_transform) ) ) if not isinstance(numiters, int): raise TypeError( "Expected numiters to be of type int. Got {0}.".format(type(numiters)) ) if not (isinstance(lambda_max, float) or isinstance(lambda_max, int)): raise TypeError( "Expected lambda_max to be of type float or int; got {0}".format( type(lambda_max) ) ) if not (isinstance(B, float) or isinstance(B, int)): raise TypeError( "Expected B to be of type float or int; got {0}".format(type(B)) ) if not (isinstance(B2, float) or isinstance(B2, int)): raise TypeError( "Expected B2 to be of type float or int; got {0}".format(type(B2)) ) if not (isinstance(nu, float) or isinstance(nu, int)): raise TypeError( "Expected nu to be of type float or int; got {0}".format(type(nu)) ) if initial_transform.ndim != 2: raise ValueError( "Expected initial_transform.ndim to be 2. Got {0}.".format( initial_transform.ndim ) ) if not (initial_transform.shape[0] == 4 and initial_transform.shape[1] == 4): raise ValueError( "Expected initial_transform.shape to be (4, 4). Got {0}.".format( initial_transform.shape ) ) src_pc = src_pc.contiguous() tgt_pc = tgt_pc.contiguous() tgt_normals = tgt_normals.contiguous() dtype = src_pc.dtype device = src_pc.device damp = torch.tensor(damp, dtype=dtype, device=device) lambda_min = 1 / lambda_max # include initial transform initial_transform = ( torch.eye(4, dtype=dtype, device=device) if initial_transform is None else initial_transform ) src_pc = transform_pointcloud(src_pc[0], initial_transform).unsqueeze(0) transform = initial_transform for it in range(numiters): # Form the linear system and compute the residual A, b, chamfer_indices = gauss_newton_solve( src_pc, tgt_pc, tgt_normals, dist_thresh ) residual = b[:, 0] # Solve the linear system xi = solve_linear_system(A, b, damp) # Apply exponential to find the transform residual_transform = se3_exp(xi) # Find error err = torch.dot(residual.t(), residual) pc_error = torch.sqrt(torch.sum((torch.mm(A, xi) - b) ** 2)) # Lookahead error (for LM) # calculate transformed cloud one_step_pc = transform_pointcloud(src_pc[0], residual_transform).unsqueeze(0) # Form new linear system and compute one-step residual _, one_step_b, chamfer_indices_onestep = gauss_newton_solve( one_step_pc, tgt_pc, tgt_normals, dist_thresh ) one_step_residual = one_step_b[:, 0] # Find new error new_err = torch.dot(one_step_residual.t(), one_step_residual) # smooth LM routine errdiff = new_err - err # NEW ADDITION: clamping to ensure gradient flow errdiff = errdiff.clamp(-70.0, 70.0) damp_new = lambda_min + (lambda_max - lambda_min) / ( 1 + torch.exp(-B * errdiff) ) damp = damp * damp_new # residual transform should be perturbed by sigmoid sigmoid = 1 / ((1 + torch.exp(-B2 * errdiff)) ** (1 / nu)) # calculate perturbation residual_transform = se3_exp(sigmoid * xi) src_pc = transform_pointcloud(src_pc[0], residual_transform).unsqueeze(0) transform = torch.mm(residual_transform, transform) return transform, chamfer_indices
[docs]def downsample_pointclouds( pointclouds: Pointclouds, pc2im_bnhw: torch.Tensor, ds_ratio: int ) -> Pointclouds: r"""Downsamples active points of pointclouds (points that project inside the live frame) and removes non-active points. Args: pointclouds (gradslam.Pointclouds): Pointclouds to downsample pc2im_bnhw (torch.Tensor): Active map points lookup table. Each row contains batch index `b`, point index (in pointclouds) `n`, and height and width index after projection to live frame `h` and `w` respectively. ds_ratio (int): Downsampling ratio Returns: gradslam.Pointclouds: Downsampled pointclouds Shape: - pc2im_bnhw: :math:`(\text{num_active_map_points}, 4)` """ if not isinstance(pointclouds, Pointclouds): raise TypeError( "Expected pointclouds to be of type gradslam.Pointclouds. Got {0}.".format( type(pointclouds) ) ) if not torch.is_tensor(pc2im_bnhw): raise TypeError( "Expected pc2im_bnhw to be of type torch.Tensor. Got {0}.".format( type(pc2im_bnhw) ) ) if not isinstance(ds_ratio, int): raise TypeError( "Expected ds_ratio to be of type int. Got {0}.".format(type(ds_ratio)) ) if pc2im_bnhw.ndim != 2: raise ValueError( "Expected pc2im_bnhw to have ndim=2. Got {0}.".format(pc2im_bnhw.ndim) ) if pc2im_bnhw.shape[1] != 4: raise ValueError( "pc2im_bnhw.shape[1] must be 4, but was {0}.".format(pc2im_bnhw.shape[1]) ) B = len(pointclouds) # Find indices of points to keep after downsampling pc2im_bnhw = pc2im_bnhw[pc2im_bnhw[..., 2] % ds_ratio == 0] pc2im_bnhw = pc2im_bnhw[pc2im_bnhw[..., 3] % ds_ratio == 0] # Downsample points and normals maps_points = [ pointclouds.points_list[b][pc2im_bnhw[pc2im_bnhw[..., 0] == b][..., 1]] for b in range(B) ] maps_normals = ( None if pointclouds.normals_list is None else [ pointclouds.normals_list[b][pc2im_bnhw[pc2im_bnhw[..., 0] == b][..., 1]] for b in range(B) ] ) maps_colors = ( None if pointclouds.colors_list is None else [ pointclouds.colors_list[b][pc2im_bnhw[pc2im_bnhw[..., 0] == b][..., 1]] for b in range(B) ] ) return Pointclouds(points=maps_points, normals=maps_normals, colors=maps_colors)
[docs]def downsample_rgbdimages(rgbdimages: RGBDImages, ds_ratio: int) -> Pointclouds: r"""Downsamples points and normals of RGBDImages and returns a gradslam.Pointclouds object Args: rgbdimages (gradslam.RGBDImages): RGBDImages to downsample ds_ratio (int): Downsampling ratio Returns: gradslam.Pointclouds: Downsampled points and normals """ if not isinstance(rgbdimages, RGBDImages): raise TypeError( "Expected rgbdimages to be of type gradslam.RGBDImages. Got {0}.".format( type(rgbdimages) ) ) if not isinstance(ds_ratio, int): raise TypeError( "Expected ds_ratio to be of type int. Got {0}.".format(type(ds_ratio)) ) if rgbdimages.shape[1] != 1: raise ValueError( "Sequence length of rgbdimages must be 1, but was {0}.".format( rgbdimages.shape[1] ) ) B = len(rgbdimages) # Valid depths mask mask = rgbdimages.valid_depth_mask.squeeze(-1)[..., ::ds_ratio, ::ds_ratio] # Downsample points and normals points = [ rgbdimages.global_vertex_map[b][..., ::ds_ratio, ::ds_ratio, :][mask[b]] for b in range(B) ] normals = [ rgbdimages.global_normal_map[b][..., ::ds_ratio, ::ds_ratio, :][mask[b]] for b in range(B) ] colors = [ rgbdimages.rgb_image[b][..., ::ds_ratio, ::ds_ratio, :][mask[b]] for b in range(B) ] return Pointclouds(points=points, normals=normals, colors=colors)