Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a library node for np.transpose for >2 dims and a cuTENSOR implementation #1303

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,14 +784,13 @@ def _transpose(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, inpname: str, a
state.add_edge(acc1, None, tasklet, '_inp', Memlet.from_array(inpname, arr1))
state.add_edge(tasklet, '_out', acc2, None, Memlet.from_array(outname, arr2))
else:
state.add_mapped_tasklet(
"_transpose_", {"_i{}".format(i): "0:{}".format(s)
for i, s in enumerate(arr1.shape)},
dict(_in=Memlet.simple(inpname, ", ".join("_i{}".format(i) for i, _ in enumerate(arr1.shape)))),
"_out = _in",
dict(_out=Memlet.simple(outname, ", ".join("_i{}".format(axes[i]) for i, _ in enumerate(arr1.shape)))),
external_edges=True)

acc1 = state.add_read(inpname)
acc2 = state.add_write(outname)
import dace.libraries.blas # Avoid import loop
tasklet = dace.libraries.blas.Permute('_Permute_', axes=axes, dtype=restype)
state.add_node(tasklet)
state.add_edge(acc1, None, tasklet, '_inp', Memlet.from_array(inpname, arr1))
state.add_edge(tasklet, '_out', acc2, None, Memlet.from_array(outname, arr2))
return outname


Expand Down
1 change: 1 addition & 0 deletions dace/libraries/blas/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .intel_mkl import *
from .cublas import *
from .rocblas import *
from .cutensor import *
54 changes: 54 additions & 0 deletions dace/libraries/blas/environments/cutensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

import dace.library
import ctypes.util


@dace.library.environment
class cuTENSOR:

cmake_minimum_version = None
cmake_packages = ["CUDA"]
cmake_variables = {}
cmake_includes = []
cmake_libraries = ["cutensor"]
cmake_compile_flags = ["-L/users/jbazinsk/libcutensor-linux-x86_64-1.7.0.1-archive/lib/11"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to set up the LIBRARY_PATH envvar locally.

cmake_link_flags = []
cmake_files = []

headers = {'frame': ["../include/dace_cutensor.h"], 'cuda': ["../include/dace_cutensor.h"]}
state_fields = ["dace::blas::CutensorHandle cutensor_handle;"]
init_code = ""
finalize_code = ""
dependencies = []

@staticmethod
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
else:
try:
location = int(location["gpu"])
except ValueError:
raise ValueError("Invalid GPU identifier: {}".format(location))

code = """\
const int __dace_cuda_device = {location};
cutensorHandle_t* __dace_cutensor_handle = __state->cutensor_handle.Get(__dace_cuda_device);\n"""

return code.format(location=location)

@staticmethod
def _find_library():
# *nix-based search
blas_path = ctypes.util.find_library('cutensor')
if blas_path:
return [blas_path]

# Windows-based search
versions = (10, 11, 12)
for version in versions:
blas_path = ctypes.util.find_library(f'cutensor64_{version}')
if blas_path:
return [blas_path]
return []
67 changes: 67 additions & 0 deletions dace/libraries/blas/include/dace_cutensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
// Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

#pragma once

#include <cuda_runtime.h>
#include <cutensor.h>

#include <cstddef> // size_t
#include <stdexcept> // std::runtime_error
#include <string> // std::to_string
#include <unordered_map>

namespace dace {

namespace blas {

static void CheckCutensorError(cutensorStatus_t const& status) {
if (status != CUTENSOR_STATUS_SUCCESS) {
throw std::runtime_error("cuSPARSE failed with error code: " + std::to_string(status));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't there a cutensorGetErrorString or something similar?

}
}

static cutensorHandle_t* CreateCutensorHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
cutensorHandle_t* handle;
CheckCutensorError(cutensorCreate(&handle));
return handle;
}



/**
* CUsparse wrapper class for DaCe. Once constructed, the class can be used to
* get or create a cuSPARSE library handle (cutensorHandle_t) for a given
* GPU ID. The class is constructed when the cuSPARSE DaCe library is used.
**/
class CutensorHandle {
public:
CutensorHandle() = default;
CutensorHandle(CutensorHandle const&) = delete;

cutensorHandle_t* Get(int device) {
auto f = handles_.find(device);
if (f == handles_.end()) {
// Lazily construct new cutensor handle if the specified key does not
// yet exist
cutensorHandle_t* handle = CreateCutensorHandle(device);
f = handles_.emplace(device, handle).first;
}
return f->second;
}

~CutensorHandle() {
for (auto& h : handles_) {
CheckCutensorError(cutensorDestroy(h.second));
}
}

CutensorHandle& operator=(CutensorHandle const&) = delete;

std::unordered_map<int, cutensorHandle_t*> handles_;
};

} // namespace tensor

} // namespace dace
2 changes: 1 addition & 1 deletion dace/libraries/blas/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from .ger import Ger
from .batched_matmul import BatchedMatMul
from .transpose import Transpose

from .permute import Permute
from .axpy import Axpy
from .einsum import Einsum
203 changes: 203 additions & 0 deletions dace/libraries/blas/nodes/permute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

import functools
from copy import deepcopy as dc
from typing import List

from dace.config import Config
import dace.library
import dace.properties
import dace.sdfg.nodes
from dace.libraries.blas import blas_helpers
from dace.transformation.transformation import ExpandTransformation
from .. import environments
import warnings


def _get_permute_input(node, state, sdfg):
"""Returns the permute input edge, array, and shape."""
for edge in state.in_edges(node):
if edge.dst_conn == "_inp":
subset = dc(edge.data.subset)
subset.squeeze()
size = subset.size()
outer_array = sdfg.data(dace.sdfg.find_input_arraynode(state, edge).data)
return edge, outer_array, size
raise ValueError("Permute input connector \"_inp\" not found.")


def _get_permute_output(node, state, sdfg):
"""Returns the permute output edge, array, and shape."""
for edge in state.out_edges(node):
if edge.src_conn == "_out":
subset = dc(edge.data.subset)
subset.squeeze()
size = subset.size()
outer_array = sdfg.data(dace.sdfg.find_output_arraynode(state, edge).data)
return edge, outer_array, size
raise ValueError("Permute output connector \"_out\" not found.")


@dace.library.expansion
class ExpandPermutePure(ExpandTransformation):
environments = []

@staticmethod
def make_sdfg(node, parent_state, parent_sdfg):

in_edge, in_outer_array, in_shape = _get_permute_input(node, parent_state, parent_sdfg)
out_edge, out_outer_array, out_shape = _get_permute_output(node, parent_state, parent_sdfg)
dtype = node.dtype
axes = node.axes
sdfg = dace.SDFG(node.label + "_sdfg")
state = sdfg.add_state(node.label + "_state")

_, in_array = sdfg.add_array("_inp",
in_shape,
dtype,
strides=in_outer_array.strides,
storage=in_outer_array.storage)
_, out_array = sdfg.add_array("_out",
out_shape,
dtype,
strides=out_outer_array.strides,
storage=out_outer_array.storage)

num_elements = functools.reduce(lambda x, y: x * y, in_array.shape)
if num_elements == 1:
inp = state.add_read("_inp")
out = state.add_write("_out")
tasklet = state.add_tasklet("permute", {"__inp"}, {"__out"}, "__out = __inp")
state.add_edge(inp, None, tasklet, "__inp", dace.memlet.Memlet.from_array("_inp", in_array))
state.add_edge(tasklet, "__out", out, None, dace.memlet.Memlet.from_array("_out", out_array))
else:
state.add_mapped_tasklet(
"_permute_", {"_i{}".format(i): "0:{}".format(s)
for i, s in enumerate(in_array.shape)},
dict(_tmp_in=dace.memlet.Memlet.simple("_inp", ", ".join("_i{}".format(i) for i, _ in enumerate(in_array.shape)))),
"_tmp_out = _tmp_in",
dict(_tmp_out=dace.memlet.Memlet.simple("_out", ", ".join("_i{}".format(axes[i]) for i, _ in enumerate(in_array.shape)))),
external_edges=True)

return sdfg

@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
return ExpandPermutePure.make_sdfg(node, state, sdfg)


@dace.library.expansion
class ExpandPermuteCuTENSOR(ExpandTransformation):

environments = [environments.cutensor.cuTENSOR]

@staticmethod
def expansion(node, state, sdfg, **kwargs):
node.validate(sdfg, state)
dtype = node.dtype
axes = node.axes

cuda_dtype = blas_helpers.dtype_to_cudadatatype(dtype)

in_edge, in_outer_array, in_shape = _get_permute_input(node, state, sdfg)
out_edge, out_outer_array, out_shape = _get_permute_output(node, state, sdfg)

num_dims = len(axes)
modeA = ', '.join([str(x) for x in axes])
modeC = ', '.join([str(x) for x in range(len(axes))])

stridesA = ', '.join([str(x) for x in in_outer_array.strides])
stridesC = ', '.join([str(x) for x in out_outer_array.strides])

code_prefix = environments.cuTENSOR.handle_setup_code(node)
code_call = f"""
int modeC[] = {{ {modeC} }};
int modeA[] = {{ {modeA} }};

int64_t extentA[] = {{ {', '.join([str(x) for x in in_shape])} }};
int64_t extentC[] = {{ {', '.join([str(x) for x in out_shape])} }};
int64_t stridesA[] = {{ {stridesA} }};
int64_t stridesC[] = {{ {stridesC} }};

cutensorTensorDescriptor_t descA;
dace::blas::CheckCutensorError(cutensorInitTensorDescriptor(__dace_cutensor_handle,
&descA,
{num_dims},
extentA,
stridesA,
{cuda_dtype}, CUTENSOR_OP_IDENTITY));

cutensorTensorDescriptor_t descC;
dace::blas::CheckCutensorError(cutensorInitTensorDescriptor(__dace_cutensor_handle,
&descC,
{num_dims},
extentC,
stridesC,
{cuda_dtype}, CUTENSOR_OP_IDENTITY));

const float one = 1.0f;
cutensorPermutation(
__dace_cutensor_handle,
&one,
/*A=*/_inp,
&descA,
/*axes A=*/modeA,
/*C=*/_out,
&descC,
/*axes C=*/modeC,
/*computeType=*/{cuda_dtype},
/*stream=*/__dace_current_stream
);
"""

tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
code_prefix + code_call,
language=dace.dtypes.Language.CPP)

return tasklet


@dace.library.node
class Permute(dace.sdfg.nodes.LibraryNode):
# Global properties
implementations = {
"pure": ExpandPermutePure,
"cuTENSOR": ExpandPermuteCuTENSOR,
}
default_implementation = None

dtype = dace.properties.TypeClassProperty(allow_none=True)
axes = dace.properties.ListProperty(element_type=int, allow_none=True,
desc="Axes to permute.")

def __init__(self, name, axes, dtype=None, location=None, ):
super().__init__(name, location=location, inputs={'_inp'}, outputs={'_out'})
self.dtype = dtype
self.axes = axes

def validate(self, sdfg, state):
in_edges = state.in_edges(self)
if len(in_edges) != 1:
raise ValueError("Expected exactly one input to permute operation")
in_size = None
for _, _, _, dst_conn, memlet in state.in_edges(self):
if dst_conn == '_inp':
subset = dc(memlet.subset)
subset.squeeze()
in_size = subset.size()
if in_size is None:
raise ValueError("Input connector not found.")
out_edges = state.out_edges(self)
if len(out_edges) != 1:
raise ValueError("Expected exactly one output from permute operation")
out_memlet = out_edges[0].data

out_subset = dc(out_memlet.subset)
out_subset.squeeze()
out_size = out_subset.size()
if len(out_size) != len(in_size):
raise ValueError("Permute operation only supported on matrices of same dimensionalities.")
if set(out_size) != set(in_size):
raise ValueError("Expected input size to be a permutation of output size.")
Loading
Loading