From de0216d8a13574e26aefb3634e7e8d443d2a3d95 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 14 Oct 2019 13:32:15 +0200 Subject: [PATCH] pytorch 1.3 support --- cpu/compat.h | 5 +++++ cpu/fps.cpp | 17 ++++++++-------- cpu/graclus.cpp | 15 +++++++------- cpu/rw.cpp | 13 ++++++------ cpu/sampler.cpp | 8 +++++--- cuda/coloring.cuh | 6 ++++-- cuda/compat.cuh | 5 +++++ cuda/fps_kernel.cu | 13 ++++++------ cuda/grid_kernel.cu | 8 +++++--- cuda/knn_kernel.cu | 10 ++++++---- cuda/nearest_kernel.cu | 8 +++++--- cuda/proposal.cuh | 11 ++++++---- cuda/radius_kernel.cu | 10 ++++++---- cuda/response.cuh | 11 ++++++---- cuda/rw_kernel.cu | 7 ++++--- setup.py | 42 ++++++++++++++++++++++++++++----------- torch_cluster/__init__.py | 2 +- 17 files changed, 121 insertions(+), 70 deletions(-) create mode 100644 cpu/compat.h create mode 100644 cuda/compat.cuh diff --git a/cpu/compat.h b/cpu/compat.h new file mode 100644 index 00000000..1be09913 --- /dev/null +++ b/cpu/compat.h @@ -0,0 +1,5 @@ +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/cpu/fps.cpp b/cpu/fps.cpp index c89d607c..65102399 100644 --- a/cpu/fps.cpp +++ b/cpu/fps.cpp @@ -1,5 +1,6 @@ #include +#include "compat.h" #include "utils.h" at::Tensor get_dist(at::Tensor x, ptrdiff_t index) { @@ -7,19 +8,19 @@ at::Tensor get_dist(at::Tensor x, ptrdiff_t index) { } at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) { - auto batch_size = batch[-1].data()[0] + 1; + auto batch_size = batch[-1].DATA_PTR()[0] + 1; auto deg = degree(batch, batch_size); auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong); auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0); - auto out = at::empty(cum_k[-1].data()[0], batch.options()); + auto out = at::empty(cum_k[-1].DATA_PTR()[0], batch.options()); - auto cum_deg_d = cum_deg.data(); - auto k_d = k.data(); - auto cum_k_d = cum_k.data(); - auto out_d = out.data(); + auto cum_deg_d = cum_deg.DATA_PTR(); + auto k_d = k.DATA_PTR(); + auto cum_k_d = cum_k.DATA_PTR(); + auto out_d = out.DATA_PTR(); for (ptrdiff_t b = 0; b < batch_size; b++) { auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options()); @@ -27,14 +28,14 @@ at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) { ptrdiff_t start = 0; if (random) { - start = at::randperm(y.size(0), batch.options()).data()[0]; + start = at::randperm(y.size(0), batch.options()).DATA_PTR()[0]; } out_d[cum_k_d[b]] = cum_deg_d[b] + start; auto dist = get_dist(y, start); for (ptrdiff_t i = 1; i < k_d[b]; i++) { - ptrdiff_t argmax = dist.argmax().data()[0]; + ptrdiff_t argmax = dist.argmax().DATA_PTR()[0]; out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax; dist = at::min(dist, get_dist(y, argmax)); } diff --git a/cpu/graclus.cpp b/cpu/graclus.cpp index 3f886909..18a8d0a6 100644 --- a/cpu/graclus.cpp +++ b/cpu/graclus.cpp @@ -1,18 +1,19 @@ #include +#include "compat.h" #include "utils.h" at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = rand(row, col); std::tie(row, col) = to_csr(row, col, num_nodes); - auto row_data = row.data(), col_data = col.data(); + auto row_data = row.DATA_PTR(), col_data = col.DATA_PTR(); auto perm = at::randperm(num_nodes, row.options()); - auto perm_data = perm.data(); + auto perm_data = perm.DATA_PTR(); auto cluster = at::full(num_nodes, -1, row.options()); - auto cluster_data = cluster.data(); + auto cluster_data = cluster.DATA_PTR(); for (int64_t i = 0; i < num_nodes; i++) { auto u = perm_data[i]; @@ -41,16 +42,16 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) { std::tie(row, col, weight) = remove_self_loops(row, col, weight); std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes); - auto row_data = row.data(), col_data = col.data(); + auto row_data = row.DATA_PTR(), col_data = col.DATA_PTR(); auto perm = at::randperm(num_nodes, row.options()); - auto perm_data = perm.data(); + auto perm_data = perm.DATA_PTR(); auto cluster = at::full(num_nodes, -1, row.options()); - auto cluster_data = cluster.data(); + auto cluster_data = cluster.DATA_PTR(); AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] { - auto weight_data = weight.data(); + auto weight_data = weight.DATA_PTR(); for (int64_t i = 0; i < num_nodes; i++) { auto u = perm_data[i]; diff --git a/cpu/rw.cpp b/cpu/rw.cpp index bfd8c67a..ae89297c 100644 --- a/cpu/rw.cpp +++ b/cpu/rw.cpp @@ -1,5 +1,6 @@ #include +#include "compat.h" #include "utils.h" at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, @@ -12,12 +13,12 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, auto out = at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options()); - auto deg_d = deg.data(); - auto cum_deg_d = cum_deg.data(); - auto col_d = col.data(); - auto start_d = start.data(); - auto rand_d = rand.data(); - auto out_d = out.data(); + auto deg_d = deg.DATA_PTR(); + auto cum_deg_d = cum_deg.DATA_PTR(); + auto col_d = col.DATA_PTR(); + auto start_d = start.DATA_PTR(); + auto rand_d = rand.DATA_PTR(); + auto out_d = out.DATA_PTR(); for (ptrdiff_t n = 0; n < start.size(0); n++) { int64_t cur = start_d[n]; diff --git a/cpu/sampler.cpp b/cpu/sampler.cpp index 133e8b9a..f97a94ba 100644 --- a/cpu/sampler.cpp +++ b/cpu/sampler.cpp @@ -1,10 +1,12 @@ #include +#include "compat.h" + at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size, float factor) { - auto start_ptr = start.data(); - auto cumdeg_ptr = cumdeg.data(); + auto start_ptr = start.DATA_PTR(); + auto cumdeg_ptr = cumdeg.DATA_PTR(); std::vector e_ids; for (ptrdiff_t i = 0; i < start.size(0); i++) { @@ -29,7 +31,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size, e_ids.insert(e_ids.end(), v.begin(), v.end()); } else { auto sample = at::randperm(num_neighbors, start.options()); - auto sample_ptr = sample.data(); + auto sample_ptr = sample.DATA_PTR(); for (size_t j = 0; j < size_i; j++) { e_ids.push_back(sample_ptr[j] + low); } diff --git a/cuda/coloring.cuh b/cuda/coloring.cuh index eac1cfb8..63622723 100644 --- a/cuda/coloring.cuh +++ b/cuda/coloring.cuh @@ -2,6 +2,8 @@ #include +#include "compat.cuh" + #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS @@ -30,8 +32,8 @@ int64_t colorize(at::Tensor cluster) { auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat)); auto bernoulli = props.bernoulli(); - colorize_kernel<<>>(cluster.data(), - bernoulli.data(), numel); + colorize_kernel<<>>( + cluster.DATA_PTR(), bernoulli.DATA_PTR(), numel); int64_t out; cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost); diff --git a/cuda/compat.cuh b/cuda/compat.cuh new file mode 100644 index 00000000..1be09913 --- /dev/null +++ b/cuda/compat.cuh @@ -0,0 +1,5 @@ +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/cuda/fps_kernel.cu b/cuda/fps_kernel.cu index f7b15583..7d67b265 100644 --- a/cuda/fps_kernel.cu +++ b/cuda/fps_kernel.cu @@ -1,6 +1,7 @@ #include #include "atomics.cuh" +#include "compat.cuh" #include "utils.cuh" #define THREADS 1024 @@ -164,7 +165,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg, at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { cudaSetDevice(x.get_device()); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(batch_sizes, batch[-1].data(), sizeof(int64_t), + cudaMemcpy(batch_sizes, batch[-1].DATA_PTR(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto batch_size = batch_sizes[0] + 1; @@ -185,15 +186,15 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { auto tmp_dist = at::empty(x.size(0), x.options()); auto k_sum = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(k_sum, cum_k[-1].data(), sizeof(int64_t), + cudaMemcpy(k_sum, cum_k[-1].DATA_PTR(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto out = at::empty(k_sum[0], k.options()); AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "fps_kernel", [&] { - FPS_KERNEL(x.size(1), x.data(), cum_deg.data(), - cum_k.data(), start.data(), - dist.data(), tmp_dist.data(), - out.data()); + FPS_KERNEL(x.size(1), x.DATA_PTR(), cum_deg.DATA_PTR(), + cum_k.DATA_PTR(), start.DATA_PTR(), + dist.DATA_PTR(), tmp_dist.DATA_PTR(), + out.DATA_PTR()); }); return out; diff --git a/cuda/grid_kernel.cu b/cuda/grid_kernel.cu index 76d7cd34..becb4cab 100644 --- a/cuda/grid_kernel.cu +++ b/cuda/grid_kernel.cu @@ -2,6 +2,8 @@ #include #include +#include "compat.cuh" + #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS @@ -31,10 +33,10 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] { grid_kernel<<>>( - cluster.data(), + cluster.DATA_PTR(), at::cuda::detail::getTensorInfo(pos), - size.data(), start.data(), end.data(), - cluster.numel()); + size.DATA_PTR(), start.DATA_PTR(), + end.DATA_PTR(), cluster.numel()); }); return cluster; diff --git a/cuda/knn_kernel.cu b/cuda/knn_kernel.cu index 7c058d9e..ab7b32ce 100644 --- a/cuda/knn_kernel.cu +++ b/cuda/knn_kernel.cu @@ -1,5 +1,6 @@ #include +#include "compat.cuh" #include "utils.cuh" #define THREADS 1024 @@ -79,7 +80,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, at::Tensor batch_y, bool cosine) { cudaSetDevice(x.get_device()); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(batch_sizes, batch_x[-1].data(), sizeof(int64_t), + cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto batch_size = batch_sizes[0] + 1; @@ -94,9 +95,10 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] { knn_kernel<<>>( - x.data(), y.data(), batch_x.data(), - batch_y.data(), dist.data(), row.data(), - col.data(), k, x.size(1), cosine); + x.DATA_PTR(), y.DATA_PTR(), + batch_x.DATA_PTR(), batch_y.DATA_PTR(), + dist.DATA_PTR(), row.DATA_PTR(), + col.DATA_PTR(), k, x.size(1), cosine); }); auto mask = col != -1; diff --git a/cuda/nearest_kernel.cu b/cuda/nearest_kernel.cu index cf5d6e2f..6b207586 100644 --- a/cuda/nearest_kernel.cu +++ b/cuda/nearest_kernel.cu @@ -1,5 +1,6 @@ #include +#include "compat.cuh" #include "utils.cuh" #define THREADS 1024 @@ -62,7 +63,7 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, at::Tensor batch_y) { cudaSetDevice(x.get_device()); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(batch_sizes, batch_x[-1].data(), sizeof(int64_t), + cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto batch_size = batch_sizes[0] + 1; @@ -73,8 +74,9 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] { nearest_kernel<<>>( - x.data(), y.data(), batch_x.data(), - batch_y.data(), out.data(), x.size(1)); + x.DATA_PTR(), y.DATA_PTR(), + batch_x.DATA_PTR(), batch_y.DATA_PTR(), + out.DATA_PTR(), x.size(1)); }); return out; diff --git a/cuda/proposal.cuh b/cuda/proposal.cuh index 6f72d232..00191eae 100644 --- a/cuda/proposal.cuh +++ b/cuda/proposal.cuh @@ -2,6 +2,8 @@ #include +#include "compat.cuh" + #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS @@ -36,8 +38,8 @@ __global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal, void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row, at::Tensor col) { propose_kernel<<>>( - cluster.data(), proposal.data(), row.data(), - col.data(), cluster.numel()); + cluster.DATA_PTR(), proposal.DATA_PTR(), + row.DATA_PTR(), col.DATA_PTR(), cluster.numel()); } template @@ -79,7 +81,8 @@ void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row, at::Tensor col, at::Tensor weight) { AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] { propose_kernel<<>>( - cluster.data(), proposal.data(), row.data(), - col.data(), weight.data(), cluster.numel()); + cluster.DATA_PTR(), proposal.DATA_PTR(), + row.DATA_PTR(), col.DATA_PTR(), + weight.DATA_PTR(), cluster.numel()); }); } diff --git a/cuda/radius_kernel.cu b/cuda/radius_kernel.cu index fcf1a14e..4802589c 100644 --- a/cuda/radius_kernel.cu +++ b/cuda/radius_kernel.cu @@ -1,5 +1,6 @@ #include +#include "compat.cuh" #include "utils.cuh" #define THREADS 1024 @@ -50,7 +51,7 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius, size_t max_num_neighbors) { cudaSetDevice(x.get_device()); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); - cudaMemcpy(batch_sizes, batch_x[-1].data(), sizeof(int64_t), + cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto batch_size = batch_sizes[0] + 1; @@ -64,9 +65,10 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius, AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { radius_kernel<<>>( - x.data(), y.data(), batch_x.data(), - batch_y.data(), row.data(), col.data(), - radius, max_num_neighbors, x.size(1)); + x.DATA_PTR(), y.DATA_PTR(), + batch_x.DATA_PTR(), batch_y.DATA_PTR(), + row.DATA_PTR(), col.DATA_PTR(), radius, + max_num_neighbors, x.size(1)); }); auto mask = row != -1; diff --git a/cuda/response.cuh b/cuda/response.cuh index 881b9bf3..b62a820b 100644 --- a/cuda/response.cuh +++ b/cuda/response.cuh @@ -2,6 +2,8 @@ #include +#include "compat.cuh" + #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS @@ -38,8 +40,8 @@ __global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal, void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row, at::Tensor col) { respond_kernel<<>>( - cluster.data(), proposal.data(), row.data(), - col.data(), cluster.numel()); + cluster.DATA_PTR(), proposal.DATA_PTR(), + row.DATA_PTR(), col.DATA_PTR(), cluster.numel()); } template @@ -84,7 +86,8 @@ void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row, at::Tensor col, at::Tensor weight) { AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] { respond_kernel<<>>( - cluster.data(), proposal.data(), row.data(), - col.data(), weight.data(), cluster.numel()); + cluster.DATA_PTR(), proposal.DATA_PTR(), + row.DATA_PTR(), col.DATA_PTR(), + weight.DATA_PTR(), cluster.numel()); }); } diff --git a/cuda/rw_kernel.cu b/cuda/rw_kernel.cu index 66068252..e7aa64d8 100644 --- a/cuda/rw_kernel.cu +++ b/cuda/rw_kernel.cu @@ -1,5 +1,6 @@ #include +#include "compat.cuh" #include "utils.cuh" #define THREADS 1024 @@ -37,9 +38,9 @@ at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start, at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options()); uniform_rw_kernel<<>>( - row.data(), col.data(), deg.data(), - start.data(), rand.data(), out.data(), - walk_length, start.numel()); + row.DATA_PTR(), col.DATA_PTR(), deg.DATA_PTR(), + start.DATA_PTR(), rand.DATA_PTR(), + out.DATA_PTR(), walk_length, start.numel()); return out.t().contiguous(); } diff --git a/setup.py b/setup.py index 93962d22..938813bd 100644 --- a/setup.py +++ b/setup.py @@ -2,34 +2,52 @@ import torch from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +extra_compile_args = [] +if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + extra_compile_args += ['-DVERSION_GE_1_3'] + ext_modules = [ - CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']), + CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp'], + extra_compile_args=extra_compile_args), CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']), - CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']), - CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp']), - CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp']), + CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp'], + extra_compile_args=extra_compile_args), + CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp'], + extra_compile_args=extra_compile_args), + CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp'], + extra_compile_args=extra_compile_args), ] cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} if CUDA_HOME is not None: ext_modules += [ CUDAExtension('torch_cluster.graclus_cuda', - ['cuda/graclus.cpp', 'cuda/graclus_kernel.cu']), + ['cuda/graclus.cpp', 'cuda/graclus_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.grid_cuda', - ['cuda/grid.cpp', 'cuda/grid_kernel.cu']), + ['cuda/grid.cpp', 'cuda/grid_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.fps_cuda', - ['cuda/fps.cpp', 'cuda/fps_kernel.cu']), + ['cuda/fps.cpp', 'cuda/fps_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.nearest_cuda', - ['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']), + ['cuda/nearest.cpp', 'cuda/nearest_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.knn_cuda', - ['cuda/knn.cpp', 'cuda/knn_kernel.cu']), + ['cuda/knn.cpp', 'cuda/knn_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.radius_cuda', - ['cuda/radius.cpp', 'cuda/radius_kernel.cu']), + ['cuda/radius.cpp', 'cuda/radius_kernel.cu'], + extra_compile_args=extra_compile_args), CUDAExtension('torch_cluster.rw_cuda', - ['cuda/rw.cpp', 'cuda/rw_kernel.cu']), + ['cuda/rw.cpp', 'cuda/rw_kernel.cu'], + extra_compile_args=extra_compile_args), ] -__version__ = '1.4.4' +__version__ = '1.4.5' url = 'https://github.com/rusty1s/pytorch_cluster' install_requires = ['scipy'] diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index 40e1e0fd..16a1d485 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -7,7 +7,7 @@ from .rw import random_walk from .sampler import neighbor_sampler -__version__ = '1.4.4' +__version__ = '1.4.5' __all__ = [ 'graclus_cluster',