-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a library node for np.transpose for >2 dims and a cuTENSOR implementation #1303
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
from .intel_mkl import * | ||
from .cublas import * | ||
from .rocblas import * | ||
from .cutensor import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||
import dace.library | ||
import ctypes.util | ||
|
||
|
||
@dace.library.environment | ||
class cuTENSOR: | ||
|
||
cmake_minimum_version = None | ||
cmake_packages = ["CUDA"] | ||
cmake_variables = {} | ||
cmake_includes = [] | ||
cmake_libraries = ["cutensor"] | ||
cmake_compile_flags = ["-L/users/jbazinsk/libcutensor-linux-x86_64-1.7.0.1-archive/lib/11"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe better to set up the |
||
cmake_link_flags = [] | ||
cmake_files = [] | ||
|
||
headers = {'frame': ["../include/dace_cutensor.h"], 'cuda': ["../include/dace_cutensor.h"]} | ||
state_fields = ["dace::blas::CutensorHandle cutensor_handle;"] | ||
init_code = "" | ||
finalize_code = "" | ||
dependencies = [] | ||
|
||
@staticmethod | ||
def handle_setup_code(node): | ||
location = node.location | ||
if not location or "gpu" not in node.location: | ||
location = 0 | ||
else: | ||
try: | ||
location = int(location["gpu"]) | ||
except ValueError: | ||
raise ValueError("Invalid GPU identifier: {}".format(location)) | ||
|
||
code = """\ | ||
const int __dace_cuda_device = {location}; | ||
cutensorHandle_t* __dace_cutensor_handle = __state->cutensor_handle.Get(__dace_cuda_device);\n""" | ||
|
||
return code.format(location=location) | ||
|
||
@staticmethod | ||
def _find_library(): | ||
# *nix-based search | ||
blas_path = ctypes.util.find_library('cutensor') | ||
if blas_path: | ||
return [blas_path] | ||
|
||
# Windows-based search | ||
versions = (10, 11, 12) | ||
for version in versions: | ||
blas_path = ctypes.util.find_library(f'cutensor64_{version}') | ||
if blas_path: | ||
return [blas_path] | ||
return [] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,67 @@ | ||||||
// Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
#pragma once | ||||||
|
||||||
#include <cuda_runtime.h> | ||||||
#include <cutensor.h> | ||||||
|
||||||
#include <cstddef> // size_t | ||||||
#include <stdexcept> // std::runtime_error | ||||||
#include <string> // std::to_string | ||||||
#include <unordered_map> | ||||||
|
||||||
namespace dace { | ||||||
|
||||||
namespace blas { | ||||||
|
||||||
static void CheckCutensorError(cutensorStatus_t const& status) { | ||||||
if (status != CUTENSOR_STATUS_SUCCESS) { | ||||||
throw std::runtime_error("cuSPARSE failed with error code: " + std::to_string(status)); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't there a cutensorGetErrorString or something similar? |
||||||
} | ||||||
} | ||||||
|
||||||
static cutensorHandle_t* CreateCutensorHandle(int device) { | ||||||
if (cudaSetDevice(device) != cudaSuccess) { | ||||||
throw std::runtime_error("Failed to set CUDA device."); | ||||||
} | ||||||
cutensorHandle_t* handle; | ||||||
CheckCutensorError(cutensorCreate(&handle)); | ||||||
return handle; | ||||||
} | ||||||
|
||||||
|
||||||
|
||||||
/** | ||||||
* CUsparse wrapper class for DaCe. Once constructed, the class can be used to | ||||||
* get or create a cuSPARSE library handle (cutensorHandle_t) for a given | ||||||
* GPU ID. The class is constructed when the cuSPARSE DaCe library is used. | ||||||
**/ | ||||||
class CutensorHandle { | ||||||
public: | ||||||
CutensorHandle() = default; | ||||||
CutensorHandle(CutensorHandle const&) = delete; | ||||||
|
||||||
cutensorHandle_t* Get(int device) { | ||||||
auto f = handles_.find(device); | ||||||
if (f == handles_.end()) { | ||||||
// Lazily construct new cutensor handle if the specified key does not | ||||||
// yet exist | ||||||
cutensorHandle_t* handle = CreateCutensorHandle(device); | ||||||
f = handles_.emplace(device, handle).first; | ||||||
} | ||||||
return f->second; | ||||||
} | ||||||
|
||||||
~CutensorHandle() { | ||||||
for (auto& h : handles_) { | ||||||
CheckCutensorError(cutensorDestroy(h.second)); | ||||||
} | ||||||
} | ||||||
|
||||||
CutensorHandle& operator=(CutensorHandle const&) = delete; | ||||||
|
||||||
std::unordered_map<int, cutensorHandle_t*> handles_; | ||||||
}; | ||||||
|
||||||
} // namespace tensor | ||||||
|
||||||
} // namespace dace |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,203 @@ | ||||||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
import functools | ||||||
from copy import deepcopy as dc | ||||||
from typing import List | ||||||
|
||||||
from dace.config import Config | ||||||
import dace.library | ||||||
import dace.properties | ||||||
import dace.sdfg.nodes | ||||||
from dace.libraries.blas import blas_helpers | ||||||
from dace.transformation.transformation import ExpandTransformation | ||||||
from .. import environments | ||||||
import warnings | ||||||
|
||||||
|
||||||
def _get_permute_input(node, state, sdfg): | ||||||
"""Returns the permute input edge, array, and shape.""" | ||||||
for edge in state.in_edges(node): | ||||||
if edge.dst_conn == "_inp": | ||||||
subset = dc(edge.data.subset) | ||||||
subset.squeeze() | ||||||
size = subset.size() | ||||||
outer_array = sdfg.data(dace.sdfg.find_input_arraynode(state, edge).data) | ||||||
return edge, outer_array, size | ||||||
raise ValueError("Permute input connector \"_inp\" not found.") | ||||||
|
||||||
|
||||||
def _get_permute_output(node, state, sdfg): | ||||||
"""Returns the permute output edge, array, and shape.""" | ||||||
for edge in state.out_edges(node): | ||||||
if edge.src_conn == "_out": | ||||||
subset = dc(edge.data.subset) | ||||||
subset.squeeze() | ||||||
size = subset.size() | ||||||
outer_array = sdfg.data(dace.sdfg.find_output_arraynode(state, edge).data) | ||||||
return edge, outer_array, size | ||||||
raise ValueError("Permute output connector \"_out\" not found.") | ||||||
|
||||||
|
||||||
@dace.library.expansion | ||||||
class ExpandPermutePure(ExpandTransformation): | ||||||
environments = [] | ||||||
|
||||||
@staticmethod | ||||||
def make_sdfg(node, parent_state, parent_sdfg): | ||||||
|
||||||
in_edge, in_outer_array, in_shape = _get_permute_input(node, parent_state, parent_sdfg) | ||||||
out_edge, out_outer_array, out_shape = _get_permute_output(node, parent_state, parent_sdfg) | ||||||
dtype = node.dtype | ||||||
axes = node.axes | ||||||
sdfg = dace.SDFG(node.label + "_sdfg") | ||||||
state = sdfg.add_state(node.label + "_state") | ||||||
|
||||||
_, in_array = sdfg.add_array("_inp", | ||||||
in_shape, | ||||||
dtype, | ||||||
strides=in_outer_array.strides, | ||||||
storage=in_outer_array.storage) | ||||||
_, out_array = sdfg.add_array("_out", | ||||||
out_shape, | ||||||
dtype, | ||||||
strides=out_outer_array.strides, | ||||||
storage=out_outer_array.storage) | ||||||
|
||||||
num_elements = functools.reduce(lambda x, y: x * y, in_array.shape) | ||||||
if num_elements == 1: | ||||||
inp = state.add_read("_inp") | ||||||
out = state.add_write("_out") | ||||||
tasklet = state.add_tasklet("permute", {"__inp"}, {"__out"}, "__out = __inp") | ||||||
state.add_edge(inp, None, tasklet, "__inp", dace.memlet.Memlet.from_array("_inp", in_array)) | ||||||
state.add_edge(tasklet, "__out", out, None, dace.memlet.Memlet.from_array("_out", out_array)) | ||||||
else: | ||||||
state.add_mapped_tasklet( | ||||||
"_permute_", {"_i{}".format(i): "0:{}".format(s) | ||||||
for i, s in enumerate(in_array.shape)}, | ||||||
dict(_tmp_in=dace.memlet.Memlet.simple("_inp", ", ".join("_i{}".format(i) for i, _ in enumerate(in_array.shape)))), | ||||||
"_tmp_out = _tmp_in", | ||||||
dict(_tmp_out=dace.memlet.Memlet.simple("_out", ", ".join("_i{}".format(axes[i]) for i, _ in enumerate(in_array.shape)))), | ||||||
external_edges=True) | ||||||
|
||||||
return sdfg | ||||||
|
||||||
@staticmethod | ||||||
def expansion(node, state, sdfg): | ||||||
node.validate(sdfg, state) | ||||||
return ExpandPermutePure.make_sdfg(node, state, sdfg) | ||||||
|
||||||
|
||||||
@dace.library.expansion | ||||||
class ExpandPermuteCuTENSOR(ExpandTransformation): | ||||||
|
||||||
environments = [environments.cutensor.cuTENSOR] | ||||||
|
||||||
@staticmethod | ||||||
def expansion(node, state, sdfg, **kwargs): | ||||||
node.validate(sdfg, state) | ||||||
dtype = node.dtype | ||||||
axes = node.axes | ||||||
|
||||||
cuda_dtype = blas_helpers.dtype_to_cudadatatype(dtype) | ||||||
|
||||||
in_edge, in_outer_array, in_shape = _get_permute_input(node, state, sdfg) | ||||||
out_edge, out_outer_array, out_shape = _get_permute_output(node, state, sdfg) | ||||||
|
||||||
num_dims = len(axes) | ||||||
modeA = ', '.join([str(x) for x in axes]) | ||||||
modeC = ', '.join([str(x) for x in range(len(axes))]) | ||||||
|
||||||
stridesA = ', '.join([str(x) for x in in_outer_array.strides]) | ||||||
stridesC = ', '.join([str(x) for x in out_outer_array.strides]) | ||||||
|
||||||
code_prefix = environments.cuTENSOR.handle_setup_code(node) | ||||||
code_call = f""" | ||||||
int modeC[] = {{ {modeC} }}; | ||||||
int modeA[] = {{ {modeA} }}; | ||||||
|
||||||
int64_t extentA[] = {{ {', '.join([str(x) for x in in_shape])} }}; | ||||||
int64_t extentC[] = {{ {', '.join([str(x) for x in out_shape])} }}; | ||||||
int64_t stridesA[] = {{ {stridesA} }}; | ||||||
int64_t stridesC[] = {{ {stridesC} }}; | ||||||
|
||||||
cutensorTensorDescriptor_t descA; | ||||||
dace::blas::CheckCutensorError(cutensorInitTensorDescriptor(__dace_cutensor_handle, | ||||||
&descA, | ||||||
{num_dims}, | ||||||
extentA, | ||||||
stridesA, | ||||||
{cuda_dtype}, CUTENSOR_OP_IDENTITY)); | ||||||
|
||||||
cutensorTensorDescriptor_t descC; | ||||||
dace::blas::CheckCutensorError(cutensorInitTensorDescriptor(__dace_cutensor_handle, | ||||||
&descC, | ||||||
{num_dims}, | ||||||
extentC, | ||||||
stridesC, | ||||||
{cuda_dtype}, CUTENSOR_OP_IDENTITY)); | ||||||
|
||||||
const float one = 1.0f; | ||||||
cutensorPermutation( | ||||||
__dace_cutensor_handle, | ||||||
&one, | ||||||
/*A=*/_inp, | ||||||
&descA, | ||||||
/*axes A=*/modeA, | ||||||
/*C=*/_out, | ||||||
&descC, | ||||||
/*axes C=*/modeC, | ||||||
/*computeType=*/{cuda_dtype}, | ||||||
/*stream=*/__dace_current_stream | ||||||
); | ||||||
""" | ||||||
|
||||||
tasklet = dace.sdfg.nodes.Tasklet(node.name, | ||||||
node.in_connectors, | ||||||
node.out_connectors, | ||||||
code_prefix + code_call, | ||||||
language=dace.dtypes.Language.CPP) | ||||||
|
||||||
return tasklet | ||||||
|
||||||
|
||||||
@dace.library.node | ||||||
class Permute(dace.sdfg.nodes.LibraryNode): | ||||||
# Global properties | ||||||
implementations = { | ||||||
"pure": ExpandPermutePure, | ||||||
"cuTENSOR": ExpandPermuteCuTENSOR, | ||||||
} | ||||||
default_implementation = None | ||||||
|
||||||
dtype = dace.properties.TypeClassProperty(allow_none=True) | ||||||
axes = dace.properties.ListProperty(element_type=int, allow_none=True, | ||||||
desc="Axes to permute.") | ||||||
|
||||||
def __init__(self, name, axes, dtype=None, location=None, ): | ||||||
super().__init__(name, location=location, inputs={'_inp'}, outputs={'_out'}) | ||||||
self.dtype = dtype | ||||||
self.axes = axes | ||||||
|
||||||
def validate(self, sdfg, state): | ||||||
in_edges = state.in_edges(self) | ||||||
if len(in_edges) != 1: | ||||||
raise ValueError("Expected exactly one input to permute operation") | ||||||
in_size = None | ||||||
for _, _, _, dst_conn, memlet in state.in_edges(self): | ||||||
if dst_conn == '_inp': | ||||||
subset = dc(memlet.subset) | ||||||
subset.squeeze() | ||||||
in_size = subset.size() | ||||||
if in_size is None: | ||||||
raise ValueError("Input connector not found.") | ||||||
out_edges = state.out_edges(self) | ||||||
if len(out_edges) != 1: | ||||||
raise ValueError("Expected exactly one output from permute operation") | ||||||
out_memlet = out_edges[0].data | ||||||
|
||||||
out_subset = dc(out_memlet.subset) | ||||||
out_subset.squeeze() | ||||||
out_size = out_subset.size() | ||||||
if len(out_size) != len(in_size): | ||||||
raise ValueError("Permute operation only supported on matrices of same dimensionalities.") | ||||||
if set(out_size) != set(in_size): | ||||||
raise ValueError("Expected input size to be a permutation of output size.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.