diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 17076933a115..d5e17da0c41b 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -137,6 +137,222 @@ void SliceDimTwoCsrImpl(const mxnet::TShape &begin, const mxnet::TShape &en }); } +template +struct split_tensor_data { + static const int MaxSections = 128; + size_t num_sections; + DType* outputs[MaxSections]; + size_t indices[MaxSections+1]; + DType* inputs[1]; +}; + +template +__global__ void split_tensor_kernel(size_t input_size, + const split_tensor_data params, + size_t split_axis_size, + size_t tail_size, + size_t last_axis_size, + size_t blocks_last_axis) { + const int entries_per_load = sizeof(LType)/sizeof(DType); + const LType* in_aligned = reinterpret_cast(params.inputs[0]); + const size_t last_axis_size_aligned = entries_per_load > 0 ? + last_axis_size / entries_per_load : last_axis_size; + if (split_last_axis) { + size_t input_offset_leading = (blockIdx.x / blocks_last_axis) * last_axis_size_aligned; + size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x * entries_per_load + + params.indices[0] + threadIdx.x * entries_per_load; + if (position_last_axis < params.indices[params.num_sections]) { + size_t position_last_axis_aligned = entries_per_load > 0 ? + position_last_axis / entries_per_load : + position_last_axis; + LType input_data = in_aligned[input_offset_leading + position_last_axis_aligned]; + // Binary search to find section of each thread + size_t lower = 0; + size_t upper = params.num_sections - 1; + while (lower < upper) { + size_t mid = (lower + upper + 1) / 2; + if (position_last_axis >= params.indices[mid]) + lower = mid; + else + upper = mid - 1; + } + size_t section = upper; + size_t section_size = params.indices[section + 1] - params.indices[section]; + LType* out_aligned = reinterpret_cast(params.outputs[section]); + size_t section_size_aligned = entries_per_load > 0 ? section_size / entries_per_load : + section_size; + size_t index_aligned = entries_per_load > 0 ? params.indices[section] / entries_per_load : + params.indices[section]; + size_t output_offset_leading = (blockIdx.x / blocks_last_axis) * section_size_aligned; + size_t output_position = output_offset_leading + position_last_axis_aligned - index_aligned; + out_aligned[output_position] = input_data; + } + } else { + size_t split_axis_size_iter = params.indices[params.num_sections] - params.indices[0]; + size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size * blocks_last_axis); + // input offsets: leading (axes pre-split-axis), at split-axis, tail, and blocks_last_axis + size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) * + split_axis_size * tail_size * last_axis_size_aligned; + size_t pos_in_split_axis = (blockIdx.x / (tail_size * blocks_last_axis)) % + split_axis_size_iter + params.indices[0]; + size_t input_offset_split_axis = pos_in_split_axis * tail_size * last_axis_size_aligned; + size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) * + last_axis_size_aligned; + size_t input_offset = input_offset_leading + input_offset_split_axis + offset_tail + + (blockIdx.x % blocks_last_axis) * blockDim.x; + // Binary search to find section for this block + size_t lower = 0; + size_t upper = params.num_sections - 1; + while (lower < upper) { + size_t mid = (lower + upper + 1) / 2; + if (pos_in_split_axis >= params.indices[mid]) + lower = mid; + else + upper = mid - 1; + } + size_t section = upper; + size_t section_size = params.indices[section + 1] - params.indices[section]; + LType* out_aligned = reinterpret_cast(params.outputs[section]); + // output offsets: leading (axes pre-split-axis), at split-axis,and blocks_last_axis + size_t output_offset_leading = (blockIdx.x / blocks_per_leading_dim) * + section_size * tail_size * last_axis_size_aligned; + size_t output_offset_split_axis = ((blockIdx.x % blocks_per_leading_dim) / blocks_last_axis - + ((params.indices[section] - params.indices[0]) * tail_size)) * + last_axis_size_aligned; + size_t output_offset = output_offset_leading + output_offset_split_axis + + (blockIdx.x % blocks_last_axis) * blockDim.x; + if (threadIdx.x < last_axis_size_aligned) { + LType input_data = in_aligned[input_offset + threadIdx.x]; + out_aligned[output_offset + threadIdx.x] = input_data; + } + } +} + +template +int get_load_type_split(size_t last_axis_size, + bool splitting_last_axis, + size_t n_sections, + size_t* indices) { + using namespace mshadow; + int sections_largest_multiple = 8; + if (splitting_last_axis) { + for (size_t i = 0; i < n_sections; ++i) { + size_t size_section = indices[i+1] - indices[i]; + if (size_section * sizeof(DType) % 8) + sections_largest_multiple = std::min(sections_largest_multiple, 4); + if (size_section * sizeof(DType) % 4) + sections_largest_multiple = std::min(sections_largest_multiple, 2); + if (size_section * sizeof(DType) % 2) + sections_largest_multiple = std::min(sections_largest_multiple, 1); + } + } + if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple == 8) { + return kFloat64; + } else if (last_axis_size * sizeof(DType) % 4 == 0 && sections_largest_multiple >= 4) { + return kFloat32; + } else if (last_axis_size * sizeof(DType) % 2 == 0 && sections_largest_multiple >= 2) { + return kFloat16; + } else { + return kUint8; + } +} + +inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()); + const TBlob& input_data = inputs[split_enum::kData]; + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += input_data.ndim(); + } + size_t last_axis_size = input_data.shape_[inputs[0].ndim()-1]; + size_t split_axis_size = input_data.shape_[real_axis]; + size_t tail_size = 1; // does not include last dim + for (int i = real_axis + 1; i < input_data.ndim()-1; ++i) { + tail_size *= input_data.shape_[i]; + } + if (last_axis_size < 128) { + // custom kernel will not be efficient with less than 128 elemnts in last axis + SplitOpForwardImpl(attrs, ctx, inputs, req, outputs, real_axis); + } else { + Stream *s = ctx.get_stream(); + CHECK_LT(real_axis, input_data.ndim()); + const mxnet::TShape& ishape = input_data.shape_; + const mxnet::TShape split_pts = + (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices; + std::vector indices; + for (const auto& split_pos : split_pts) { + indices.push_back(split_pos); + } + if (param.sections == 0) { + indices.push_back(ishape[real_axis]); + } + size_t n_sections = indices.size() - 1; + bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1); + + for (size_t sections_processed = 0; sections_processed < n_sections;) { + size_t remaining_sections = n_sections - sections_processed; + MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, { + // set parameters + split_tensor_data params{}; + params.num_sections = std::min(remaining_sections, params.MaxSections); + params.inputs[0] = input_data.dptr(); + for (size_t i = 0; i < params.num_sections; ++i) { + params.outputs[i] = outputs[sections_processed + i].dptr(); + params.indices[i] = indices[sections_processed + i]; + } + params.indices[params.num_sections] = indices[sections_processed + params.num_sections]; + // load type: we need to check that last axis size is multiple of ltype + // and if splitting_last_axis, all section sizes as well + int ltype = get_load_type_split(last_axis_size, splitting_last_axis, + params.num_sections, params.indices); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + CHECK_LE(sizeof(DType), sizeof(LType)); + const size_t entries_per_load = sizeof(LType) / sizeof(DType); + size_t block_size = 32; + size_t max_threads_block = 256; + size_t last_axis_elements = entries_per_load > 0 ? (last_axis_size / entries_per_load): 0; + if (splitting_last_axis) { + // may not be possible to include whole axis if too many sections + last_axis_elements = entries_per_load > 0 ? + ((params.indices[params.num_sections] - params.indices[0]) / entries_per_load): 0; + } + while (block_size < last_axis_elements && (block_size < max_threads_block)) { + block_size += 32; + } + size_t blocks_last_axis = (last_axis_elements + block_size - 1) / block_size; + size_t n_blocks = blocks_last_axis; + for (int i = 0 ; i < input_data.ndim() - 1; ++i) { + if (i == real_axis) { + // may not be possible to include all sections if too many + n_blocks *= (params.indices[params.num_sections] - params.indices[0]); + } else { + n_blocks *= input_data.shape_[i]; + } + } + if (splitting_last_axis) { + split_tensor_kernel<<stream_>>> + (input_data.Size(), params, split_axis_size, tail_size, + last_axis_size, blocks_last_axis); + } else { + split_tensor_kernel<<stream_>>> + (input_data.Size(), params, split_axis_size, tail_size, + last_axis_size, blocks_last_axis); + } + }); + sections_processed += params.num_sections; + }); + } + } +} NNVM_REGISTER_OP(Reshape) .set_attr("FCompute", UnaryOp::IdentityCompute); @@ -219,7 +435,7 @@ NNVM_REGISTER_OP(space_to_depth) .set_attr("FCompute", SpaceToDepthOpForward); NNVM_REGISTER_OP(_split_v2) -.set_attr("FCompute", SplitOpForward); +.set_attr("FCompute", SplitOpForwardGPU); NNVM_REGISTER_OP(_split_v2_backward) .set_attr("FCompute", SplitOpBackward); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index b6d0011f1a2f..9c5e11a44bae 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2476,3 +2476,20 @@ def test_arange_like_dtype(): for v in out: assert v.dtype == t +@with_seed() +@pytest.mark.serial +@pytest.mark.parametrize('dtype', ["float16", "float32", "float64"]) +def test_split_v2_fwd(dtype): + dim = random.randint(2, 9) + shape = rand_shape_nd(dim) + axis = random.randint(-dim, dim-1) + axis_size = shape[axis] + samples = random.randint(0, axis_size - 1) + indices = sorted(random.sample([i for i in range(1, axis_size)], samples)) + indices = tuple(indices) + mx_data = rand_ndarray(shape, dtype=dtype) + np_data = mx_data.asnumpy() + np_out = np.split(np_data, indices_or_sections=indices, axis=axis) + data = mx.sym.Variable("data") + sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis) + check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5)