From 16ba73e9809d36a2d4791db6552f7fcb9e3636f9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jul 2018 04:34:05 +0000 Subject: [PATCH] Improve sparse embedding index out of bound error message; --- src/operator/tensor/indexing_op.cc | 38 ++++++++++++++++++++---- src/operator/tensor/indexing_op.cu | 46 ++++++++++++++++++++++++------ 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 0f96e2cc2f72..ef59145bb4a9 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -28,6 +28,27 @@ namespace mxnet { namespace op { +/* + * \brief returns true if all indices are between [min, max] + * \param data_ptr the indices to check + * \param data_size the number of indices to examine + * \param min the expected min value for indices + * \param max the expected max value for indices + */ +template +bool CheckIndexOutOfBound(const DType* data_ptr, size_t data_size, + const DType min, const DType max) { + bool is_valid = true; + for (size_t i = 0; i < data_size; i++) { + if (data_ptr[i] > max || data_ptr[i] < min) { + is_valid = false; + break; + } + } + return is_valid; +} + + template<> void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, @@ -48,18 +69,16 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, return; } // check out-of-bound indices - bool is_valid = true; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { DType min = 0; DType max = static_cast(weight.shape()[0] - 1); // check with single thread is faster since data is small DType* data_ptr = data.dptr(); size_t data_size = data.shape_.Size(); - for (size_t i = 0; i < data_size; i++) { - if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false; - } + bool is_valid = CheckIndexOutOfBound(data_ptr, data_size, + min, max); + CHECK(is_valid) << "SparseEmbedding input contains data out of bound"; }) - CHECK(is_valid) << "SparseEmbedding input contains data out of bound"; // the weight is actually dense if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) { EmbeddingOpForwardDnsImpl(s, data, weight.data(), req, output); @@ -101,6 +120,15 @@ inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, { + // check out of bound indices + { + IType min = 0; + IType max = static_cast(output.shape()[0] - 1); + // check with single thread is faster since data is small + IType* data_ptr = data.dptr(); + bool is_valid = CheckIndexOutOfBound(data_ptr, data.shape_.Size(), min, max); + CHECK(is_valid) << "Embedding input contains data out of bound"; + } // mark row flags Fill(s, TBlob(row_flg, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0); Kernel::Launch(s, data_size, row_flg, data.dptr()); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 39fd81ef2001..bdc7f6e843c0 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -36,7 +36,7 @@ namespace op { struct is_valid_check { template - MSHADOW_XINLINE static void Map(int i, int32_t* out, const DType* data, + MSHADOW_XINLINE static void Map(int i, char* out, const DType* data, const DType min, const DType max) { if (data[i] < min || data[i] > max) *out = 1; } @@ -116,6 +116,27 @@ struct AddTakeGradRspDeterministicKernel { } }; +/* + * \brief returns true if all indices are between [min, max] + * \param s the stream + * \param data_ptr the indices on the stream + * \param data_size the number of indices to examine + * \param min the expected min value for indices + * \param max the expected max value for indices + * \param is_valid_ptr the temparary workspace + */ +template +bool CheckIndexOutOfBound(mshadow::Stream *s, const DType* data_ptr, size_t data_size, + const DType min, const DType max, char* is_valid_ptr) { + using namespace mxnet_op; + int32_t is_valid = 0; + Kernel::Launch(s, 1, is_valid_ptr); + Kernel::Launch(s, data_size, is_valid_ptr, data_ptr, min, max); + CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(char), + cudaMemcpyDeviceToHost)); + return is_valid == 0; +} + template<> void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, @@ -136,21 +157,17 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, return; } // check out-of-bound indices - int32_t is_valid = 0; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { DType min = 0; DType max = static_cast(weight.shape()[0] - 1); DType* data_ptr = data.dptr(); size_t data_size = data.shape_.Size(); Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(sizeof(int32_t)), s); - int32_t* is_valid_ptr = reinterpret_cast(workspace.dptr_); - Kernel::Launch(s, 1, is_valid_ptr); - Kernel::Launch(s, data_size, is_valid_ptr, data_ptr, min, max); - CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t), - cudaMemcpyDeviceToHost)); + .get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckIndexOutOfBound(s, data_ptr, data_size, min, max, is_valid_ptr); + CHECK(is_valid) << "SparseEmbedding input contains data out of bound"; }) - CHECK_EQ(is_valid, 0) << "SparseEmbedding input contains data out of bound"; // the weight is actually dense if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) { EmbeddingOpForwardDnsImpl(s, data, weight.data(), req, output); @@ -207,6 +224,17 @@ void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx, sorted_data_storage_bytes); temp_storage = workspace.dptr_ + total_storage_bytes - temp_workspace_bytes; + // check out-of-bound indices + { + IType min = 0; + IType max = static_cast(output.shape()[0] - 1); + IType* data_ptr = data.dptr(); + size_t data_size = data.shape_.Size(); + bool is_valid = CheckIndexOutOfBound(s, data_ptr, data_size, min, max, + reinterpret_cast(temp_storage)); + CHECK(is_valid) << "Embedding input contains data out of bound"; + } + // make a copy of the data, to be sorted TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask); auto sorted_data_tensor = sorted_data_blob.FlatTo1D(s);