From 5f6a60859e61dc14279a2d547058229ad5c4d7d4 Mon Sep 17 00:00:00 2001 From: Auto Differentiation Dev Team <107129969+auto-differentiation-dev@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:35:33 +0000 Subject: [PATCH] Adds static tape functions to control tape activation --- docs/ref/tape.md | 13 +++++++++++++ src/XAD/Tape.hpp | 19 ++++++++++++------- test/Tape_test.cpp | 27 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/docs/ref/tape.md b/docs/ref/tape.md index fd10107..512338c 100644 --- a/docs/ref/tape.md +++ b/docs/ref/tape.md @@ -116,6 +116,19 @@ or `!c++ nullptr` if no active tape has been set. Note that this is a thread-local pointer - calling this function in different threads gives different results. +#### `setActive` + +`#!c++ static void setActive(Tape* t)` static function that sets the given tape as the +globally active one. This is equivalent to `t.activate()`. + +It may throw [`TapeAlreadyActive`](exceptions.md) if another tape is +already active for the current thread. + +#### `deactivateAll` + +`#!c++ static void deactivateAll()` deactivates any currently active tapes. +Equivalent to `auto t = Tape::getActive(); if (t) t->deactivate();`. + #### `registerInput` `#!c++ void registerInput(active_type& inp)` registers the given variable with the tape and start recording dependents of it. A call to this function or its overloads is required in order to calculate adjoints. diff --git a/src/XAD/Tape.hpp b/src/XAD/Tape.hpp index 663f793..bd4c709 100644 --- a/src/XAD/Tape.hpp +++ b/src/XAD/Tape.hpp @@ -91,13 +91,8 @@ class Tape Tape& operator=(const Tape&) = delete; // recording control - XAD_INLINE void activate() - { - if (active_tape_ != nullptr) - throw TapeAlreadyActive(); - else - active_tape_ = this; - } + XAD_INLINE void activate() { setActive(this); } + XAD_INLINE void deactivate() { if (active_tape_ == this) @@ -106,6 +101,16 @@ class Tape XAD_INLINE bool isActive() const { return active_tape_ == this; } XAD_INLINE static Tape* getActive() { return active_tape_; } + XAD_INLINE static void setActive(Tape* t) + { + if (active_tape_ != nullptr) + throw TapeAlreadyActive(); + else + active_tape_ = t; + } + + XAD_INLINE static void deactivateAll() { active_tape_ = nullptr; } + XAD_INLINE void registerInput(active_type& inp) { if (!inp.shouldRecord()) // already registered diff --git a/test/Tape_test.cpp b/test/Tape_test.cpp index 3109911..4a06bc7 100644 --- a/test/Tape_test.cpp +++ b/test/Tape_test.cpp @@ -62,6 +62,33 @@ TEST(Tape, canInitializeDeactivated) EXPECT_NE(nullptr, Tape::getActive()); } +TEST(Tape, canActivateStatically) +{ + using xad::Tape; + Tape s(false); + + EXPECT_FALSE(s.isActive()); + EXPECT_EQ(nullptr, Tape::getActive()); + + xad::Tape::setActive(&s); + + EXPECT_TRUE(s.isActive()); + EXPECT_NE(nullptr, Tape::getActive()); +} + +TEST(Tape, canDeactivateGlobally) +{ + using xad::Tape; + + EXPECT_EQ(nullptr, Tape::getActive()); + + Tape s; + + EXPECT_TRUE(s.isActive()); + Tape::deactivateAll(); + EXPECT_FALSE(s.isActive()); +} + TEST(Tape, isMovable) { xad::Tape s(false);