From 4a8283517fb77e7527e278e233df502f8b7a3266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Pautrat?= <32239569+rpautrat@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:36:58 +0100 Subject: [PATCH] Fix GlueStick training config (#29) * Update training config of GlueStick * Remove unnecessary checks in GlueStick * Update ETH3D download link * Update link to undistorted ETH3D * Update link to download DeepLSD --------- Co-authored-by: pautratr --- .../superpoint+lsd+gluestick-homography.yaml | 8 ++++---- .../configs/superpoint+lsd+gluestick-megadepth.yaml | 13 +++++++++---- gluefactory/datasets/eth3d.py | 2 +- gluefactory/models/lines/deeplsd.py | 2 +- gluefactory/models/matchers/gluestick.py | 4 +--- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml b/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml index 11d53939..62bc883e 100644 --- a/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml +++ b/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml @@ -1,14 +1,14 @@ data: name: homographies homography: - difficulty: 0.5 - max_angle: 30 + difficulty: 0.7 + max_angle: 45 patch_shape: [640, 480] photometric: p: 0.75 train_size: 900000 val_size: 1000 - batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet) + batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet) num_workers: 15 model: name: gluefactory.models.two_view_pipeline @@ -70,4 +70,4 @@ train: n_steps: 4 submodules: [] # clip_grad: 10 # Use only with mixed precision - # load_experiment: \ No newline at end of file + # load_experiment: \ No newline at end of file diff --git a/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml b/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml index 14ff90a2..5946826d 100644 --- a/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml +++ b/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml @@ -1,10 +1,15 @@ data: name: gluefactory.datasets.megadepth + train_num_per_scene: 300 + val_pairs: valid_pairs.txt views: 2 + min_overlap: 0.1 + max_overlap: 0.7 + num_overlap_bins: 3 preprocessing: resize: 640 square_pad: True - batch_size: 60 + batch_size: 160 num_workers: 15 model: name: gluefactory.models.two_view_pipeline @@ -53,9 +58,9 @@ model: train: seed: 0 epochs: 200 - log_every_iter: 10 - eval_every_iter: 100 - save_every_iter: 500 + log_every_iter: 400 + eval_every_iter: 700 + save_every_iter: 1400 lr: 1e-4 lr_schedule: type: exp # exp or multi_step diff --git a/gluefactory/datasets/eth3d.py b/gluefactory/datasets/eth3d.py index ca5e2648..44fd73f8 100644 --- a/gluefactory/datasets/eth3d.py +++ b/gluefactory/datasets/eth3d.py @@ -194,7 +194,7 @@ def download_eth3d(self): if tmp_dir.exists(): shutil.rmtree(tmp_dir) tmp_dir.mkdir(exist_ok=True, parents=True) - url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/" + url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/" zip_name = "ETH3D_undistorted.zip" zip_path = tmp_dir / zip_name torch.hub.download_url_to_file(url_base + zip_name, zip_path) diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index 122f4b4f..d1aa57df 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -41,7 +41,7 @@ def download_model(self, path): if not path.parent.is_dir(): path.parent.mkdir(parents=True, exist_ok=True) - link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download" + link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar" cmd = ["wget", link, "-O", path] print("Downloading DeepLSD model...") subprocess.run(cmd, check=True) diff --git a/gluefactory/models/matchers/gluestick.py b/gluefactory/models/matchers/gluestick.py index e16a8a52..b46af136 100644 --- a/gluefactory/models/matchers/gluestick.py +++ b/gluefactory/models/matchers/gluestick.py @@ -131,7 +131,7 @@ def _init(self, conf): state_dict = { k.replace("module.", ""): v for k, v in state_dict.items() } - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict=False) def _forward(self, data): device = data["keypoints0"].device @@ -200,8 +200,6 @@ def _forward(self, data): kpts0 = normalize_keypoints(kpts0, image_size0) kpts1 = normalize_keypoints(kpts1, image_size1) - assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) - assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"]) desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])