From e3c3b13330d11cb82e2d154681b58162a69f7da8 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 4 Jun 2019 21:41:18 +0200 Subject: [PATCH] flow arg for radius --- setup.py | 2 +- test/test_knn.py | 2 -- test/test_radius.py | 12 +++++++++--- torch_cluster/__init__.py | 2 +- torch_cluster/radius.py | 19 +++++++++++++------ 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 354d0b92..b040c832 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ ['cuda/rw.cpp', 'cuda/rw_kernel.cu']), ] -__version__ = '1.4.1' +__version__ = '1.4.2' url = 'https://github.com/rusty1s/pytorch_cluster' install_requires = ['scipy'] diff --git a/test/test_knn.py b/test/test_knn.py index 44400bbe..bc6cfacd 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -45,12 +45,10 @@ def test_knn_graph(dtype, device): row, col = knn_graph(x, k=2, flow='target_to_source') col = col.view(-1, 2).sort(dim=-1)[0].view(-1) - assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] row, col = knn_graph(x, k=2, flow='source_to_target') row = row.view(-1, 2).sort(dim=-1)[0].view(-1) - assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] diff --git a/test/test_radius.py b/test/test_radius.py index 5981bdd2..6e5264ce 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -47,6 +47,12 @@ def test_radius_graph(dtype, device): [+1, -1], ], dtype, device) - out = radius_graph(x, r=2) - assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3], - [1, 3, 0, 2, 1, 3, 0, 2]] + row, col = radius_graph(x, r=2, flow='target_to_source') + col = col.view(-1, 2).sort(dim=-1)[0].view(-1) + assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] + assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] + + row, col = radius_graph(x, r=2, flow='source_to_target') + row = row.view(-1, 2).sort(dim=-1)[0].view(-1) + assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] + assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index 3dccfd28..a51d9ef5 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -7,7 +7,7 @@ from .sampler import neighbor_sampler from .rw import random_walk -__version__ = '1.4.1' +__version__ = '1.4.2' __all__ = [ 'graclus_cluster', diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index 7e0d53a4..5067ba56 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): return torch.stack([row[mask], col[mask]], dim=0) -def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): +def radius_graph(x, + r, + batch=None, + loop=False, + max_num_neighbors=32, + flow='source_to_target'): r"""Computes graph edges to all points within a given distance. Args: @@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): self-loops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. (default: :obj:`32`) + flow (string, optional): The flow direction when using in combination + with message passing (:obj:`"source_to_target"` or + :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :rtype: :class:`LongTensor` @@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): >>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) """ - edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1) - row, col = edge_index + assert flow in ['source_to_target', 'target_to_source'] + row, col = radius(x, x, r, batch, batch, max_num_neighbors + 1) + row, col = (col, row) if flow == 'source_to_target' else (row, col) if not loop: - row, col = edge_index mask = row != col row, col = row[mask], col[mask] - edge_index = torch.stack([row, col], dim=0) - return edge_index + return torch.stack([row, col], dim=0)