Skip to content

Commit

Permalink
Actor status (#1071)
Browse files Browse the repository at this point in the history
* add status detection for remote instance

* upgrade paddle version in documentation compilation env

* copyright

* remove code for debug
  • Loading branch information
TomorrowIsAnOtherDay authored Mar 10, 2023
1 parent 2e503da commit eb1ebbf
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
paddlepaddle==2.1
paddlepaddle==2.4.2
60 changes: 43 additions & 17 deletions parl/remote/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,18 @@ class Client(object):
def __init__(self, master_address, process_id, distributed_files=[]):
"""
Args:
master_addr (str): ip address of the master node.
job_heartbeat_server_addr(str): server address for heartbeat detection from jobs.
process_id (str): process id in which client is created.
Should use os.getpid() to get the process id.
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
master_addr (str): IP address of the master node.
job_heartbeat_server_addr(str): Server address for heartbeat detection from jobs.
process_id (str): Process id in which client is created. Should use os.getpid() to get the process id.
distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration
file for initialization) .
"""
self.dead_job_queue = mp.Queue()
self.client_is_alive = mp.Value('i', True)
self._create_heartbeat_server()
th = threading.Thread(target=self._update_job_status, args=(self.dead_job_queue, ))
th.setDaemon(True)
th.start()
self.master_address = master_address
self.process_id = process_id
self.ctx = zmq.Context()
Expand All @@ -71,6 +73,8 @@ def __init__(self, master_address, process_id, distributed_files=[]):
self._create_sockets(master_address)
self.connected_to_master = True
self.check_env_consistency()
self.instance_count = 0
self.instance_id_to_job = dict()

thread = threading.Thread(target=self._update_client_status_to_master)
thread.setDaemon(True)
Expand All @@ -82,6 +86,7 @@ def __init__(self, master_address, process_id, distributed_files=[]):
def destroy(self):
"""Destructor function"""
self.connected_to_master = False
self.dead_job_queue.put('exit')
self.master_heartbeat_thread.exit()
for th in self.threads:
th.join()
Expand Down Expand Up @@ -209,6 +214,16 @@ def master_heartbeat_exit_callback_func():
"check if master is started and ensure the input "
"address {} is correct.".format(master_address))

def _update_job_status(self, dead_job_queue):
while True:
instance_id = dead_job_queue.get()
# the client calls the destroy function
if isinstance(instance_id, str) and instance_id == 'exit':
break
logger.error("[Client] lost connection with a remote instance. ID: {}".format(instance_id))
job_is_alive = self.instance_id_to_job[instance_id]
job_is_alive.value = False

def check_env_consistency(self):
'''Verify that the parl & python version as well as some other packages in 'worker' process
matches that of the 'master' process'''
Expand Down Expand Up @@ -269,43 +284,53 @@ def _update_client_status_to_master(self):

time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)

def _check_and_monitor_job(self, job_ping_address, max_memory, gpu):
def _check_job(self, job_ping_address, max_memory, gpu):
"""
We have to check if this job is still alive before establishing connection with it.
Check if this job is still alive before establishing connection with it.
Return: instance_id (int): an unique isntance id. -1 if the job is not ready for connection.
"""
# job_ping_socket: sends ping signal to job
job_ping_socket = self.ctx.socket(zmq.REQ)
job_ping_socket.linger = 0
job_ping_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
job_ping_socket.connect("tcp://" + job_ping_address)
instance_id = self._generate_instance_id()
try:
job_ping_socket.send_multipart([
remote_constants.HEARTBEAT_TAG,
to_byte(self.job_heartbeat_server_addr),
to_byte(str(max_memory)),
to_byte(gpu)
to_byte(gpu),
to_byte(instance_id)
], )
job_ping_socket.recv_multipart()
except zmq.error.Again:
job_ping_socket.close(0)
logger.error(
"[Client] connects to a finished job, will try again, job_ping_address:{}".format(job_ping_address))
return False
job_ping_socket.close(0)
return True
instance_id = -1
finally:
job_ping_socket.close(0)
return instance_id

def _create_heartbeat_server(self):
""" Create the grpc-based heartbeat server at the subprocess.
"""
job_heartbeat_port = mp.Value('i', 0)
self.actor_num = mp.Value('i', 0)
self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num, self.client_is_alive)
self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num,
self.client_is_alive, self.dead_job_queue)
self.job_heartbeat_process.daemon = True
self.job_heartbeat_process.start()
assert job_heartbeat_port.value != 0, "fail to initialize heartbeat server for jobs."
self.job_heartbeat_server_addr = "{}:{}".format(get_ip_address(), job_heartbeat_port.value)

def submit_job(self, max_memory, n_gpu):
def _generate_instance_id(self):
"""Return an unique instance id for the remote instance"""
self.instance_count += 1
unique_id = f"{self.instance_count:05}"
return unique_id

def submit_job(self, max_memory, n_gpu, job_is_alive):
"""Send a job to the Master node.
When a `@parl.remote_class` object is created, the global client
Expand Down Expand Up @@ -340,9 +365,10 @@ def submit_job(self, max_memory, n_gpu):
job_ping_address = job_info.ping_heartbeat_address

self.lock.acquire()
check_result = self._check_and_monitor_job(job_ping_address, max_memory, job_info.allocated_gpu.gpu)
instance_id = self._check_job(job_ping_address, max_memory, job_info.allocated_gpu.gpu)
self.lock.release()
if check_result:
if instance_id != -1:
self.instance_id_to_job[instance_id] = job_is_alive
return job_info
# no vacant CPU resources, cannot submit a new job
elif tag == remote_constants.CPU_TAG:
Expand Down
3 changes: 2 additions & 1 deletion parl/remote/grpc_heartbeat/heartbeat.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ service GrpcHeartbeat {
message Request {
string client_id = 1;
bytes tag = 2;
string extra_msg = 4;
string instance_id = 4; // used in heartbeat detection between the job and client.
string extra_msg = 8;
}

// The response message
Expand Down
26 changes: 15 additions & 11 deletions parl/remote/grpc_heartbeat/heartbeat_pb2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: heartbeat.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
Expand All @@ -33,7 +30,7 @@
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x0fheartbeat.proto\"<\n\x07Request\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x0b\n\x03tag\x18\x02 \x01(\x0c\x12\x11\n\textra_msg\x18\x04 \x01(\t\"\x14\n\x05Reply\x12\x0b\n\x03tag\x18\x01 \x01(\x0c\x32+\n\rGrpcHeartbeat\x12\x1a\n\x04Send\x12\x08.Request\x1a\x06.Reply\"\x00\x62\x06proto3'
serialized_pb=b'\n\x0fheartbeat.proto\"Q\n\x07Request\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x0b\n\x03tag\x18\x02 \x01(\x0c\x12\x13\n\x0binstance_id\x18\x04 \x01(\t\x12\x11\n\textra_msg\x18\x08 \x01(\t\"\x14\n\x05Reply\x12\x0b\n\x03tag\x18\x01 \x01(\x0c\x32+\n\rGrpcHeartbeat\x12\x1a\n\x04Send\x12\x08.Request\x1a\x06.Reply\"\x00\x62\x06proto3'
)


Expand Down Expand Up @@ -62,12 +59,19 @@
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='extra_msg', full_name='Request.extra_msg', index=2,
name='instance_id', full_name='Request.instance_id', index=2,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='extra_msg', full_name='Request.extra_msg', index=3,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
Expand All @@ -81,7 +85,7 @@
oneofs=[
],
serialized_start=19,
serialized_end=79,
serialized_end=100,
)


Expand Down Expand Up @@ -112,8 +116,8 @@
extension_ranges=[],
oneofs=[
],
serialized_start=81,
serialized_end=101,
serialized_start=102,
serialized_end=122,
)

DESCRIPTOR.message_types_by_name['Request'] = _REQUEST
Expand Down Expand Up @@ -143,8 +147,8 @@
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=103,
serialized_end=146,
serialized_start=124,
serialized_end=167,
methods=[
_descriptor.MethodDescriptor(
name='Send',
Expand Down
10 changes: 6 additions & 4 deletions parl/remote/grpc_heartbeat/heartbeat_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,11 +26,12 @@


class GrpcHeartbeatServer(heartbeat_pb2_grpc.GrpcHeartbeatServicer):
def __init__(self, client_count=None, host_is_alive=True):
def __init__(self, client_count=None, host_is_alive=True, dead_job_queue=None):
self.last_heartbeat_time = time.time()
self.last_heartbeat_table = dict()
self.exit_flag = False
self.client_count = client_count
self.dead_job_queue = dead_job_queue
self.host_is_alive = host_is_alive
self.host_pid = None

Expand Down Expand Up @@ -77,6 +78,7 @@ def timeout_time_mp(self):
to_del_client.append(client_id)
for client_id in to_del_client:
del self.last_heartbeat_table[client_id]
self.dead_job_queue.put(client_id)
self.client_count.value = len(self.last_heartbeat_table)

class HeartbeatServerThread(threading.Thread):
Expand Down Expand Up @@ -144,7 +146,7 @@ def exit(self):
self.heartbeat_server.exit()

class HeartbeatServerProcess(mp.Process):
def __init__(self, port, client_count, host_is_alive):
def __init__(self, port, client_count, host_is_alive, dead_job_queue):
"""Create a process to run the heartbeat server.
Args:
port(mp.Value): notify the main prcoess of the severt port.
Expand All @@ -155,7 +157,7 @@ def __init__(self, port, client_count, host_is_alive):
futures.ThreadPoolExecutor(max_workers=500),
options=[('grpc.max_receive_message_length', -1),
('grpc.max_send_message_length', -1)])
self.heartbeat_server = GrpcHeartbeatServer(client_count, host_is_alive)
self.heartbeat_server = GrpcHeartbeatServer(client_count, host_is_alive, dead_job_queue)

heartbeat_pb2_grpc.add_GrpcHeartbeatServicer_to_server(
self.heartbeat_server, self.grpc_server)
Expand Down
4 changes: 3 additions & 1 deletion parl/remote/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, worker_address, log_server_address):
max_memory (float): Maximum memory (MB) can be used by each remote instance.
gpu (str): id list of GPUs can be used by each remote instance.
job_id (str): Unique ID for the job.
instance_id (str): Unique instance ID to which the job connects.
"""
self.max_memory = None
self.gpu = ""
Expand Down Expand Up @@ -200,6 +201,7 @@ def _reply_ping(self, socket):
if max_memory != 'None':
self.max_memory = float(max_memory)
self.gpu = to_str(message[3])
self.instance_id = to_str(message[4])
socket.send_multipart([remote_constants.HEARTBEAT_TAG])

def client_heartbeat_exit_callback_func():
Expand All @@ -214,7 +216,7 @@ def client_heartbeat_exit_callback_func():

# a thread that sends heartbeat signals from the client
self.client_heartbeat_client_thread = HeartbeatClientThread(
client_id=self.job_id,
client_id=self.instance_id,
heartbeat_server_addr=client_heartbeat_server_addr,
heartbeat_exit_callback_func=client_heartbeat_exit_callback_func)
self.client_heartbeat_client_thread.setDaemon(True)
Expand Down
2 changes: 2 additions & 0 deletions parl/remote/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
worker_address(str): Worker's server address that receive command from the master.
pid(int): Optional. Process id of the job.
is_alive(True): Optional. This flag is used in worker to make sure that only alive jobs can be added into the worker_status.
instance_id (str): Optional. The ID generated by the client, which represents the instance to which the job connects.
"""
self.job_address = job_address
self.worker_heartbeat_address = worker_heartbeat_address
Expand All @@ -64,6 +65,7 @@ def __init__(self,
self.log_server_address = log_server_address
self.allocated_cpu = None # Record CPU(s) used in a job
self.allocated_gpu = None # Record GPU(s) used in a job
self.instance_id = None


class InitializedWorker(object):
Expand Down
1 change: 1 addition & 0 deletions parl/remote/remote_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LimitedActor(object):
max_memory (float): Maximum memory (MB) can be used by each remote
instance, the unit is in MB and default value is
none(unlimited).
n_gpu (int): The number of GPUs required to run the remote instance.
Returns:
A remote wrapper for the remote class.
Expand Down
Loading

0 comments on commit eb1ebbf

Please sign in to comment.