Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-750] fix nested call on CachedOp. #11951

Merged
merged 2 commits into from
Jul 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,11 @@ OpStatePtr CachedOp::DynamicForward(

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

if (recording && !inlining_) Imperative::Get()->set_is_recording(false);

// If we are already recording, we don't need RunGraph to record all
// computation again.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);

Imperative::Get()->set_is_recording(recording);
std::move(ref_count), &states, dispatch_modes,
!recording || inlining_);

return op_state;
}
Expand Down Expand Up @@ -947,7 +946,8 @@ void CachedOp::DynamicBackward(
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ std::vector<NDArray*> Imperative::Backward(
int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
is_recording());

Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ void RunGraph(
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes) {
const DispatchModeVector &dispatch_modes,
bool recording) {
using namespace nnvm;
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
Expand All @@ -40,7 +41,6 @@ void RunGraph(
const auto imp = Imperative::Get();

std::vector<OpStatePtr>& states = *p_states;
bool recording = imp->is_recording();

std::vector<NDArray*> ndinputs, ndoutputs;
ShapeVector arg_shapes;
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,8 @@ void RunGraph(const bool retain_graph,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes);
const DispatchModeVector &dispatch_modes,
bool recording);

} // namespace imperative
} // namespace mxnet
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ def check_contrib_rnn(cell_type, num_states):

configs = [
{},
{'inline_limit': 0},
{'static_alloc': True},
{'static_alloc': True, 'static_shape': True} ]
for config in configs:
Expand Down