Skip to content

Commit

Permalink
Add GPU-optimization for split op (apache#19131)
Browse files Browse the repository at this point in the history
* Add GPU-optimization for split op

* Complete operator

* unit-test: use parametrize

* fix lint

* fix lint

* fix lint
  • Loading branch information
MoisesHer authored and chinakook committed Nov 17, 2020
1 parent 4bde521 commit 093ad1d
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 1 deletion.
218 changes: 217 additions & 1 deletion src/operator/tensor/matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,222 @@ void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape &begin, const mxnet::TShape &en
});
}

template <typename DType>
struct split_tensor_data {
static const int MaxSections = 128;
size_t num_sections;
DType* outputs[MaxSections];
size_t indices[MaxSections+1];
DType* inputs[1];
};

template <bool split_last_axis, typename LType, typename DType>
__global__ void split_tensor_kernel(size_t input_size,
const split_tensor_data<DType> params,
size_t split_axis_size,
size_t tail_size,
size_t last_axis_size,
size_t blocks_last_axis) {
const int entries_per_load = sizeof(LType)/sizeof(DType);
const LType* in_aligned = reinterpret_cast<const LType*>(params.inputs[0]);
const size_t last_axis_size_aligned = entries_per_load > 0 ?
last_axis_size / entries_per_load : last_axis_size;
if (split_last_axis) {
size_t input_offset_leading = (blockIdx.x / blocks_last_axis) * last_axis_size_aligned;
size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x * entries_per_load +
params.indices[0] + threadIdx.x * entries_per_load;
if (position_last_axis < params.indices[params.num_sections]) {
size_t position_last_axis_aligned = entries_per_load > 0 ?
position_last_axis / entries_per_load :
position_last_axis;
LType input_data = in_aligned[input_offset_leading + position_last_axis_aligned];
// Binary search to find section of each thread
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (position_last_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
size_t section_size_aligned = entries_per_load > 0 ? section_size / entries_per_load :
section_size;
size_t index_aligned = entries_per_load > 0 ? params.indices[section] / entries_per_load :
params.indices[section];
size_t output_offset_leading = (blockIdx.x / blocks_last_axis) * section_size_aligned;
size_t output_position = output_offset_leading + position_last_axis_aligned - index_aligned;
out_aligned[output_position] = input_data;
}
} else {
size_t split_axis_size_iter = params.indices[params.num_sections] - params.indices[0];
size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size * blocks_last_axis);
// input offsets: leading (axes pre-split-axis), at split-axis, tail, and blocks_last_axis
size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
split_axis_size * tail_size * last_axis_size_aligned;
size_t pos_in_split_axis = (blockIdx.x / (tail_size * blocks_last_axis)) %
split_axis_size_iter + params.indices[0];
size_t input_offset_split_axis = pos_in_split_axis * tail_size * last_axis_size_aligned;
size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) *
last_axis_size_aligned;
size_t input_offset = input_offset_leading + input_offset_split_axis + offset_tail +
(blockIdx.x % blocks_last_axis) * blockDim.x;
// Binary search to find section for this block
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (pos_in_split_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
// output offsets: leading (axes pre-split-axis), at split-axis,and blocks_last_axis
size_t output_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
section_size * tail_size * last_axis_size_aligned;
size_t output_offset_split_axis = ((blockIdx.x % blocks_per_leading_dim) / blocks_last_axis -
((params.indices[section] - params.indices[0]) * tail_size)) *
last_axis_size_aligned;
size_t output_offset = output_offset_leading + output_offset_split_axis +
(blockIdx.x % blocks_last_axis) * blockDim.x;
if (threadIdx.x < last_axis_size_aligned) {
LType input_data = in_aligned[input_offset + threadIdx.x];
out_aligned[output_offset + threadIdx.x] = input_data;
}
}
}

template <typename DType>
int get_load_type_split(size_t last_axis_size,
bool splitting_last_axis,
size_t n_sections,
size_t* indices) {
using namespace mshadow;
int sections_largest_multiple = 8;
if (splitting_last_axis) {
for (size_t i = 0; i < n_sections; ++i) {
size_t size_section = indices[i+1] - indices[i];
if (size_section * sizeof(DType) % 8)
sections_largest_multiple = std::min(sections_largest_multiple, 4);
if (size_section * sizeof(DType) % 4)
sections_largest_multiple = std::min(sections_largest_multiple, 2);
if (size_section * sizeof(DType) % 2)
sections_largest_multiple = std::min(sections_largest_multiple, 1);
}
}
if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple == 8) {
return kFloat64;
} else if (last_axis_size * sizeof(DType) % 4 == 0 && sections_largest_multiple >= 4) {
return kFloat32;
} else if (last_axis_size * sizeof(DType) % 2 == 0 && sections_largest_multiple >= 2) {
return kFloat16;
} else {
return kUint8;
}
}

inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim());
const TBlob& input_data = inputs[split_enum::kData];
int real_axis = param.axis;
if (real_axis < 0) {
real_axis += input_data.ndim();
}
size_t last_axis_size = input_data.shape_[inputs[0].ndim()-1];
size_t split_axis_size = input_data.shape_[real_axis];
size_t tail_size = 1; // does not include last dim
for (int i = real_axis + 1; i < input_data.ndim()-1; ++i) {
tail_size *= input_data.shape_[i];
}
if (last_axis_size < 128) {
// custom kernel will not be efficient with less than 128 elemnts in last axis
SplitOpForwardImpl<gpu>(attrs, ctx, inputs, req, outputs, real_axis);
} else {
Stream<gpu> *s = ctx.get_stream<gpu>();
CHECK_LT(real_axis, input_data.ndim());
const mxnet::TShape& ishape = input_data.shape_;
const mxnet::TShape split_pts =
(param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
std::vector<size_t> indices;
for (const auto& split_pos : split_pts) {
indices.push_back(split_pos);
}
if (param.sections == 0) {
indices.push_back(ishape[real_axis]);
}
size_t n_sections = indices.size() - 1;
bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1);

for (size_t sections_processed = 0; sections_processed < n_sections;) {
size_t remaining_sections = n_sections - sections_processed;
MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
// set parameters
split_tensor_data<DType> params{};
params.num_sections = std::min<size_t>(remaining_sections, params.MaxSections);
params.inputs[0] = input_data.dptr<DType>();
for (size_t i = 0; i < params.num_sections; ++i) {
params.outputs[i] = outputs[sections_processed + i].dptr<DType>();
params.indices[i] = indices[sections_processed + i];
}
params.indices[params.num_sections] = indices[sections_processed + params.num_sections];
// load type: we need to check that last axis size is multiple of ltype
// and if splitting_last_axis, all section sizes as well
int ltype = get_load_type_split<DType>(last_axis_size, splitting_last_axis,
params.num_sections, params.indices);
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
CHECK_LE(sizeof(DType), sizeof(LType));
const size_t entries_per_load = sizeof(LType) / sizeof(DType);
size_t block_size = 32;
size_t max_threads_block = 256;
size_t last_axis_elements = entries_per_load > 0 ? (last_axis_size / entries_per_load): 0;
if (splitting_last_axis) {
// may not be possible to include whole axis if too many sections
last_axis_elements = entries_per_load > 0 ?
((params.indices[params.num_sections] - params.indices[0]) / entries_per_load): 0;
}
while (block_size < last_axis_elements && (block_size < max_threads_block)) {
block_size += 32;
}
size_t blocks_last_axis = (last_axis_elements + block_size - 1) / block_size;
size_t n_blocks = blocks_last_axis;
for (int i = 0 ; i < input_data.ndim() - 1; ++i) {
if (i == real_axis) {
// may not be possible to include all sections if too many
n_blocks *= (params.indices[params.num_sections] - params.indices[0]);
} else {
n_blocks *= input_data.shape_[i];
}
}
if (splitting_last_axis) {
split_tensor_kernel<true, LType><<<n_blocks, block_size, 0, s->stream_>>>
(input_data.Size(), params, split_axis_size, tail_size,
last_axis_size, blocks_last_axis);
} else {
split_tensor_kernel<false, LType><<<n_blocks, block_size, 0, s->stream_>>>
(input_data.Size(), params, split_axis_size, tail_size,
last_axis_size, blocks_last_axis);
}
});
sections_processed += params.num_sections;
});
}
}
}

NNVM_REGISTER_OP(Reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
Expand Down Expand Up @@ -219,7 +435,7 @@ NNVM_REGISTER_OP(space_to_depth)
.set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);

NNVM_REGISTER_OP(_split_v2)
.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);

NNVM_REGISTER_OP(_split_v2_backward)
.set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
Expand Down
17 changes: 17 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,3 +2476,20 @@ def test_arange_like_dtype():
for v in out:
assert v.dtype == t

@with_seed()
@pytest.mark.serial
@pytest.mark.parametrize('dtype', ["float16", "float32", "float64"])
def test_split_v2_fwd(dtype):
dim = random.randint(2, 9)
shape = rand_shape_nd(dim)
axis = random.randint(-dim, dim-1)
axis_size = shape[axis]
samples = random.randint(0, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
mx_data = rand_ndarray(shape, dtype=dtype)
np_data = mx_data.asnumpy()
np_out = np.split(np_data, indices_or_sections=indices, axis=axis)
data = mx.sym.Variable("data")
sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis)
check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5)

0 comments on commit 093ad1d

Please sign in to comment.