Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds evaluation on challenging WxBS and EVD datasets #52

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,95 @@ AP_lines: 69.22

</details>

#### EVD

The dataset will be auto-downloaded if it is not found on disk, and will need about 27 Mb of free disk space.

<details>
<summary>[Evaluating LightGlue]</summary>

To evaluate LightGlue on EVD, run:
```bash
python -m gluefactory.eval.evd --conf gluefactory/configs/superpoint+lightglue-official.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also run this with python -m gluefactory.eval.evd --conf superpoint+lightglue-official

```
You should expect the following results
```
{'H_error_dlt@10px': 0.0808,
'H_error_dlt@1px': 0.0,
'H_error_dlt@20px': 0.1443,
'H_error_dlt@5px': 0.0,
'H_error_ransac@10px': 0.1045,
'H_error_ransac@1px': 0.0,
'H_error_ransac@20px': 0.1189,
'H_error_ransac@5px': 0.0553,
'H_error_ransac_mAA': 0.069675,
'mH_error_dlt': nan,
'mH_error_ransac': nan,
'mnum_keypoints': 2048.0,
'mnum_matches': 11.0,
'mprec@1px': 0.0,
'mprec@3px': 0.0,
'mransac_inl': 5.0,
'mransac_inl%': 0.089}
```

Here are the results as Area Under the Curve (AUC) of the homography error at 1/5/10/20 pixels:
<details>
<summary>[LightGlue on EVD]</summary>

| Methods (2K features if not specified) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
| ----------------------------------------------------------- | ---------------------------------------------------------------------- |
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 0.0 / 5.4 / 10.1 / 11.7 |
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 0.0 / 5.5 / 10.4 / 11.8 |
| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | 0.0 / 3.8 / 5.2 / 10.0 |
| [DoGHardNet + LightGlue](gluefactory/configs/doghardnet+lightglue-official.yaml) | 0.0 / 5.5 / 10.5 / 11.9 |
| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | 0.0 / 5.4 / 12.4 / 16.2|
| [DISK + LightGlue](gluefactory/configs/disk+lightglue-official.yaml) | 0.0 / 0.0 / 6.9 / 10.1 |


</details>
</details>

#### WxBS

The dataset will be auto-downloaded if it is not found on disk, and will need about 40 Mb of free disk space.

<details>
<summary>[Evaluating LightGlue]</summary>

To evaluate LightGlue on WxBS, run:
```bash
python -m gluefactory.eval.WxBS --conf gluefactory/configs/superpoint+lightglue-official.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python -m gluefactory.eval.wxbs --conf superpoint+lightglue-official

```
You should expect the following results
```
{'epi_error@10px': 0.6141352941176471,
'epi_error@1px': 0.2968,
'epi_error@20px': 0.6937882352941176,
'epi_error@5px': 0.5143617647058826,
'epi_error_mAA': 0.5297713235294118,
Comment on lines +253 to +257
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I run this I get different results, close to what is reported in the table below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update this

'mnum_keypoints': 2048.0,
'mnum_matches': 99.5,
'mransac_inl': 65.0,
'mransac_inl%': nan}
```

Here are the results as Area Under the Curve (AUC) of the epipolar error at 1/5/10/20 pixels:
<details>
<summary>[LightGlue on WxBS]</summary>

| Methods (2K features if not specified) | [PoseLib](../gluefactory/robust_estimators/fundamental_matrix/poselib.py) |
| ----------------------------------------------------------- | ---------------------------------------------------------------------- |
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 13.2 / 39.9 / 49.7 / 56.7 |
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 12.6 / 34.5 / 44.0 / 52.2 |
| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | 9.5 / 22.7 / 29.0 / 34.2 |
| [DoGHardNet + LightGlue](gluefactory/configs/doghardnet+lightglue-official.yaml) | 10.0 / 29.6 / 39.0 / 49.2 |
| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | 18.7 / 46.2 / 56.0 / 63.5 |
| [DISK + LightGlue](gluefactory/configs/disk+lightglue-official.yaml) | 15.1 / 39.3 / 48.2 / 55.2 |

</details>
</details>

#### Image Matching Challenge 2021
Coming soon!

Expand Down
126 changes: 126 additions & 0 deletions gluefactory/datasets/evd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""
import argparse
import logging
import zipfile
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf

from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


def read_homography(path):
with open(path, 'r') as hf:
lines = hf.readlines()
H = []
for l in lines:
H.append([float(x) for x in l.replace('\t',' ').strip().split(' ') if len(x) > 0])
H = np.array(H)
H = H / H[2, 2]
return H

class EVD(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"preprocessing": ImagePreprocessor.default_conf,
"data_dir": "EVD",
"subset": None,
"grayscale": False,
}
url = "http://cmp.felk.cvut.cz/wbs/datasets/EVD.zip"

def _init(self, conf):
assert conf.batch_size == 1
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.root = DATA_PATH / conf.data_dir
if not self.root.exists():
logger.info("Downloading the EVD dataset.")
self.download()
self.pairs = self.index_dataset()
if not self.pairs:
raise ValueError("No image found!")

def download(self):
data_dir = self.root.parent
data_dir.mkdir(exist_ok=True, parents=True)
zip_path = data_dir / self.url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(self.url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(data_dir)
os.unlink(zip_path)


def index_dataset(self):
sets = sorted([x for x in os.listdir(os.path.join(self.root, '1'))])
img_pairs_list = []
for s in sets:
if s == '.DS_Store':
continue
img_pairs_list.append(((os.path.join(self.root, '1', s)),
(os.path.join(self.root, '2', s)),
(os.path.join(self.root, 'h', s.replace('png', 'txt')))))
return img_pairs_list

def __getitem__(self, idx):
imgfname1, imgfname2, h_fname = self.pairs[idx]
H = read_homography(h_fname)
data0 = self.preprocessor(load_image(imgfname1))
data1 = self.preprocessor(load_image(imgfname2))
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
pair_name = imgfname1.split('/')[-1].split('.')[0]
return {
"H_0to1": H.astype(np.float32),
"scene": pair_name,
"view0": data0,
"view1": data1,
"idx": idx,
"name": pair_name,
}

def __len__(self):
return len(self.pairs)

def get_dataset(self, split):
return self

def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 8,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = EVD(conf)
loader = dataset.get_data_loader("test")
logger.info("The dataset has %d elements.", len(loader))

with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
[data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()


if __name__ == "__main__":
from .. import logger # overwrite the logger

parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)
148 changes: 148 additions & 0 deletions gluefactory/datasets/wxbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""

import argparse
import logging

import numpy as np
import torch
import torchvision
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import zipfile
from omegaconf import OmegaConf
from pathlib import Path


from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from .base_dataset import BaseDataset
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from ..geometry.homography import warp_points

logger = logging.getLogger(__name__)


class WxBSDataset(BaseDataset, torch.utils.data.Dataset):
"""Wide multiple baselines stereo dataset."""
url = 'http://cmp.felk.cvut.cz/wbs/datasets/WxBS_v1.1.zip'
zip_fname = 'WxBS_v1.1.zip'
validation_pairs = ['kyiv_dolltheater2', 'petrzin']
default_conf = {
"preprocessing": ImagePreprocessor.default_conf,
"data_dir": "WxBS",
"subset": None,
"grayscale": False,
}
def _init(self, conf):
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.root = DATA_PATH / conf.data_dir
if not self.root.exists():
logger.info("Downloading the WxBS dataset.")
self.download()
self.pairs = self.index_dataset()
if not self.pairs:
raise ValueError("No image found!")

def __len__(self):
return len(self.pairs)

def download(self):
data_dir = self.root
data_dir.mkdir(exist_ok=True, parents=True)
zip_path = data_dir / self.url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(self.url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(data_dir)
os.unlink(zip_path)

def index_dataset(self):
sets = sorted([x for x in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, x))])

img_pairs_list = []
for s in sets[::-1]:
if s == '.DS_Store':
continue
ss = os.path.join(self.root, s)
pairs = os.listdir(ss)
for p in sorted(pairs):
if p == '.DS_Store':
continue
cur_dir = os.path.join(ss, p)
if os.path.isfile(os.path.join(cur_dir, '01.png')):
img_pairs_list.append((os.path.join(cur_dir, '01.png'),
os.path.join(cur_dir, '02.png'),
os.path.join(cur_dir, 'corrs.txt'),
os.path.join(cur_dir, 'crossval_errors.txt')))
elif os.path.isfile(os.path.join(cur_dir, '01.jpg')):
img_pairs_list.append((os.path.join(cur_dir, '01.jpg'),
os.path.join(cur_dir, '02.jpg'),
os.path.join(cur_dir, 'corrs.txt'),
os.path.join(cur_dir, 'crossval_errors.txt')))
else:
continue
return img_pairs_list

def __getitem__(self, idx):
imgfname1, imgfname2, pts_fname, err_fname = self.pairs[idx]
data0 = self.preprocessor(load_image(imgfname1))
data1 = self.preprocessor(load_image(imgfname2))
a = load_image(imgfname1)
pts = np.loadtxt(pts_fname)
pts[:, :2] = warp_points(pts[:, :2], data0["transform"], False)
pts[:, 2:] = warp_points(pts[:, 2:], data1["transform"], False)

crossval_errors = np.loadtxt(err_fname)
pair_name = '/'.join(pts_fname.split('/')[-3:-1]).replace('/', '_')
scene_name = '/'.join(pts_fname.split('/')[-3:-2])
out = {
"pts_0to1": pts,
"scene": scene_name,
"view0": data0,
"view1": data1,
"idx": idx,
"name": pair_name,
"crossval_errors": crossval_errors}
return out

def get_dataset(self, split):
assert split in ['val', 'test']
return self


def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 8,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = WxBSDataset(conf)
loader = dataset.get_data_loader("test")
logger.info("The dataset has %d elements.", len(loader))

with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
[data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()


if __name__ == "__main__":
from .. import logger # overwrite the logger

parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)
11 changes: 9 additions & 2 deletions gluefactory/eval/eval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ def save_eval(dir, summaries, figures, results):
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(arr.dtype, np.number):
arr = arr.astype("object")
hfile.create_dataset(k, data=arr)
if not isinstance(v[0], str):
arr = np.array([x.astype(np.float64) for x in v])
dt = h5py.special_dtype(vlen=np.float64)
hfile.create_dataset(k, data=arr, dtype=dt)
else:
arr = arr.astype("object")
hfile.create_dataset(k, data=arr)
else:
hfile.create_dataset(k, data=arr)
Comment on lines +28 to +36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the fixes proposed in utils this should not be required.

# just to be safe, not used in practice
for k, v in summaries.items():
hfile.attrs[k] = v
Expand Down
Loading
Loading