Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds static tape functions to control tape activation #102

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/ref/tape.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ or `!c++ nullptr` if no active tape has been set.
Note that this is a thread-local pointer - calling this function in different
threads gives different results.

#### `setActive`

`#!c++ static void setActive(Tape* t)` static function that sets the given tape as the
globally active one. This is equivalent to `t.activate()`.

It may throw [`TapeAlreadyActive`](exceptions.md) if another tape is
already active for the current thread.

#### `deactivateAll`

`#!c++ static void deactivateAll()` deactivates any currently active tapes.
Equivalent to `auto t = Tape::getActive(); if (t) t->deactivate();`.

#### `registerInput`

`#!c++ void registerInput(active_type& inp)` registers the given variable with the tape and start recording dependents of it. A call to this function or its overloads is required in order to calculate adjoints.
Expand Down
19 changes: 12 additions & 7 deletions src/XAD/Tape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ class Tape
Tape& operator=(const Tape&) = delete;

// recording control
XAD_INLINE void activate()
{
if (active_tape_ != nullptr)
throw TapeAlreadyActive();
else
active_tape_ = this;
}
XAD_INLINE void activate() { setActive(this); }

XAD_INLINE void deactivate()
{
if (active_tape_ == this)
Expand All @@ -106,6 +101,16 @@ class Tape
XAD_INLINE bool isActive() const { return active_tape_ == this; }
XAD_INLINE static Tape* getActive() { return active_tape_; }

XAD_INLINE static void setActive(Tape* t)
{
if (active_tape_ != nullptr)
throw TapeAlreadyActive();
else
active_tape_ = t;
}

XAD_INLINE static void deactivateAll() { active_tape_ = nullptr; }

XAD_INLINE void registerInput(active_type& inp)
{
if (!inp.shouldRecord()) // already registered
Expand Down
27 changes: 27 additions & 0 deletions test/Tape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ TEST(Tape, canInitializeDeactivated)
EXPECT_NE(nullptr, Tape<float>::getActive());
}

TEST(Tape, canActivateStatically)
{
using xad::Tape;
Tape<float> s(false);

EXPECT_FALSE(s.isActive());
EXPECT_EQ(nullptr, Tape<float>::getActive());

xad::Tape<float>::setActive(&s);

EXPECT_TRUE(s.isActive());
EXPECT_NE(nullptr, Tape<float>::getActive());
}

TEST(Tape, canDeactivateGlobally)
{
using xad::Tape;

EXPECT_EQ(nullptr, Tape<double>::getActive());

Tape<double> s;

EXPECT_TRUE(s.isActive());
Tape<double>::deactivateAll();
EXPECT_FALSE(s.isActive());
}

TEST(Tape, isMovable)
{
xad::Tape<double> s(false);
Expand Down
Loading