Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add GPU-optimization for split op #19131

Merged
merged 6 commits into from
Oct 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 217 additions & 1 deletion src/operator/tensor/matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,222 @@ void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape &begin, const mxnet::TShape &en
});
}

template <typename DType>
struct split_tensor_data {
static const int MaxSections = 128;
size_t num_sections;
DType* outputs[MaxSections];
size_t indices[MaxSections+1];
DType* inputs[1];
};

template <bool split_last_axis, typename LType, typename DType>
__global__ void split_tensor_kernel(size_t input_size,
const split_tensor_data<DType> params,
size_t split_axis_size,
size_t tail_size,
size_t last_axis_size,
size_t blocks_last_axis) {
const int entries_per_load = sizeof(LType)/sizeof(DType);
const LType* in_aligned = reinterpret_cast<const LType*>(params.inputs[0]);
const size_t last_axis_size_aligned = entries_per_load > 0 ?
last_axis_size / entries_per_load : last_axis_size;
if (split_last_axis) {
size_t input_offset_leading = (blockIdx.x / blocks_last_axis) * last_axis_size_aligned;
size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x * entries_per_load +
params.indices[0] + threadIdx.x * entries_per_load;
if (position_last_axis < params.indices[params.num_sections]) {
size_t position_last_axis_aligned = entries_per_load > 0 ?
position_last_axis / entries_per_load :
position_last_axis;
LType input_data = in_aligned[input_offset_leading + position_last_axis_aligned];
// Binary search to find section of each thread
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (position_last_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
size_t section_size_aligned = entries_per_load > 0 ? section_size / entries_per_load :
section_size;
size_t index_aligned = entries_per_load > 0 ? params.indices[section] / entries_per_load :
params.indices[section];
size_t output_offset_leading = (blockIdx.x / blocks_last_axis) * section_size_aligned;
size_t output_position = output_offset_leading + position_last_axis_aligned - index_aligned;
out_aligned[output_position] = input_data;
}
} else {
size_t split_axis_size_iter = params.indices[params.num_sections] - params.indices[0];
size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size * blocks_last_axis);
// input offsets: leading (axes pre-split-axis), at split-axis, tail, and blocks_last_axis
size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
split_axis_size * tail_size * last_axis_size_aligned;
size_t pos_in_split_axis = (blockIdx.x / (tail_size * blocks_last_axis)) %
split_axis_size_iter + params.indices[0];
size_t input_offset_split_axis = pos_in_split_axis * tail_size * last_axis_size_aligned;
size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) *
last_axis_size_aligned;
size_t input_offset = input_offset_leading + input_offset_split_axis + offset_tail +
(blockIdx.x % blocks_last_axis) * blockDim.x;
// Binary search to find section for this block
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (pos_in_split_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
// output offsets: leading (axes pre-split-axis), at split-axis,and blocks_last_axis
size_t output_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
section_size * tail_size * last_axis_size_aligned;
size_t output_offset_split_axis = ((blockIdx.x % blocks_per_leading_dim) / blocks_last_axis -
((params.indices[section] - params.indices[0]) * tail_size)) *
last_axis_size_aligned;
size_t output_offset = output_offset_leading + output_offset_split_axis +
(blockIdx.x % blocks_last_axis) * blockDim.x;
if (threadIdx.x < last_axis_size_aligned) {
LType input_data = in_aligned[input_offset + threadIdx.x];
out_aligned[output_offset + threadIdx.x] = input_data;
}
}
}

template <typename DType>
int get_load_type_split(size_t last_axis_size,
bool splitting_last_axis,
size_t n_sections,
size_t* indices) {
using namespace mshadow;
int sections_largest_multiple = 8;
if (splitting_last_axis) {
for (size_t i = 0; i < n_sections; ++i) {
size_t size_section = indices[i+1] - indices[i];
if (size_section * sizeof(DType) % 8)
sections_largest_multiple = std::min(sections_largest_multiple, 4);
if (size_section * sizeof(DType) % 4)
sections_largest_multiple = std::min(sections_largest_multiple, 2);
if (size_section * sizeof(DType) % 2)
sections_largest_multiple = std::min(sections_largest_multiple, 1);
}
}
if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple == 8) {
return kFloat64;
} else if (last_axis_size * sizeof(DType) % 4 == 0 && sections_largest_multiple >= 4) {
return kFloat32;
} else if (last_axis_size * sizeof(DType) % 2 == 0 && sections_largest_multiple >= 2) {
return kFloat16;
} else {
return kUint8;
}
}

inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim());
const TBlob& input_data = inputs[split_enum::kData];
int real_axis = param.axis;
if (real_axis < 0) {
real_axis += input_data.ndim();
}
size_t last_axis_size = input_data.shape_[inputs[0].ndim()-1];
size_t split_axis_size = input_data.shape_[real_axis];
size_t tail_size = 1; // does not include last dim
for (int i = real_axis + 1; i < input_data.ndim()-1; ++i) {
tail_size *= input_data.shape_[i];
}
if (last_axis_size < 128) {
// custom kernel will not be efficient with less than 128 elemnts in last axis
SplitOpForwardImpl<gpu>(attrs, ctx, inputs, req, outputs, real_axis);
} else {
Stream<gpu> *s = ctx.get_stream<gpu>();
CHECK_LT(real_axis, input_data.ndim());
const mxnet::TShape& ishape = input_data.shape_;
const mxnet::TShape split_pts =
(param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
std::vector<size_t> indices;
for (const auto& split_pos : split_pts) {
indices.push_back(split_pos);
}
if (param.sections == 0) {
indices.push_back(ishape[real_axis]);
}
size_t n_sections = indices.size() - 1;
bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1);

for (size_t sections_processed = 0; sections_processed < n_sections;) {
size_t remaining_sections = n_sections - sections_processed;
MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
// set parameters
split_tensor_data<DType> params{};
params.num_sections = std::min<size_t>(remaining_sections, params.MaxSections);
params.inputs[0] = input_data.dptr<DType>();
for (size_t i = 0; i < params.num_sections; ++i) {
params.outputs[i] = outputs[sections_processed + i].dptr<DType>();
params.indices[i] = indices[sections_processed + i];
}
params.indices[params.num_sections] = indices[sections_processed + params.num_sections];
// load type: we need to check that last axis size is multiple of ltype
// and if splitting_last_axis, all section sizes as well
int ltype = get_load_type_split<DType>(last_axis_size, splitting_last_axis,
params.num_sections, params.indices);
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
CHECK_LE(sizeof(DType), sizeof(LType));
const size_t entries_per_load = sizeof(LType) / sizeof(DType);
size_t block_size = 32;
size_t max_threads_block = 256;
size_t last_axis_elements = entries_per_load > 0 ? (last_axis_size / entries_per_load): 0;
if (splitting_last_axis) {
// may not be possible to include whole axis if too many sections
last_axis_elements = entries_per_load > 0 ?
((params.indices[params.num_sections] - params.indices[0]) / entries_per_load): 0;
}
while (block_size < last_axis_elements && (block_size < max_threads_block)) {
block_size += 32;
}
size_t blocks_last_axis = (last_axis_elements + block_size - 1) / block_size;
size_t n_blocks = blocks_last_axis;
for (int i = 0 ; i < input_data.ndim() - 1; ++i) {
if (i == real_axis) {
// may not be possible to include all sections if too many
n_blocks *= (params.indices[params.num_sections] - params.indices[0]);
} else {
n_blocks *= input_data.shape_[i];
}
}
if (splitting_last_axis) {
split_tensor_kernel<true, LType><<<n_blocks, block_size, 0, s->stream_>>>
(input_data.Size(), params, split_axis_size, tail_size,
last_axis_size, blocks_last_axis);
} else {
split_tensor_kernel<false, LType><<<n_blocks, block_size, 0, s->stream_>>>
(input_data.Size(), params, split_axis_size, tail_size,
last_axis_size, blocks_last_axis);
}
});
sections_processed += params.num_sections;
});
}
}
}

NNVM_REGISTER_OP(Reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
Expand Down Expand Up @@ -219,7 +435,7 @@ NNVM_REGISTER_OP(space_to_depth)
.set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);

NNVM_REGISTER_OP(_split_v2)
.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);

NNVM_REGISTER_OP(_split_v2_backward)
.set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
Expand Down
18 changes: 18 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,3 +2319,21 @@ def test_fp16_spmm():
out = mxsps.dot(inp, weight)
out_np = mx.nd.dot(inp, weight)
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)

@with_seed()
@pytest.mark.serial
szha marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to mark as serial

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the mark.serial is triggered for all tests in this file. Thus, we may keep mark.serial.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We run tests based on the tag and it has nothing to do with file. serial is only needed when test invocation is long-running and consumes lots of memory. Since this is no longer the case through parametrizing the input, the serial tag is not needed.

please open a follow up PR to finish the change.

@pytest.mark.parametrize('dtype', ["float16", "float32", "float64"])
def test_split_v2_fwd(dtype):
dim = random.randint(2, 9)
shape = rand_shape_nd(dim)
axis = random.randint(-dim, dim-1)
axis_size = shape[axis]
samples = random.randint(0, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
mx_data = rand_ndarray(shape, dtype=dtype)
np_data = mx_data.asnumpy()
np_out = np.split(np_data, indices_or_sections=indices, axis=axis)
data = mx.sym.Variable("data")
sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis)
check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5)