Skip to content

Commit

Permalink
code release
Browse files Browse the repository at this point in the history
  • Loading branch information
qinzheng93 committed Jun 15, 2023
1 parent 307b419 commit 3f20101
Show file tree
Hide file tree
Showing 15 changed files with 1,147 additions and 5 deletions.
121 changes: 116 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,126 @@
# Deep Graph-based Spatial Consistency for Robust Non-rigid Point Cloud Registration
# Deep Graph-Based Spatial Consistency for Robust Non-Rigid Point Cloud Registration

PyTorch implementation of the paper:

Deep Graph-based Spatial Consistency for Robust Non-rigid Point Cloud Registration.
[Deep Graph-Based Spatial Consistency for Robust Non-Rigid Point Cloud Registration](http://arxiv.org/abs/2303.09950)

[Zheng Qin](https://scholar.google.com/citations?user=DnHBAN0AAAAJ), [Hao Yu](https://scholar.google.com/citations?user=g7JfRn4AAAAJ), Changjian Wang, Yuxing Peng, and [Kai Xu](https://scholar.google.com/citations?user=GuVkg-8AAAAJ).
[Zheng Qin](https://scholar.google.com/citations?user=DnHBAN0AAAAJ), [Hao Yu](https://scholar.google.com/citations?user=g7JfRn4AAAAJ),
Changjian Wang, Yuxing Peng, and [Kai Xu](https://scholar.google.com/citations?user=GuVkg-8AAAAJ).

## Introduction

We study the problem of outlier correspondence pruning for non-rigid point cloud registration. In rigid registration, spatial consistency has been a commonly used criterion to discriminate outliers from inliers. It measures the compatibility of two correspondences by the discrepancy between the respective distances in two point clouds. However, spatial consistency no longer holds in non-rigid cases and outlier rejection for non-rigid registration has not been well studied. In this work, we propose Graph-based Spatial Consistency Network (GraphSCNet) to filter outliers for non-rigid registration. Our method is based on the fact that non-rigid deformations are usually locally rigid, or local shape preserving. We first design a local spatial consistency measure over the deformation graph of the point cloud, which evaluates the spatial compatibility only between the correspondences in the vicinity of a graph node. An attention-based non-rigid correspondence embedding module is then devised to learn a robust representation of non-rigid correspondences from local spatial consistency. Despite its simplicity, GraphSCNet effectively improves the quality of the putative correspondences and attains state-of-the-art performance on three challenging benchmarks.
We study the problem of outlier correspondence pruning for non-rigid point cloud registration. In rigid registration,
spatial consistency has been a commonly used criterion to discriminate outliers from inliers. It measures the
compatibility of two correspondences by the discrepancy between the respective distances in two point clouds. However,
spatial consistency no longer holds in non-rigid cases and outlier rejection for non-rigid registration has not been
well studied. In this work, we propose Graph-based Spatial Consistency Network (GraphSCNet) to filter outliers for
non-rigid registration. Our method is based on the fact that non-rigid deformations are usually locally rigid, or local
shape preserving. We first design a local spatial consistency measure over the deformation graph of the point cloud,
which evaluates the spatial compatibility only between the correspondences in the vicinity of a graph node. An
attention-based non-rigid correspondence embedding module is then devised to learn a robust representation of non-rigid
correspondences from local spatial consistency. Despite its simplicity, GraphSCNet effectively improves the quality of
the putative correspondences and attains state-of-the-art performance on three challenging benchmarks.

![](assets/teaser.png)

## News

2023.02.28: This work is accepted by CVPR 2023. Code and models will be released soon.
2023.06.15: Code and models on 4DMatch released.

2023.02.28: This work is accepted by CVPR 2023.

## Installation

Please use the following command for installation:

```bash
# 1. It is recommended to create a new environment
conda create -n geotransformer python==3.8
conda activate geotransformer

# 2. Install vision3d following https://github.com/qinzheng93/vision3d
```

The code has been tested on Python 3.8, PyTorch 1.13.1, Ubuntu 22.04, GCC 11.3 and CUDA 11.7, but it should work with
other configurations.

## 4DMatch

### Data preparation

The 4DMatch dataset can be downloaded from [DeformationPyramid](https://github.com/rabbityl/DeformationPyramid). We
provide the correspondences extracted by GeoTransformer in the release page. The data should be organized as follows:

```text
--data--4DMatch--train--abe_CoverToStand
| |--...
|--val--amy_Situps
| |--...
|--4DMatch-F--AJ_SoccerPass
| |--...
|--4DLoMatch-F--AJ_SoccerPass
| |--...
|--correspondences--val--amy_Situps
| |--...
|--4DMatch-F--AJ_SoccerPass
| |--...
|--4DLoMatch-F--AJ_SoccerPass
|--...
```

### Training

The code for 4DMatch is in `experiments/graphscnet.4dmatch.geotransformer`. Use the following command for training.

```bash
CUDA_VISIBLE_DEVICES=0 python trainval.py
```

### Testing

Use the following command for testing.

```bash
# 4DMatch
CUDA_VISIBLE_DEVICES=0 python test.py --test_epoch=EPOCH --benchmark=4DMatch-F
# 4DLoMatch
CUDA_VISIBLE_DEVICES=0 python test.py --test_epoch=EPOCH --benchmark=4DLoMatch-F
```

`EPOCH` is the epoch id.

We also provide pretrained weights in `weights`, use the following command to test the pretrained weights.

```bash
CUDA_VISIBLE_DEVICES=0 python test.py --checkpoint=/path/to/GraphSCNet/weights/graphscnet.pth --benchmark=4DMatch-F
```

Replace `4DMatch` with `4DLoMatch` to evaluate on 4DLoMatch.

### Results

| Benchmark | Prec | Recall | EPE | AccS | AccR | OR |
|:----------|:----:|:------:|:-----:|:----:|:----:|:----:|
| 4DMatch | 92.2 | 96.9 | 0.043 | 72.3 | 84.4 | 9.4 |
| 4DLoMatch | 82.6 | 86.8 | 0.121 | 41.0 | 58.3 | 21.0 |

## Citation

```bibtex
@inproceedings{qin2023deep,
title={Deep Graph-Based Spatial Consistency for Robust Non-Rigid Point Cloud Registration},
author={Zheng Qin and Hao Yu and Changjian Wang and Yuxing Peng and Kai Xu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2023},
pages={5394-5403}
}
```

## Acknowledgements

- [vision3d](https://github.com/qinzheng93/vision3d)
- [GeoTransformer](https://github.com/qinzheng93/GeoTransformer)
- [PointDSC](https://github.com/XuyangBai/PointDSC)
- [lepard](https://github.com/rabbityl/lepard)
- [DeformationPyramid](https://github.com/rabbityl/DeformationPyramid)
Binary file added assets/teaser.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
135 changes: 135 additions & 0 deletions experiments/graphscnet.4dmatch.geotransformer/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import argparse
import os
import os.path as osp

from easydict import EasyDict as edict

from vision3d.utils.io import ensure_dir

_C = edict()

# exp
_C.exp = edict()
_C.exp.name = osp.basename(osp.dirname(osp.realpath(__file__)))
_C.exp.working_dir = osp.dirname(osp.realpath(__file__))
_C.exp.output_dir = osp.join("..", "..", "outputs", _C.exp.name)
_C.exp.checkpoint_dir = osp.join(_C.exp.output_dir, "checkpoints")
_C.exp.log_dir = osp.join(_C.exp.output_dir, "logs")
_C.exp.event_dir = osp.join(_C.exp.output_dir, "events")
_C.exp.cache_dir = osp.join(_C.exp.output_dir, "cache")
_C.exp.result_dir = osp.join(_C.exp.output_dir, "results")
_C.exp.seed = 7351

ensure_dir(_C.exp.output_dir)
ensure_dir(_C.exp.checkpoint_dir)
ensure_dir(_C.exp.log_dir)
ensure_dir(_C.exp.event_dir)
ensure_dir(_C.exp.cache_dir)
ensure_dir(_C.exp.result_dir)

# data
_C.data = edict()
_C.data.dataset_dir = "../../data/4DMatch"

# train data
_C.train = edict()
_C.train.batch_size = 1
_C.train.num_workers = 8
_C.train.use_augmentation = True
_C.train.return_corr_indices = True

# test data
_C.test = edict()
_C.test.batch_size = 1
_C.test.num_workers = 8
_C.test.return_corr_indices = True
_C.test.shape_names = None

# evaluation
_C.eval = edict()
_C.eval.acceptance_score = 0.4
_C.eval.acceptance_radius = 0.04
_C.eval.distance_limit = 0.1

# trainer
_C.trainer = edict()
_C.trainer.max_epoch = 40
_C.trainer.grad_acc_steps = 1

# optimizer
_C.optimizer = edict()
_C.optimizer.type = "Adam"
_C.optimizer.lr = 1e-4
_C.optimizer.weight_decay = 1e-6

# scheduler
_C.scheduler = edict()
_C.scheduler.type = "Step"
_C.scheduler.gamma = 0.95
_C.scheduler.step_size = 1

# model - Global
_C.model = edict()
_C.model.min_local_correspondences = 3
_C.model.max_local_correspondences = 128

_C.model.deformation_graph = edict()
_C.model.deformation_graph.num_anchors = 6
_C.model.deformation_graph.node_coverage = 0.08

# model - transformer
_C.model.transformer = edict()
_C.model.transformer.input_dim = 6
_C.model.transformer.hidden_dim = 256
_C.model.transformer.output_dim = 256
_C.model.transformer.num_heads = 4
_C.model.transformer.num_blocks = 3
_C.model.transformer.num_layers_per_block = 2
_C.model.transformer.sigma_d = 0.08
_C.model.transformer.dropout = None
_C.model.transformer.activation_fn = "ReLU"
_C.model.transformer.embedding_k = -1
_C.model.transformer.embedding_dim = 1

# model - classifier
_C.model.classifier = edict()
_C.model.classifier.input_dim = 256
_C.model.classifier.dropout = None

# Non-rigid ICP
_C.model.nicp = edict()
_C.model.nicp.corr_lambda = 5.0
_C.model.nicp.arap_lambda = 1.0
_C.model.nicp.lm_lambda = 0.01
_C.model.nicp.num_iterations = 5

# loss
_C.loss = edict()
_C.loss.focal_loss = edict()
_C.loss.focal_loss.weight = 1.0
_C.loss.consistency_loss = edict()
_C.loss.consistency_loss.weight = 1.0


def make_cfg():
return _C


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--link_output", dest="link_output", action="store_true", help="link output dir"
)
args = parser.parse_args()
return args


def main():
cfg = make_cfg()
args = parse_args()
if args.link_output:
os.symlink(cfg.output_dir, "output")


if __name__ == "__main__":
main()
Loading

0 comments on commit 3f20101

Please sign in to comment.