diff --git a/CHANGELOG.md b/CHANGELOG.md index d3dd463..1d8ad0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [0.14.4] - 2022-10-31 + +### Added +- Python 3.10 and 3.11 wheels! + - Only for supported torch versions. +- Support torch 1.13. +- Tiled NA2D for 3x3 kernels. + +### Changed +- Minor changes to the setup script to fix `pip install natten`. + ## [0.14.2] - 2022-10-15 ### Added diff --git a/README.md b/README.md index bbb90c3..b571520 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,7 @@ The latest version of NATTEN runs pretty fast on Ampere with the latest torch an ## Requirements -NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, and 3.9. -However, we highly recommend using Python 3.8 and PyTorch 1.12.1 + CUDA 11.6 for the best performance. +NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, 3.9, 3.10(only torch >= 1.11), and 3.11 (only torch >= 1.13). **NOTE:** The current version of NATTEN comes with Linux-only wheels, and supports Pascal and above (`SM >= 60`, i.e. Tesla P100). Make sure your GPU is supported by referring to @@ -94,7 +93,7 @@ python -m unittest discover -v -s ./tests - [ ] Neighborhood Attention 3D (CPU) - [x] Dilation support - [x] Float16 support and utilization -- [ ] BFloat16 support +- [ ] BFloat16 support (awaiting CUDA 11.8/12 builds of torch) - [ ] Kepler and Maxwell (30<=SM<60) support - [ ] Windows builds @@ -104,8 +103,8 @@ Simply import `NeighborhoodAttention1D` or `NeighborhoodAttention2D` from `natte from natten import NeighborhoodAttention1D from natten import NeighborhoodAttention2D -na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda() -na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda() +na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4) +na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4) ``` ### FLOPs diff --git a/assets/README_pypi.md b/assets/README_pypi.md index 6a6a1e5..9dfeddc 100644 --- a/assets/README_pypi.md +++ b/assets/README_pypi.md @@ -34,8 +34,7 @@ The latest version of NATTEN runs pretty fast on Ampere with the latest torch an ## Requirements -NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, and 3.9. -However, we highly recommend using Python 3.8 and PyTorch 1.12.1 + CUDA 11.6 for the best performance. +NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, 3.9, 3.10(only torch >= 1.11), and 3.11 (only torch >= 1.13). **NOTE:** The current version of NATTEN comes with Linux-only wheels, and supports Pascal and above (`SM >= 60`, i.e. Tesla P100). Make sure your GPU is supported by referring to @@ -98,8 +97,8 @@ Simply import `NeighborhoodAttention1D` or `NeighborhoodAttention2D` from `natte from natten import NeighborhoodAttention1D from natten import NeighborhoodAttention2D -na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda() -na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda() +na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4) +na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4) ``` ### FLOPs diff --git a/dev/packaging/build_all_wheels_parallel.sh b/dev/packaging/build_all_wheels_parallel.sh index 45fb9d2..cf91af5 100755 --- a/dev/packaging/build_all_wheels_parallel.sh +++ b/dev/packaging/build_all_wheels_parallel.sh @@ -1,7 +1,6 @@ #!/bin/bash -e # Based on detectron2's builder: # github.com/facebookresearch/detectron2 -# Copyright (c) Facebook, Inc. and its affiliates. [[ -d "dev/packaging" ]] || { echo "Please run this script at natten root!" @@ -11,6 +10,7 @@ build_one() { cu=$1 pytorch_ver=$2 + cp310=${3:-0} case "$cu" in cu*) @@ -28,7 +28,13 @@ build_one() { echo "Launching container $container_name ..." container_id="$container_name"_"$cu"_"$pytorch_ver" - py_versions=(3.7 3.8 3.9) + if [ $cp310 -eq 2 ]; then + py_versions=(3.7 3.8 3.9 3.10 3.11) + elif [ $cp310 -eq 1 ]; then + py_versions=(3.7 3.8 3.9 3.10) + else + py_versions=(3.7 3.8 3.9) + fi for py in "${py_versions[@]}"; do docker run -itd \ @@ -51,11 +57,17 @@ EOF if [[ -n "$1" ]] && [[ -n "$2" ]]; then build_one "$1" "$2" else - build_one cu116 1.12.1 & build_one cu113 1.12.1 & build_one cu102 1.12.1 & build_one cpu 1.12.1 + # 1.13 and newer -- build python 3.11 wheels + build_one cu117 1.13 2 & build_one cu116 1.13 2 & build_one cpu 1.13 2 - build_one cu116 1.12 & build_one cu113 1.12 & build_one cu102 1.12 & build_one cpu 1.12 + # 1.11 and newer -- build python 3.10 wheels + build_one cu116 1.12.1 1 & build_one cu113 1.12.1 1 & build_one cu102 1.12.1 1 & build_one cpu 1.12.1 1 - build_one cu115 1.11 & build_one cu113 1.11 & build_one cu102 1.11 & build_one cpu 1.11 + build_one cu116 1.12 1 & build_one cu113 1.12 1 & build_one cu102 1.12 1 & build_one cpu 1.12 1 + + build_one cu115 1.11 1 & build_one cu113 1.11 1 & build_one cu102 1.11 1 & build_one cpu 1.11 1 + + # 1.10 and older build_one cu113 1.10.1 & build_one cu111 1.10.1 & build_one cu102 1.10.1 & build_one cpu 1.10.1 diff --git a/dev/packaging/build_cpu_wheel.sh b/dev/packaging/build_cpu_wheel.sh deleted file mode 100755 index 982f154..0000000 --- a/dev/packaging/build_cpu_wheel.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash -# Based on detectron2's builder: -# github.com/facebookresearch/detectron2 -set -ex - -ldconfig # https://github.com/NVIDIA/nvidia-docker/issues/854 - -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "$script_dir/pkg_helpers.bash" - -echo "Build Settings:" -echo "CU_VERSION: $CU_VERSION" # e.g. cu101 -echo "PYTHON_VERSION: $PYTHON_VERSION" # e.g. 3.7 -echo "PYTORCH_VERSION: $PYTORCH_VERSION" # e.g. 1.4 - -setup_cuda -setup_wheel_python - -yum install ninja-build -y -ln -sv /usr/bin/ninja-build /usr/bin/ninja || true - -pip_install pip numpy -U -pip_install "torch==$PYTORCH_VERSION" \ - -f https://download.pytorch.org/whl/cpu/torch_stable.html - -python setup.py sdist bdist_wheel --plat-name=manylinux2014_x86_64 diff --git a/dev/packaging/build_default_wheel.sh b/dev/packaging/build_default_wheel.sh deleted file mode 100755 index 6316031..0000000 --- a/dev/packaging/build_default_wheel.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -e -# Based on detectron2's builder: -# github.com/facebookresearch/detectron2 -# Copyright (c) Facebook, Inc. and its affiliates. - -[[ -d "dev/packaging" ]] || { - echo "Please run this script at natten root!" - exit 1 -} - -pytorch_ver="1.8" -container_name=manylinux-cpu -cu="cpu" -py_versions=(3.7 3.8 3.9) - -echo "Launching container $container_name ..." -container_id="$container_name"_"$pytorch_ver" - -for py in "${py_versions[@]}"; do - docker run -itd \ - --name "$container_id" \ - --mount type=bind,source="$(pwd)",target=/natten \ - pytorch/$container_name - - cat < -__global__ void nattena_cuda_backward_kernel_fp16( +__global__ void natten2da_cuda_backward_kernel_fp16( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -229,7 +241,7 @@ __global__ void nattena_cuda_backward_kernel_fp16( template -__global__ void nattena_cuda_backward_kernel_fp32( +__global__ void natten2da_cuda_backward_kernel_fp32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -270,7 +282,173 @@ __global__ void nattena_cuda_backward_kernel_fp32( /* TODO: FIX BANK CONFLICTS */ template -__global__ void nattena_cuda_backward_kernel_fp16_5x5_32( +__global__ void natten2da_cuda_backward_kernel_fp16_3x3_32( + const torch::PackedTensorAccessor32 d_out, + torch::PackedTensorAccessor32 d_attn, + const torch::PackedTensorAccessor32 value, + const int height, + const int width, + const int batch_size, + const int heads, + const int dilation_in) { + const int dilation = (DILATION>0) ? DILATION : dilation_in; + // Because batch heads have stride 1 per threadblock, we can just use blockIdx since blockDim will be 1 and threadIdx will + // always be 0. + // const int z = blockIdx.z * blockDim.z + threadIdx.z; + const int z = blockIdx.z; + const int b = z / heads; + const int h = z - b * heads; + // Not needed again because it will always be true. + // if (z < batch_size * heads) + // { + const int lti = threadIdx.y * (TILE_3*KERNEL_SIZE_3) + threadIdx.x; + const int stride2 = DIMHALF_32 * width; + const int batchHeadOffset = b * (stride2*height*heads) + h * (stride2*height); + const int si = int(blockIdx.y / dilation) * (TILE_3 * dilation) + (blockIdx.y % dilation); + const int sj = int(blockIdx.x / dilation) * (TILE_3 * dilation) + (blockIdx.x % dilation); + const int sni = get_window_start(si, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int snj = get_window_start(sj, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __shared__ __half2 tile[TILE_3*TILE_3][DIM_32+3]; + __shared__ __half2 kTile[KTILE_3*KTILE_3][DIM_32+3]; + __half2* d_out2 = reinterpret_cast<__half2*>(d_out.data()); + __half2* value2 = reinterpret_cast<__half2*>(value.data()); + + /* d_out tile */ + const int qtx = lti / QSTRIDE_3_HALF; + const int qty = (lti - qtx * QSTRIDE_3_HALF) * QITERS_3_HALF; + if (qtx < TILE_3*TILE_3) + { + int qi = qtx / TILE_3; + const int qj = (qtx - qi * TILE_3) * dilation + sj; + qi = qi * dilation + si; + if (qi < height && qj < width){ + #pragma unroll + for (int ti=0; ti < QITERS_3_HALF; ++ti) + tile[qtx][qty+ti] = d_out2[batchHeadOffset + qi * stride2 + qj * DIMHALF_32 + qty+ti]; + } + } + /* value tile */ + const int ktx = lti / KSTRIDE_32; + const int kty = (lti - ktx * KSTRIDE_32) * KHALFITERS_32; + if (ktx < KTILE_3*KTILE_3) + { + int bi = ktx / KTILE_3; + const int bj = (ktx - bi * KTILE_3) * dilation + snj; + bi = bi * dilation + sni; + if (bi < height && bj < width){ + const int valueOffset = batchHeadOffset + bi * stride2 + bj * DIMHALF_32 + kty; + #pragma unroll + for (int ti=0; ti < KHALFITERS_32; ++ti) + kTile[ktx][kty + ti] = value2[valueOffset + ti]; + } + } + __syncthreads(); + const int ii = threadIdx.y / KERNEL_SIZE_3; + const int ki = threadIdx.y - ii * KERNEL_SIZE_3; + const int jj = threadIdx.x / KERNEL_SIZE_3; + const int kj = threadIdx.x - jj * KERNEL_SIZE_3; + const int i = si + ii*dilation, j = sj + jj*dilation; + if (i < height && j < width){ + const int ni = get_window_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int nj = get_window_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __half2 updt = __float2half2_rn(0.f); + const int d_outIdx = ii*TILE_3 + jj; + const int valueIdx = int((ni+ki*dilation - sni)/dilation)*KTILE_3 + int((nj+kj*dilation - snj)/dilation); + + #pragma unroll + for (int dimOffset=0; dimOffset < DIMHALF_32; ++dimOffset) + updt = __hfma2(tile[d_outIdx][dimOffset], kTile[valueIdx][dimOffset], updt); + const int index = b * d_attn.stride(0) + h * d_attn.stride(1) + i * d_attn.stride(2) + j * d_attn.stride(3) + ki*KERNEL_SIZE_3+kj; + d_attn.data()[index] = static_cast(__hadd(updt.x, updt.y)); + } + //} +} + +/* TODO: CHECK BANK CONFLICTS */ +template +__global__ void natten2da_cuda_backward_kernel_fp32_3x3_32( + const torch::PackedTensorAccessor32 d_out, + torch::PackedTensorAccessor32 d_attn, + const torch::PackedTensorAccessor32 value, + const int height, + const int width, + const int batch_size, + const int heads, + const int dilation_in) { + const int dilation = (DILATION>0) ? DILATION : dilation_in; + // Because batch heads have stride 1 per threadblock, we can just use blockIdx since blockDim will be 1 and threadIdx will + // always be 0. + // const int z = blockIdx.z * blockDim.z + threadIdx.z; + const int z = blockIdx.z; + const int b = z / heads; + const int h = z - b * heads; + // Not needed again because it will always be true. + // if (z < batch_size * heads) + // { + const int lti = threadIdx.y * (TILE_3*KERNEL_SIZE_3) + threadIdx.x; + const int batchHeadOffset = b * d_out.stride(0) + h * d_out.stride(1); + const int si = int(blockIdx.y / dilation) * (TILE_3 * dilation) + (blockIdx.y % dilation); + const int sj = int(blockIdx.x / dilation) * (TILE_3 * dilation) + (blockIdx.x % dilation); + const int sni = get_window_start(si, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int snj = get_window_start(sj, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __shared__ scalar_t tile[TILE_3*TILE_3][DIM_32+3]; + __shared__ scalar_t kTile[KTILE_3*KTILE_3][DIM_32+3]; + + /* d_out tile */ + const int qtx = lti / QSTRIDE_3; + const int qty = (lti - qtx * QSTRIDE_3) * QITERS_3; + if (qtx < TILE_3*TILE_3) + { + int qi = qtx / TILE_3; + const int qj = (qtx - qi * TILE_3) * dilation + sj; + qi = qi * dilation + si; + if (qi < height && qj < width){ + #pragma unroll + for (int ti=0; ti < QITERS_3; ++ti) + tile[qtx][qty+ti] = d_out.data()[batchHeadOffset + qi * d_out.stride(2) + qj * d_out.stride(3) + qty+ti]; + } + } + /* value tile */ + const int ktx = lti / KSTRIDE_32; + const int kty = (lti - ktx * KSTRIDE_32) * KITERS_32; + if (ktx < KTILE_3*KTILE_3) + { + int bi = ktx / KTILE_3; + const int bj = (ktx - bi * KTILE_3) * dilation + snj; + bi = bi * dilation + sni; + if (bi < height && bj < width){ + const int valueOffset = batchHeadOffset + bi * d_out.stride(2) + bj * d_out.stride(3) + kty; + #pragma unroll + for (int ti=0; ti < KITERS_32; ++ti) + kTile[ktx][kty + ti] = value.data()[valueOffset + ti]; + } + } + __syncthreads(); + const int ii = threadIdx.y / KERNEL_SIZE_3; + const int ki = threadIdx.y - ii * KERNEL_SIZE_3; + const int jj = threadIdx.x / KERNEL_SIZE_3; + const int kj = threadIdx.x - jj * KERNEL_SIZE_3; + const int i = si + ii*dilation, j = sj + jj*dilation; + if (i < height && j < width){ + const int ni = get_window_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int nj = get_window_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + scalar_t updt = scalar_t(0); + const int d_outIdx = ii*TILE_3 + jj; + const int valueIdx = int((ni+ki*dilation - sni)/dilation)*KTILE_3 + int((nj+kj*dilation - snj)/dilation); + + #pragma unroll + for (int dimOffset=0; dimOffset < DIM_32; ++dimOffset) + updt += tile[d_outIdx][dimOffset] * kTile[valueIdx][dimOffset]; + + const int index = b * d_attn.stride(0) + h * d_attn.stride(1) + i * d_attn.stride(2) + j * d_attn.stride(3) + ki*KERNEL_SIZE_3+kj; + d_attn.data()[index] = updt; + } + //} +} + +/* TODO: FIX BANK CONFLICTS */ +template +__global__ void natten2da_cuda_backward_kernel_fp16_5x5_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -352,7 +530,7 @@ __global__ void nattena_cuda_backward_kernel_fp16_5x5_32( /* TODO: CHECK BANK CONFLICTS */ template -__global__ void nattena_cuda_backward_kernel_fp32_5x5_32( +__global__ void natten2da_cuda_backward_kernel_fp32_5x5_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -433,7 +611,7 @@ __global__ void nattena_cuda_backward_kernel_fp32_5x5_32( } template -__global__ void nattena_cuda_backward_kernel_fp16_7x7_9x9_32( +__global__ void natten2da_cuda_backward_kernel_fp16_7x7_9x9_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -518,7 +696,7 @@ __global__ void nattena_cuda_backward_kernel_fp16_7x7_9x9_32( } template -__global__ void nattena_cuda_backward_kernel_fp32_7x7_9x9_32( +__global__ void natten2da_cuda_backward_kernel_fp32_7x7_9x9_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -596,7 +774,7 @@ __global__ void nattena_cuda_backward_kernel_fp32_7x7_9x9_32( } template -__global__ void nattena_cuda_backward_kernel_fp16_11x11_13x13_32( +__global__ void natten2da_cuda_backward_kernel_fp16_11x11_13x13_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -683,7 +861,7 @@ __global__ void nattena_cuda_backward_kernel_fp16_11x11_13x13_32( } template -__global__ void nattena_cuda_backward_kernel_fp32_11x11_13x13_32( +__global__ void natten2da_cuda_backward_kernel_fp32_11x11_13x13_32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 value, @@ -763,7 +941,7 @@ __global__ void nattena_cuda_backward_kernel_fp32_11x11_13x13_32( } template -__global__ void nattenv_cuda_backward_kernel_fp32( +__global__ void natten2dv_cuda_backward_kernel_fp32( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_value, const torch::PackedTensorAccessor32 attn, @@ -810,7 +988,7 @@ __global__ void nattenv_cuda_backward_kernel_fp32( } template -__global__ void nattenv_cuda_backward_kernel_fp16( +__global__ void natten2dv_cuda_backward_kernel_fp16( const torch::PackedTensorAccessor32 d_out, torch::PackedTensorAccessor32 d_value, const torch::PackedTensorAccessor32 attn, @@ -940,9 +1118,9 @@ std::vector natten2dav_cuda_backward_tiled_32( int zsize = batch_size * heads; CHECK_FEATMAP(height, width, kernel_size, dilation); TORCH_CHECK(dim == DIM_32, "natten2dav_cuda_backward_tiled_32", " only supports 32-dim attention heads."); - TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_5 || kernel_size == KERNEL_SIZE_9 || - kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, - "natten2dav_cuda_backward_tiled_32", " only supports kernel sizes 5, 7, 9, 11, and 13."); + TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_3 || kernel_size == KERNEL_SIZE_5 || + kernel_size == KERNEL_SIZE_9 || kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, + "natten2dav_cuda_backward_tiled_32", " only supports kernel sizes 3, 5, 7, 9, 11, and 13."); auto d_attn = torch::zeros_like(attn); auto d_value = torch::zeros_like(value); @@ -955,6 +1133,12 @@ std::vector natten2dav_cuda_backward_tiled_32( YTHREADS = XYTHREADS_7; BATCHTHREADS = BATCHTHREADS_7; } + else if (kernel_size == KERNEL_SIZE_3) + { + XTHREADS = XYTHREADS_3; + YTHREADS = XYTHREADS_3; + BATCHTHREADS = BATCHTHREADS_3; + } else if (kernel_size == KERNEL_SIZE_5) { XTHREADS = XYTHREADS_5; @@ -999,30 +1183,38 @@ std::vector natten2dav_cuda_backward_tiled_32( if (kernel_size == KERNEL_SIZE_7) { LAUNCH_DNA_KNS_TILED79(TILE_7, KTILE_7, KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, - nattena_cuda_backward_kernel_fp32_7x7_9x9_32, + natten2da_cuda_backward_kernel_fp32_7x7_9x9_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); } else if (kernel_size == KERNEL_SIZE_9) { LAUNCH_DNA_KNS_TILED79(TILE_9, KTILE_9, KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, - nattena_cuda_backward_kernel_fp32_7x7_9x9_32, + natten2da_cuda_backward_kernel_fp32_7x7_9x9_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, + block, 0, stream,d_out_a, d_value_a, attn_a, height, width, + heads, dilation, dim, n_value); + } + else if (kernel_size == KERNEL_SIZE_3) + { + LAUNCH_DNA_DS(dilation, natten2da_cuda_backward_kernel_fp32_3x3_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, + batch_size, heads, dilation); + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); } else if (kernel_size == KERNEL_SIZE_5) { - LAUNCH_DNA_DS(dilation, nattena_cuda_backward_kernel_fp32_5x5_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, + LAUNCH_DNA_DS(dilation, natten2da_cuda_backward_kernel_fp32_5x5_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_5, NEIGHBORHOOD_SIZE_5, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_5, NEIGHBORHOOD_SIZE_5, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); } @@ -1030,10 +1222,10 @@ std::vector natten2dav_cuda_backward_tiled_32( { LAUNCH_DNA_KNS_TILED1113(TILE_11_X, TILE_11_Y, KTILE_11_X, KTILE_11_Y, KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, scalar_t, - nattena_cuda_backward_kernel_fp32_11x11_13x13_32, + natten2da_cuda_backward_kernel_fp32_11x11_13x13_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); } @@ -1041,10 +1233,10 @@ std::vector natten2dav_cuda_backward_tiled_32( { LAUNCH_DNA_KNS_TILED1113(TILE_13_X, TILE_13_Y, KTILE_13_X, KTILE_13_Y, KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, float, - nattena_cuda_backward_kernel_fp32_11x11_13x13_32, + natten2da_cuda_backward_kernel_fp32_11x11_13x13_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); } @@ -1070,9 +1262,9 @@ std::vector natten2dav_cuda_backward_fp16_tiled_32( int zsize = batch_size * heads; CHECK_FEATMAP(height, width, kernel_size, dilation); TORCH_CHECK(dimhalf*2 == DIM_32, "natten2dav_cuda_backward_fp16_tiled_32", " only supports 32-dim attention heads."); - TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_5 || kernel_size == KERNEL_SIZE_9 || - kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, - "natten2dav_cuda_backward_fp16_tiled_32", " only supports kernel sizes 5, 7, 9, 11, and 13."); + TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_3 || kernel_size == KERNEL_SIZE_5 || + kernel_size == KERNEL_SIZE_9 || kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, + "natten2dav_cuda_backward_fp16_tiled_32", " only supports kernel sizes 3, 5, 7, 9, 11, and 13."); auto d_attn = torch::zeros_like(attn); auto d_value = torch::zeros_like(value); @@ -1085,6 +1277,12 @@ std::vector natten2dav_cuda_backward_fp16_tiled_32( YTHREADS = XYTHREADS_7; BATCHTHREADS = BATCHTHREADS_7; } + else if (kernel_size == KERNEL_SIZE_3) + { + XTHREADS = XYTHREADS_3; + YTHREADS = XYTHREADS_3; + BATCHTHREADS = BATCHTHREADS_3; + } else if (kernel_size == KERNEL_SIZE_5) { XTHREADS = XYTHREADS_5; @@ -1128,50 +1326,59 @@ std::vector natten2dav_cuda_backward_fp16_tiled_32( const auto attn_a = attn.packed_accessor32(); if (kernel_size == KERNEL_SIZE_7){ LAUNCH_DNA_KNS_TILED79(TILE_7, KTILE_7, KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, - nattena_cuda_backward_kernel_fp16_7x7_9x9_32, + natten2da_cuda_backward_kernel_fp16_7x7_9x9_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_7, NEIGHBORHOOD_SIZE_7, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); } else if (kernel_size == KERNEL_SIZE_9){ LAUNCH_DNA_KNS_TILED79(TILE_9, KTILE_9, KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, - nattena_cuda_backward_kernel_fp16_7x7_9x9_32, + natten2da_cuda_backward_kernel_fp16_7x7_9x9_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); } else if (kernel_size == KERNEL_SIZE_5){ - LAUNCH_DNA_DS(dilation, nattena_cuda_backward_kernel_fp16_5x5_32, + LAUNCH_DNA_DS(dilation, natten2da_cuda_backward_kernel_fp16_5x5_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_5, NEIGHBORHOOD_SIZE_5, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_5, NEIGHBORHOOD_SIZE_5, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, + block, 0, stream,d_out_a, d_value_a, attn_a, height, width, + heads, dilation, dimhalf, nhalf_value); + } + else if (kernel_size == KERNEL_SIZE_3){ + LAUNCH_DNA_DS(dilation, natten2da_cuda_backward_kernel_fp16_3x3_32, + attn_blocks, attn_threads, 0, stream, + d_out_a, d_attn_a, value_a, height, width, + batch_size, heads, dilation); + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); } else if (kernel_size == KERNEL_SIZE_11){ LAUNCH_DNA_KNS_TILED1113(TILE_11_X, TILE_11_Y, KTILE_11_X, KTILE_11_Y, KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, scalar_t, - nattena_cuda_backward_kernel_fp16_11x11_13x13_32, + natten2da_cuda_backward_kernel_fp16_11x11_13x13_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); } else if (kernel_size == KERNEL_SIZE_13){ LAUNCH_DNA_KNS_TILED1113(TILE_13_X, TILE_13_Y, KTILE_13_X, KTILE_13_Y, KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, scalar_t, - nattena_cuda_backward_kernel_fp16_11x11_13x13_32, + natten2da_cuda_backward_kernel_fp16_11x11_13x13_32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation); - _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, + _IN_LAUNCH_DNA_KNS(KERNEL_SIZE_13, NEIGHBORHOOD_SIZE_13, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream,d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); } @@ -1218,9 +1425,9 @@ std::vector natten2dav_cuda_backward( const auto d_out_a = d_out.packed_accessor32(); const auto value_a = value.packed_accessor32(); const auto attn_a = attn.packed_accessor32(); - LAUNCH_DNA_KNS(kernel_size, dilation, nattena_cuda_backward_kernel_fp32, attn_blocks, attn_threads, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2da_cuda_backward_kernel_fp32, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation, dim); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenv_cuda_backward_kernel_fp32, grid_value, block, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dv_cuda_backward_kernel_fp32, grid_value, block, 0, stream, d_out_a, d_value_a, attn_a, height, width, heads, dilation, dim, n_value); })); return {d_attn, d_value}; @@ -1266,9 +1473,9 @@ std::vector natten2dav_cuda_backward_fp16( const auto d_out_a = d_out.packed_accessor32(); const auto value_a = value.packed_accessor32(); const auto attn_a = attn.packed_accessor32(); - LAUNCH_DNA_KNS(kernel_size, dilation, nattena_cuda_backward_kernel_fp16, attn_blocks, attn_threads, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2da_cuda_backward_kernel_fp16, attn_blocks, attn_threads, 0, stream, d_out_a, d_attn_a, value_a, height, width, batch_size, heads, dilation, dimhalf); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenv_cuda_backward_kernel_fp16, grid_value, block, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dv_cuda_backward_kernel_fp16, grid_value, block, 0, stream, d_out_a, d_value_a, attn_a, height, width, heads, dilation, dimhalf, nhalf_value); })); return {d_attn, d_value}; diff --git a/natten/src/cuda/natten2dqkrpb_cuda_kernel.cu b/natten/src/cuda/natten2dqkrpb_cuda_kernel.cu index 452bf66..8e17ea1 100644 --- a/natten/src/cuda/natten2dqkrpb_cuda_kernel.cu +++ b/natten/src/cuda/natten2dqkrpb_cuda_kernel.cu @@ -25,22 +25,26 @@ namespace natten { #define KERNEL_SIZE_9 9 #define KERNEL_SIZE_7 7 #define KERNEL_SIZE_5 5 +#define KERNEL_SIZE_3 3 #define NEIGHBORHOOD_SIZE_13 6 #define NEIGHBORHOOD_SIZE_11 5 #define NEIGHBORHOOD_SIZE_9 4 #define NEIGHBORHOOD_SIZE_7 3 #define NEIGHBORHOOD_SIZE_5 2 +#define NEIGHBORHOOD_SIZE_3 1 // Always keep batchthreads 1, because we want each thread block to process one 1 sample 1 head #define BATCHTHREADS_13 1 #define BATCHTHREADS_11 1 #define BATCHTHREADS_9 1 #define BATCHTHREADS_7 1 #define BATCHTHREADS_5 1 +#define BATCHTHREADS_3 1 // Tile is the number of pixels across each axis that are processed within a single threadblock // So far the best tile size for Kernel size 7 is 3x3. #define TILE_9 3 #define TILE_7 3 #define TILE_5 4 +#define TILE_3 7 #define TILE_11_X 2 #define TILE_11_Y 3 @@ -50,6 +54,7 @@ namespace natten { #define KTILE_9 11 #define KTILE_7 9 #define KTILE_5 8 +#define KTILE_3 9 #define KTILE_11_X 12 #define KTILE_11_Y 13 @@ -63,6 +68,7 @@ namespace natten { #define XYTHREADS_9 27 #define XYTHREADS_7 21 #define XYTHREADS_5 20 +#define XYTHREADS_3 21 #define XTHREADS_11 33 #define YTHREADS_11 22 @@ -83,8 +89,13 @@ namespace natten { #define KSTRIDE_32 4 // For kernel size 5, we have to do 2 query dims per thread, because we have fewer threads in each threadblock than the total // number of queries. +// For kernel size 3, we have to read 2 query dims per thread #define QITERS_5 2 #define QSTRIDE_5 16 +#define QITERS_3 4 +#define QSTRIDE_3 8 +#define QITERS_3_HALF 2 +#define QSTRIDE_3_HALF 8 // This is just for the other kernels that are not using SMEM #define CUDA_NUM_THREADS_Q 512 @@ -189,6 +200,182 @@ __global__ void natten2dqkrpb_cuda_forward_kernel_fp32( } +/* TODO: FIX BANK CONFLICTS */ +template +__global__ void natten2dqkrpb_cuda_forward_kernel_fp16_3x3_32( + const torch::PackedTensorAccessor32 query, + const torch::PackedTensorAccessor32 key, + const torch::PackedTensorAccessor32 rpb, + torch::PackedTensorAccessor32 attn, + const int height, + const int width, + const int batch_size, + const int heads, + const int dilation_in) { + const int dilation = (DILATION>0) ? DILATION : dilation_in; + // Because batch heads have stride 1 per threadblock, we can just use blockIdx since blockDim will be 1 and threadIdx will + // always be 0. + // const int z = blockIdx.z * blockDim.z + threadIdx.z; + const int z = blockIdx.z; + const int b = z / heads; + const int h = z - b * heads; + // Not needed again because it will always be true. + // if (z < batch_size * heads) + // { + const int lti = threadIdx.y * (TILE_3*KERNEL_SIZE_3) + threadIdx.x; + const int stride2 = DIMHALF_32 * width; + const int batchHeadOffset = b * (stride2*height*heads) + h * (stride2*height); + const int si = int(blockIdx.y / dilation) * (TILE_3 * dilation) + (blockIdx.y % dilation); + const int sj = int(blockIdx.x / dilation) * (TILE_3 * dilation) + (blockIdx.x % dilation); + const int sni = get_window_start(si, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int snj = get_window_start(sj, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __shared__ __half2 tile[TILE_3*TILE_3][DIM_32+3]; + __shared__ __half2 kTile[KTILE_3*KTILE_3][DIM_32+3]; + __half2* query2 = reinterpret_cast<__half2*>(query.data()); + __half2* key2 = reinterpret_cast<__half2*>(key.data()); + + /* query tile */ + const int qtx = lti / QSTRIDE_3_HALF; + const int qty = (lti - qtx * QSTRIDE_3_HALF) * QITERS_3_HALF; + if (qtx < TILE_3*TILE_3) + { + int qi = qtx / TILE_3; + const int qj = (qtx - qi * TILE_3) * dilation + sj; + qi = qi * dilation + si; + if (qi < height && qj < width){ + #pragma unroll + for (int ti=0; ti < QITERS_3_HALF; ++ti) + tile[qtx][qty+ti] = query2[batchHeadOffset + qi * stride2 + qj * DIMHALF_32 + qty+ti]; + } + } + /* key tile */ + const int ktx = lti / KSTRIDE_32; + const int kty = (lti - ktx * KSTRIDE_32) * KHALFITERS_32; + if (ktx < KTILE_3*KTILE_3) + { + int bi = ktx / KTILE_3; + const int bj = (ktx - bi * KTILE_3) * dilation + snj; + bi = bi * dilation + sni; + if (bi < height && bj < width){ + const int keyOffset = batchHeadOffset + bi * stride2 + bj * DIMHALF_32 + kty; + #pragma unroll + for (int ti=0; ti < KHALFITERS_32; ++ti) + kTile[ktx][kty + ti] = key2[keyOffset + ti]; + } + } + __syncthreads(); + const int ii = threadIdx.y / KERNEL_SIZE_3; + const int ki = threadIdx.y - ii * KERNEL_SIZE_3; + const int jj = threadIdx.x / KERNEL_SIZE_3; + const int kj = threadIdx.x - jj * KERNEL_SIZE_3; + const int i = si + ii*dilation, j = sj + jj*dilation; + if (i < height && j < width){ + const int ni = get_window_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int nj = get_window_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int pi = get_pb_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int pj = get_pb_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __half2 updt = __float2half2_rn(0.f); + const int queryIdx = ii*TILE_3 + jj; + const int keyIdx = int((ni+ki*dilation - sni)/dilation)*KTILE_3 + int((nj+kj*dilation - snj)/dilation); + + #pragma unroll + for (int dimOffset=0; dimOffset < DIMHALF_32; ++dimOffset) + updt = __hfma2(tile[queryIdx][dimOffset], kTile[keyIdx][dimOffset], updt); + const int index = b * attn.stride(0) + h * attn.stride(1) + i * attn.stride(2) + j * attn.stride(3) + ki*KERNEL_SIZE_3+kj; + const int rpbIndex = h * rpb.stride(0) + (pi+ki) * rpb.stride(1) + (pj+kj) * rpb.stride(2); + attn.data()[index] = static_cast(__hadd(updt.x, updt.y)) + rpb.data()[rpbIndex]; + } + //} +} + +/* TODO: CHECK BANK CONFLICTS */ +template +__global__ void natten2dqkrpb_cuda_forward_kernel_fp32_3x3_32( + const torch::PackedTensorAccessor32 query, + const torch::PackedTensorAccessor32 key, + const torch::PackedTensorAccessor32 rpb, + torch::PackedTensorAccessor32 attn, + const int height, + const int width, + const int batch_size, + const int heads, + const int dilation_in) { + const int dilation = (DILATION>0) ? DILATION : dilation_in; + // Because batch heads have stride 1 per threadblock, we can just use blockIdx since blockDim will be 1 and threadIdx will + // always be 0. + // const int z = blockIdx.z * blockDim.z + threadIdx.z; + const int z = blockIdx.z; + const int b = z / heads; + const int h = z - b * heads; + // Not needed again because it will always be true. + // if (z < batch_size * heads) + // { + const int lti = threadIdx.y * (TILE_3*KERNEL_SIZE_3) + threadIdx.x; + const int batchHeadOffset = b * query.stride(0) + h * query.stride(1); + const int si = int(blockIdx.y / dilation) * (TILE_3 * dilation) + (blockIdx.y % dilation); + const int sj = int(blockIdx.x / dilation) * (TILE_3 * dilation) + (blockIdx.x % dilation); + const int sni = get_window_start(si, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int snj = get_window_start(sj, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + __shared__ scalar_t tile[TILE_3*TILE_3][DIM_32+3]; + __shared__ scalar_t kTile[KTILE_3*KTILE_3][DIM_32+3]; + + /* query tile */ + const int qtx = lti / QSTRIDE_3; + const int qty = (lti - qtx * QSTRIDE_3) * QITERS_3; + if (qtx < TILE_3*TILE_3) + { + int qi = qtx / TILE_3; + const int qj = (qtx - qi * TILE_3) * dilation + sj; + qi = qi * dilation + si; + if (qi < height && qj < width){ + #pragma unroll + for (int ti=0; ti < QITERS_3; ++ti) + tile[qtx][qty+ti] = query.data()[batchHeadOffset + qi * query.stride(2) + qj * query.stride(3) + qty+ti]; + } + } + /* key tile */ + const int ktx = lti / KSTRIDE_32; + const int kty = (lti - ktx * KSTRIDE_32) * KITERS_32; + if (ktx < KTILE_3*KTILE_3) + { + int bi = ktx / KTILE_3; + const int bj = (ktx - bi * KTILE_3) * dilation + snj; + bi = bi * dilation + sni; + if (bi < height && bj < width){ + const int keyOffset = batchHeadOffset + bi * query.stride(2) + bj * query.stride(3) + kty; + #pragma unroll + for (int ti=0; ti < KITERS_32; ++ti) + kTile[ktx][kty + ti] = key.data()[keyOffset + ti]; + } + } + __syncthreads(); + const int ii = threadIdx.y / KERNEL_SIZE_3; + const int ki = threadIdx.y - ii * KERNEL_SIZE_3; + const int jj = threadIdx.x / KERNEL_SIZE_3; + const int kj = threadIdx.x - jj * KERNEL_SIZE_3; + const int i = si + ii*dilation, j = sj + jj*dilation; + if (i < height && j < width){ + const int ni = get_window_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int nj = get_window_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int pi = get_pb_start(i, height, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + const int pj = get_pb_start(j, width, KERNEL_SIZE_3, NEIGHBORHOOD_SIZE_3, dilation); + scalar_t updt = scalar_t(0); + const int queryIdx = ii*TILE_3 + jj; + const int keyIdx = int((ni+ki*dilation - sni)/dilation)*KTILE_3 + int((nj+kj*dilation - snj)/dilation); + + #pragma unroll + for (int dimOffset=0; dimOffset < DIM_32; ++dimOffset) + updt += tile[queryIdx][dimOffset] * kTile[keyIdx][dimOffset]; + + const int index = b * attn.stride(0) + h * attn.stride(1) + i * attn.stride(2) + j * attn.stride(3) + ki*KERNEL_SIZE_3+kj; + const int rpbIndex = h * rpb.stride(0) + (pi+ki) * rpb.stride(1) + (pj+kj) * rpb.stride(2); + updt += rpb.data()[rpbIndex]; + attn.data()[index] = updt; + } + //} +} + + /* TODO: FIX BANK CONFLICTS */ template __global__ void natten2dqkrpb_cuda_forward_kernel_fp16_5x5_32( @@ -709,7 +896,7 @@ __global__ void natten2dqkrpb_cuda_forward_kernel_fp32_11x11_13x13_32( } template -__global__ void nattenq_cuda_backward_kernel_fp32( +__global__ void natten2dq_cuda_backward_kernel_fp32( torch::PackedTensorAccessor32 d_query, const torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 key, @@ -751,7 +938,7 @@ __global__ void nattenq_cuda_backward_kernel_fp32( } template -__global__ void nattenq_cuda_backward_kernel_fp16( +__global__ void natten2dq_cuda_backward_kernel_fp16( torch::PackedTensorAccessor32 d_query, const torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 key, @@ -797,7 +984,7 @@ __global__ void nattenq_cuda_backward_kernel_fp16( } template -__global__ void nattenrpb_cuda_backward_kernel_fp16( +__global__ void natten2drpb_cuda_backward_kernel_fp16( torch::PackedTensorAccessor32 d_rpb, const torch::PackedTensorAccessor32 d_attn, const int height, @@ -835,7 +1022,7 @@ __global__ void nattenrpb_cuda_backward_kernel_fp16( } template -__global__ void nattenrpb_cuda_backward_kernel( +__global__ void natten2drpb_cuda_backward_kernel( torch::PackedTensorAccessor32 d_rpb, const torch::PackedTensorAccessor32 d_attn, const int height, @@ -873,7 +1060,7 @@ __global__ void nattenrpb_cuda_backward_kernel( } template -__global__ void nattenk_cuda_backward_kernel_fp16( +__global__ void natten2dk_cuda_backward_kernel_fp16( torch::PackedTensorAccessor32 d_key, const torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 query, @@ -924,7 +1111,7 @@ __global__ void nattenk_cuda_backward_kernel_fp16( } template -__global__ void nattenk_cuda_backward_kernel_fp32( +__global__ void natten2dk_cuda_backward_kernel_fp32( torch::PackedTensorAccessor32 d_key, const torch::PackedTensorAccessor32 d_attn, const torch::PackedTensorAccessor32 query, @@ -1069,9 +1256,9 @@ torch::Tensor natten2dqkrpb_cuda_forward_tiled_32( int kernel_size = (RPB_MAX + 1) / 2; CHECK_FEATMAP(height, width, kernel_size, dilation); TORCH_CHECK(dim == DIM_32, "natten2dqkrpb_cuda_forward_fp32_tiled_32", " only supports 32-dim attention heads."); - TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_5 || kernel_size == KERNEL_SIZE_9 || - kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, - "natten2dqkrpb_cuda_forward_fp32_tiled_32", " only supports kernel sizes 5, 7, 9, 11, and 13."); + TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_3 || kernel_size == KERNEL_SIZE_5 || + kernel_size == KERNEL_SIZE_9 || kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, + "natten2dqkrpb_cuda_forward_fp32_tiled_32", " only supports kernel sizes 3, 5, 7, 9, 11, and 13."); int xsize = width * kernel_size; int ysize = height * kernel_size; int zsize = batch_size * heads; @@ -1088,6 +1275,12 @@ torch::Tensor natten2dqkrpb_cuda_forward_tiled_32( YTHREADS = XYTHREADS_7; BATCHTHREADS = BATCHTHREADS_7; } + else if (kernel_size == KERNEL_SIZE_3) + { + XTHREADS = XYTHREADS_3; + YTHREADS = XYTHREADS_3; + BATCHTHREADS = BATCHTHREADS_3; + } else if (kernel_size == KERNEL_SIZE_5) { XTHREADS = XYTHREADS_5; @@ -1130,8 +1323,14 @@ torch::Tensor natten2dqkrpb_cuda_forward_tiled_32( LAUNCH_DNA_KNS_TILED79(TILE_9, KTILE_9, KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, natten2dqkrpb_cuda_forward_kernel_fp32_7x7_9x9_32, blocks, threads, 0, stream, query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); + else if (kernel_size == KERNEL_SIZE_3) + LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp32_3x3_32, + blocks, threads, 0, stream, + query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); else if (kernel_size == KERNEL_SIZE_5) - LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp32_5x5_32, blocks, threads, 0, stream, query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); + LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp32_5x5_32, + blocks, threads, 0, stream, + query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); else if (kernel_size == KERNEL_SIZE_11) LAUNCH_DNA_KNS_TILED1113(TILE_11_X, TILE_11_Y, KTILE_11_X, KTILE_11_Y, KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, scalar_t, @@ -1163,9 +1362,9 @@ torch::Tensor natten2dqkrpb_cuda_forward_fp16_tiled_32( CHECK_FEATMAP(height, width, kernel_size, dilation); TORCH_CHECK(dimhalf*2 == query.size(4), "Dims per head must be an even number in FP16."); TORCH_CHECK(dimhalf*2 == DIM_32, "natten2dqkrpb_cuda_forward_fp16_tiled_32", " only supports 32-dim attention heads."); - TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_5 || kernel_size == KERNEL_SIZE_9 || - kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, - "natten2dqkrpb_cuda_forward_fp16_tiled_32", " only supports kernel sizes 5, 7, 9, 11, and 13."); + TORCH_CHECK(kernel_size == KERNEL_SIZE_7 || kernel_size == KERNEL_SIZE_3 || kernel_size == KERNEL_SIZE_5 || + kernel_size == KERNEL_SIZE_9 || kernel_size == KERNEL_SIZE_11 || kernel_size == KERNEL_SIZE_13, + "natten2dqkrpb_cuda_forward_fp16_tiled_32", " only supports kernel sizes 3, 5, 7, 9, 11, and 13."); int xsize = width * kernel_size; int ysize = height * kernel_size; int zsize = batch_size * heads; @@ -1182,6 +1381,12 @@ torch::Tensor natten2dqkrpb_cuda_forward_fp16_tiled_32( YTHREADS = XYTHREADS_7; BATCHTHREADS = BATCHTHREADS_7; } + else if (kernel_size == KERNEL_SIZE_3) + { + XTHREADS = XYTHREADS_3; + YTHREADS = XYTHREADS_3; + BATCHTHREADS = BATCHTHREADS_3; + } else if (kernel_size == KERNEL_SIZE_5) { XTHREADS = XYTHREADS_5; @@ -1224,8 +1429,14 @@ torch::Tensor natten2dqkrpb_cuda_forward_fp16_tiled_32( LAUNCH_DNA_KNS_TILED79(TILE_9, KTILE_9, KERNEL_SIZE_9, NEIGHBORHOOD_SIZE_9, dilation, natten2dqkrpb_cuda_forward_kernel_fp16_7x7_9x9_32, blocks, threads, 0, stream, query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); + else if (kernel_size == KERNEL_SIZE_3) + LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp16_3x3_32, + blocks, threads, 0, stream, + query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); else if (kernel_size == KERNEL_SIZE_5) - LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp16_5x5_32, blocks, threads, 0, stream, query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); + LAUNCH_DNA_DS(dilation, natten2dqkrpb_cuda_forward_kernel_fp16_5x5_32, + blocks, threads, 0, stream, + query_a, key_a, rpb_a, attn_a, height, width, batch_size, heads, dilation); else if (kernel_size == KERNEL_SIZE_11) LAUNCH_DNA_KNS_TILED1113(TILE_11_X, TILE_11_Y, KTILE_11_X, KTILE_11_Y, KERNEL_SIZE_11, NEIGHBORHOOD_SIZE_11, dilation, scalar_t, @@ -1282,11 +1493,11 @@ std::vector natten2dqkrpb_cuda_backward( const auto d_attn_a = d_attn.packed_accessor32(); const auto query_a = query.packed_accessor32(); const auto key_a = key.packed_accessor32(); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenrpb_cuda_backward_kernel, grid_rpb, blockr, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2drpb_cuda_backward_kernel, grid_rpb, blockr, 0, stream, d_rpb_a, d_attn_a, height, width, dilation, batch_size, d_rpb.numel(), n_rpb); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenq_cuda_backward_kernel_fp32, grid_query, blockq, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dq_cuda_backward_kernel_fp32, grid_query, blockq, 0, stream, d_query_a, d_attn_a, key_a, height, width, heads, dilation, dim, n_query); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenk_cuda_backward_kernel_fp32, grid_key, blockk, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dk_cuda_backward_kernel_fp32, grid_key, blockk, 0, stream, d_key_a, d_attn_a, query_a, height, width, heads, dilation, dim, n_key); })); return {d_query, d_key, d_rpb}; @@ -1333,11 +1544,11 @@ std::vector natten2dqkrpb_cuda_backward_fp16( const auto d_attn_a = d_attn.packed_accessor32(); const auto query_a = query.packed_accessor32(); const auto key_a = key.packed_accessor32(); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenrpb_cuda_backward_kernel_fp16, grid_rpb, blockr, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2drpb_cuda_backward_kernel_fp16, grid_rpb, blockr, 0, stream, d_rpb_a, d_attn_a, height, width, dilation, batch_size, d_rpb.numel(), n_rpb); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenq_cuda_backward_kernel_fp16, grid_query, blockq, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dq_cuda_backward_kernel_fp16, grid_query, blockq, 0, stream, d_query_a, d_attn_a, key_a, height, width, heads, dilation, dimhalf, nhalf_query); - LAUNCH_DNA_KNS(kernel_size, dilation, nattenk_cuda_backward_kernel_fp16, grid_key, blockk, 0, stream, + LAUNCH_DNA_KNS(kernel_size, dilation, natten2dk_cuda_backward_kernel_fp16, grid_key, blockk, 0, stream, d_key_a, d_attn_a, query_a, height, width, heads, dilation, dimhalf, nhalf_key); })); return {d_query, d_key, d_rpb}; diff --git a/natten/src/cuda/natten_commons.cuh b/natten/src/cuda/natten_commons.cuh index f53fbb0..229a7a6 100644 --- a/natten/src/cuda/natten_commons.cuh +++ b/natten/src/cuda/natten_commons.cuh @@ -10,7 +10,14 @@ LICENSE file in the root directory of this source tree. #include #include -#define AT_DISPATCH_HALF_TYPES(SCALARTYPE1, TYPE, NAME, ...) \ +#if defined(TORCH_113) + +#define AT_DISPATCH_HALF_TYPES(SCALARTYPE1, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)) + +#else + +#define AT_DISPATCH_HALF_TYPES(SCALARTYPE1, TYPE, NAME, ...) \ [&] { \ const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ @@ -27,6 +34,7 @@ LICENSE file in the root directory of this source tree. } \ }() +#endif #define CUDA_NUM_THREADS 1024 diff --git a/natten/src/natten.cpp b/natten/src/natten.cpp index b98f7d3..1208429 100644 --- a/natten/src/natten.cpp +++ b/natten/src/natten.cpp @@ -18,9 +18,9 @@ namespace natten { m.def("natten1dav_forward", &natten1dav_forward, "NATTEN1DAV forward"); m.def("natten1dav_backward", &natten1dav_backward, "NATTEN1DAV backward"); - m.def("natten2dqkrpb_forward", &natten2dqkrpb_forward, "NATTENQK+RPB forward"); - m.def("natten2dqkrpb_backward", &natten2dqkrpb_backward, "NATTENQK+RPB backward"); - m.def("natten2dav_forward", &natten2dav_forward, "NATTENAV forward"); - m.def("natten2dav_backward", &natten2dav_backward, "NATTENAV backward"); + m.def("natten2dqkrpb_forward", &natten2dqkrpb_forward, "NATTEN2DQK+RPB forward"); + m.def("natten2dqkrpb_backward", &natten2dqkrpb_backward, "NATTEN2DQK+RPB backward"); + m.def("natten2dav_forward", &natten2dav_forward, "NATTEN2DAV forward"); + m.def("natten2dav_backward", &natten2dav_backward, "NATTEN2DAV backward"); } } // namespace natten diff --git a/natten/src/natten2dav.h b/natten/src/natten2dav.h index b78e3fc..9c11abb 100644 --- a/natten/src/natten2dav.h +++ b/natten/src/natten2dav.h @@ -1,5 +1,5 @@ /* -NATTEN-AV TORCH EXTENSION +NATTEN2D-AV TORCH EXTENSION This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. @@ -99,7 +99,10 @@ std::vector natten2dav_backward( int dim = value.size(4); int kernel_size = sqrt(attn.size(4)); bool half = ::detail::scalar_type(value.scalar_type()) == at::ScalarType::Half; - if ((kernel_size == 7 || kernel_size == 5 || kernel_size == 9 || kernel_size == 11 || kernel_size == 13) && dim == 32){ + if (( + kernel_size == 7 || kernel_size == 3 || kernel_size == 5 || + kernel_size == 9 || kernel_size == 11 || kernel_size == 13 + ) && dim == 32){ if (half) return natten2dav_cuda_backward_fp16_tiled_32(d_out, attn, value, dilation); return natten2dav_cuda_backward_tiled_32(d_out, attn, value, dilation); diff --git a/natten/src/natten2dqkrpb.h b/natten/src/natten2dqkrpb.h index 94203e1..1c32ce9 100644 --- a/natten/src/natten2dqkrpb.h +++ b/natten/src/natten2dqkrpb.h @@ -1,5 +1,5 @@ /* -NATTEN-QKRPB TORCH EXTENSION +NATTEN2D-QKRPB TORCH EXTENSION This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. @@ -82,7 +82,10 @@ torch::Tensor natten2dqkrpb_forward( int dim = query.size(4); int kernel_size = (rpb.size(1) + 1) / 2; bool half = ::detail::scalar_type(query.scalar_type()) == at::ScalarType::Half; - if ((kernel_size == 7 || kernel_size == 5 || kernel_size == 9 || kernel_size == 11 || kernel_size == 13) && dim == 32){ + if (( + kernel_size == 7 || kernel_size == 3 || kernel_size == 5 || + kernel_size == 9 || kernel_size == 11 || kernel_size == 13 + ) && dim == 32){ if (half) return natten2dqkrpb_cuda_forward_fp16_tiled_32(query, key, rpb, dilation); return natten2dqkrpb_cuda_forward_tiled_32(query, key, rpb, dilation); diff --git a/setup.py b/setup.py index ea3b641..70e16da 100644 --- a/setup.py +++ b/setup.py @@ -23,12 +23,17 @@ from pathlib import Path this_directory = Path(__file__).parent -long_description = (this_directory / "assets/README_pypi.md").read_text() +try: + long_description = (this_directory / "assets/README_pypi.md").read_text() +except: + long_description = "Neighborhood Attention Extension." torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] assert torch_ver >= [1, 8], "NATTEN requires PyTorch >= 1.8" AVX_INT = torch_ver >= [1, 10] +TORCH_113 = torch_ver >= [1, 13] HAS_CUDA = (torch.cuda.is_available() and (CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1") +NATTEN_VERSION_SUFFIX = os.getenv("NATTEN_VERSION_SUFFIX", "") def get_version(): @@ -36,7 +41,9 @@ def get_version(): init_py = open(init_py_path, "r").readlines() version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] version = version_line.split("=")[-1].strip().strip("'\"") - PYTORCH_VERSION = ''.join(torch.__version__.split('.')[:2]) + if NATTEN_VERSION_SUFFIX != "1": + return f'{version}{NATTEN_VERSION_SUFFIX}' + PYTORCH_VERSION = ''.join(torch.__version__.split('+')[0].split('.')) if HAS_CUDA: CUDA_VERSION = ''.join(torch.version.cuda.split('.')[:2]) @@ -70,6 +77,8 @@ def get_extension(): extension = CppExtension extra_compile_args = {"cxx": ["-O3"]} define_macros = [] + if TORCH_113: + define_macros += [("TORCH_113", 1)] if AVX_INT: define_macros += [("AVX_INT", 1)] else: diff --git a/tests/test_na2d.py b/tests/test_na2d.py index 5df47d2..d05925e 100644 --- a/tests/test_na2d.py +++ b/tests/test_na2d.py @@ -174,5 +174,109 @@ def test_cpu_cuda_allclose(self): b, li, lj = 4, 14, 16 _priv_test_allclose_cpu_cuda(b, li, lj) + def test_natten2dqkrpb_tiled3x3_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 6, 6, 32, 3, 2 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dav_tiled3x3_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 6, 6, 32, 3, 2 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dqkrpb_tiled5x5_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 10, 10, 32, 5, 2 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dav_tiled5x5_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 10, 10, 32, 5, 2 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dqkrpb_tiled7x7_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 14, 14, 32, 7, 2 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dav_tiled7x7_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 14, 14, 32, 7, 2 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dqkrpb_tiled9x9_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 18, 18, 32, 9, 2 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dav_tiled9x9_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 18, 18, 32, 9, 2 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, False) + + def test_natten2dqkrpb_tiled11x11_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 22, 22, 32, 11, 2 + # TODO: Disable FAST MODE + # Presently we do fast mode because otherwise this test will throw an OOM + # Tested on an 80GB A100 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, + #False + True) + + def test_natten2dav_tiled11x11_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 22, 22, 32, 11, 2 + # TODO: Disable FAST MODE + # Presently we do fast mode because otherwise this test will throw an OOM + # Tested on an 80GB A100 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, + #False + True) + + def test_natten2dqkrpb_tiled13x13_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 26, 26, 32, 13, 2 + # TODO: Disable FAST MODE + # Presently we do fast mode because otherwise this test will throw an OOM + # Tested on an 80GB A100 + _priv_test_gradcheck_natten2dqkrpb(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, + #False + True) + + def test_natten2dav_tiled13x13_gradcheck_cuda(self): + if not HAS_CUDA: + self.skipTest("CUDA not available.") + b, h, li, lj, d, k, di = 1, 1, 26, 26, 32, 13, 2 + # TODO: Disable FAST MODE + # Presently we do fast mode because otherwise this test will throw an OOM + # Tested on an 80GB A100 + _priv_test_gradcheck_natten2dav(b, h, li, lj, d, k, di, torch.float64, 'cuda', + 1e-6, 1e-5, 1e-3, 1e-8, + #False + True) + if __name__ == "__main__": unittest.main()