Skip to content

Commit

Permalink
handle deleter in create not destructor of class
Browse files Browse the repository at this point in the history
it makes us keep mutex not recursive_mutex

Co-authored-by: Terry Cojean <terry.cojean@kit.edu>
  • Loading branch information
yhmtsai and tcojean committed Jul 29, 2021
1 parent 23d3197 commit 2da69cd
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 26 deletions.
6 changes: 4 additions & 2 deletions cuda/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ std::shared_ptr<CudaExecutor> CudaExecutor::create(
alloc_mode),
[device_id](CudaExecutor *exec) {
auto device_reset = exec->get_device_reset();
std::lock_guard<std::recursive_mutex> guard(
std::lock_guard<std::mutex> guard(
device_class::get_mutex(device_id));
delete exec;
if (!CudaExecutor::get_num_execs(device_id) && device_reset) {
auto &num_execs = device_class::get_num_execs(device_id);
num_execs--;
if (!num_execs && device_reset) {
cuda::device_guard g(device_id);
cudaDeviceReset();
}
Expand Down
8 changes: 4 additions & 4 deletions devices/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
namespace gko {


std::recursive_mutex &NvidiaDevice::get_mutex(int i)
std::mutex &NvidiaDevice::get_mutex(int i)
{
static std::recursive_mutex mutex[max_devices];
static std::mutex mutex[max_devices];
return mutex[i];
}

Expand All @@ -52,9 +52,9 @@ int &NvidiaDevice::get_num_execs(int i)
return num_execs[i];
}

std::recursive_mutex &AmdDevice::get_mutex(int i)
std::mutex &AmdDevice::get_mutex(int i)
{
static std::recursive_mutex mutex[max_devices];
static std::mutex mutex[max_devices];
return mutex[i];
}

Expand Down
6 changes: 4 additions & 2 deletions hip/base/executor.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ std::shared_ptr<HipExecutor> HipExecutor::create(
new HipExecutor(device_id, std::move(master), device_reset, alloc_mode),
[device_id](HipExecutor *exec) {
auto device_reset = exec->get_device_reset();
std::lock_guard<std::recursive_mutex> guard(
std::lock_guard<std::mutex> guard(
device_class::get_mutex(device_id));
delete exec;
if (!HipExecutor::get_num_execs(device_id) && device_reset) {
auto &num_execs = device_class::get_num_execs(device_id);
num_execs--;
if (!num_execs && device_reset) {
hip::device_guard g(device_id);
hipDeviceReset();
}
Expand Down
4 changes: 2 additions & 2 deletions include/ginkgo/core/base/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class NvidiaDevice {
*
* @return recursive_mutex reference
*/
static std::recursive_mutex &get_mutex(int i);
static std::mutex &get_mutex(int i);

/**
* get_num_execs gets the static num_execs reference at i.
Expand Down Expand Up @@ -104,7 +104,7 @@ class AmdDevice {
*
* @return recursive_mutex reference
*/
static std::recursive_mutex &get_mutex(int i);
static std::mutex &get_mutex(int i);

/**
* get_num_execs gets the static num_execs reference at i.
Expand Down
22 changes: 6 additions & 16 deletions include/ginkgo/core/base/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,8 +1302,6 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,
bool device_reset = false,
allocation_mode alloc_mode = default_cuda_alloc_mode);

~CudaExecutor() { decrease_num_execs(this->get_device_id()); }

std::shared_ptr<Executor> get_master() noexcept override;

std::shared_ptr<const Executor> get_master() const noexcept override;
Expand Down Expand Up @@ -1457,22 +1455,19 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,

static void increase_num_execs(unsigned device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
device_class::get_num_execs(device_id)++;
}

static void decrease_num_execs(unsigned device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
device_class::get_num_execs(device_id)--;
}

static unsigned get_num_execs(unsigned device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
return device_class::get_num_execs(device_id);
}

Expand Down Expand Up @@ -1525,8 +1520,6 @@ class HipExecutor : public detail::ExecutorBase<HipExecutor>,
bool device_reset = false,
allocation_mode alloc_mode = default_hip_alloc_mode);

~HipExecutor() { decrease_num_execs(this->get_device_id()); }

std::shared_ptr<Executor> get_master() noexcept override;

std::shared_ptr<const Executor> get_master() const noexcept override;
Expand Down Expand Up @@ -1684,22 +1677,19 @@ class HipExecutor : public detail::ExecutorBase<HipExecutor>,

static void increase_num_execs(int device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
device_class::get_num_execs(device_id)++;
}

static void decrease_num_execs(int device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
device_class::get_num_execs(device_id)--;
}

static int get_num_execs(int device_id)
{
std::lock_guard<std::recursive_mutex> guard(
device_class::get_mutex(device_id));
std::lock_guard<std::mutex> guard(device_class::get_mutex(device_id));
return device_class::get_num_execs(device_id);
}

Expand Down

0 comments on commit 2da69cd

Please sign in to comment.