diff --git a/tests/conftest.py b/tests/conftest.py index acebbb6ff..5e5c31495 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,19 +18,19 @@ @pytest.fixture() -def participant_metrics_sample(): +def json_participant_metrics_sample(): """Return a valid participant metric object.""" return json.dumps( [ { "measurement": "participant", - "time": 1234326435, + "time": 1582017483 * 1_000_000_000, "tags": {"id": "127.0.0.1:1345"}, "fields": {"CPU_1": 90.8, "CPU_2": 90, "CPU_3": "23", "CPU_4": 0.00,}, }, { "measurement": "participant", - "time": 3542626236, + "time": 1582017484 * 1_000_000_000, "tags": {"id": "127.0.0.1:1345"}, "fields": {"CPU_1": 90.8, "CPU_2": 90, "CPU_3": "23", "CPU_4": 0.00,}, }, diff --git a/tests/test_grpc.py b/tests/test_grpc.py index d6adf6d1b..e642f00fc 100644 --- a/tests/test_grpc.py +++ b/tests/test_grpc.py @@ -349,7 +349,7 @@ def test_start_training_round_failed_precondition( # pylint: disable=unused-arg @pytest.mark.integration -def test_end_training_round(coordinator_service, participant_metrics_sample): +def test_end_training_round(coordinator_service, json_participant_metrics_sample): """[summary] .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) @@ -369,7 +369,7 @@ def test_end_training_round(coordinator_service, participant_metrics_sample): rendezvous(channel) # call EndTrainingRound service method on coordinator end_training_round( - channel, test_weights, number_samples, participant_metrics_sample + channel, test_weights, number_samples, json_participant_metrics_sample ) # check local model received... diff --git a/tests/test_metric_store.py b/tests/test_metric_store.py index 25704baed..0887a5f4b 100644 --- a/tests/test_metric_store.py +++ b/tests/test_metric_store.py @@ -4,6 +4,7 @@ from influxdb import InfluxDBClient import pytest +from xain_proto.fl.coordinator_pb2 import State from xain_fl.config import MetricsConfig from xain_fl.coordinator.metrics_store import MetricsStore, MetricsStoreError @@ -29,101 +30,183 @@ def invalid_json_participant_metrics_sample(): ) +@pytest.fixture() +def participant_metrics_sample(): + """Return a valid metric object.""" + return {"state": State.FINISHED} + + @mock.patch.object(InfluxDBClient, "write_points", return_value=True) -def test_valid_participant_metrics( - write_points_mock, participant_metrics_sample, +def test_write_received_participant_metrics( + write_points_mock, json_participant_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that write_points does not raise an exception on a valid metric object.""" + """Test test_write_received_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) - metric_store.write_participant_metrics(participant_metrics_sample) - write_points_mock.assert_called_once() + metric_store.write_received_participant_metrics(json_participant_metrics_sample) + write_points_mock.assert_called_with( + [ + { + "measurement": "participant", + "time": 1582017483 * 1_000_000_000, + "tags": {"id": "127.0.0.1:1345"}, + "fields": {"CPU_1": 90.8, "CPU_2": 90, "CPU_3": "23", "CPU_4": 0.00,}, + }, + { + "measurement": "participant", + "time": 1582017484 * 1_000_000_000, + "tags": {"id": "127.0.0.1:1345"}, + "fields": {"CPU_1": 90.8, "CPU_2": 90, "CPU_3": "23", "CPU_4": 0.00,}, + }, + ] + ) @mock.patch.object(InfluxDBClient, "write_points", side_effect=Exception()) -def test_write_points_exception_handling_write_participant_metrics( - write_points_mock, participant_metrics_sample, +def test_write_received_participant_metrics_write_points_exception( + write_points_mock, json_participant_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that raised exceptions of the write_points method are caught in the - write_participant_metrics method.""" + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_received_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) with pytest.raises(MetricsStoreError): - metric_store.write_participant_metrics(participant_metrics_sample) + metric_store.write_received_participant_metrics(json_participant_metrics_sample) @mock.patch.object(InfluxDBClient, "write_points", return_value=True) -def test_invalid_json_exception_handling(write_points_mock): - """Check that raised exceptions of the write_points method are caught in the - write_participant_metrics method.""" +def test_write_received_participant_metrics_invalid_json_exception(write_points_mock): + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_received_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) with pytest.raises(MetricsStoreError): - metric_store.write_participant_metrics('{"a": 1') + metric_store.write_received_participant_metrics('{"a": 1') write_points_mock.assert_not_called() with pytest.raises(MetricsStoreError): - metric_store.write_participant_metrics("{1: 1}") + metric_store.write_received_participant_metrics("{1: 1}") write_points_mock.assert_not_called() @mock.patch.object(InfluxDBClient, "write_points", return_value=True) -def test_empty_metrics_exception_handling( +def test_write_received_participant_metrics_empty_metrics_exception( write_points_mock, empty_json_participant_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that raised exceptions of the write_points method are caught in the - write_participant_metrics method.""" + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_received_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) with pytest.raises(MetricsStoreError): - metric_store.write_participant_metrics(empty_json_participant_metrics_sample) + metric_store.write_received_participant_metrics( + empty_json_participant_metrics_sample + ) write_points_mock.assert_not_called() @mock.patch.object(InfluxDBClient, "write_points", return_value=True) -def test_invalid_schema_exception_handling( +def test_write_received_participant_metrics_invalid_schema_exception( write_points_mock, invalid_json_participant_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that raised exceptions of the write_points method are caught in the - write_participant_metrics method.""" + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_received_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) with pytest.raises(MetricsStoreError): - metric_store.write_participant_metrics(invalid_json_participant_metrics_sample) + metric_store.write_received_participant_metrics( + invalid_json_participant_metrics_sample + ) write_points_mock.assert_not_called() +@mock.patch("xain_fl.coordinator.metrics_store.time.time", return_value=1582017483.0) @mock.patch.object(InfluxDBClient, "write_points", return_value=True) -def test_valid_coordinator_metrics( - write_points_mock, coordinator_metrics_sample, +def test_write_coordinator_metrics( + write_points_mock, time_mock, coordinator_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that write_points does not raise an exception on a valid metric object.""" + """Test write_coordinator_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) - metric_store.write_coordinator_metrics(coordinator_metrics_sample, tags={"1": "2"}) - write_points_mock.assert_called_once() + metric_store.write_metrics( + "coordinator", coordinator_metrics_sample, tags={"meta_data": "1"} + ) + + write_points_mock.assert_called_with( + [ + { + "measurement": "coordinator", + "time": 1582017483 * 1_000_000_000, + "tags": {"meta_data": "1"}, + "fields": { + "state": State.ROUND, + "round": 2, + "number_of_selected_participants": 0, + }, + } + ] + ) @mock.patch.object(InfluxDBClient, "write_points", side_effect=Exception()) -def test_write_points_exception_handling_write_coordinator_metrics( +def test_write_coordinator_metrics_write_points_exception( write_points_mock, coordinator_metrics_sample, ): # pylint: disable=redefined-outer-name,unused-argument - """Check that raised exceptions of the write_points method are caught in the - write_coordinator_metrics method.""" + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_coordinator_metrics method.""" + + metric_store = MetricsStore( + MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") + ) + with pytest.raises(MetricsStoreError): + metric_store.write_metrics("coordinator", coordinator_metrics_sample) + + +@mock.patch("xain_fl.coordinator.metrics_store.time.time", return_value=1582017483.0) +@mock.patch.object(InfluxDBClient, "write_points", return_value=True) +def test_write_participant_metrics( + write_points_mock, time_mock, participant_metrics_sample, +): # pylint: disable=redefined-outer-name,unused-argument + """Test write_participant_metrics method.""" + metric_store = MetricsStore( + MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") + ) + + metric_store.write_metrics( + "participant", participant_metrics_sample, tags={"id": "1234-1234-1234"} + ) + + write_points_mock.assert_called_with( + [ + { + "measurement": "participant", + "time": 1582017483 * 1_000_000_000, + "tags": {"id": "1234-1234-1234"}, + "fields": {"state": State.FINISHED,}, + } + ] + ) + + +@mock.patch.object(InfluxDBClient, "write_points", side_effect=Exception()) +def test_write_participant_metrics_write_points_exception( + write_points_mock, participant_metrics_sample, +): # pylint: disable=redefined-outer-name,unused-argument + """Check that raised exceptions of the write_points method are re-raised as MetricsStoreError in + the write_participant_metrics method.""" metric_store = MetricsStore( MetricsConfig(enable=True, host="", port=1, user="", password="", db_name="") ) with pytest.raises(MetricsStoreError): - metric_store.write_coordinator_metrics(coordinator_metrics_sample) + metric_store.write_metrics("participant", participant_metrics_sample) diff --git a/xain_fl/coordinator/coordinator.py b/xain_fl/coordinator/coordinator.py index ec113a0c1..e89707bf0 100644 --- a/xain_fl/coordinator/coordinator.py +++ b/xain_fl/coordinator/coordinator.py @@ -150,11 +150,12 @@ def __init__( # pylint: disable=too-many-arguments self.epochs_current_round: int = epochs self._write_metrics_fail_silently( + "coordinator", { "state": self.state, "round": self.current_round, "number_of_selected_participants": self.participants.len(), - } + }, ) def get_minimum_connected_participants(self) -> int: @@ -244,10 +245,13 @@ def remove_participant(self, participant_id: str) -> None: if self.participants.len() < self.minimum_connected_participants: self.state = State.STANDBY - self._write_metrics_fail_silently({"state": self.state}) + self._write_metrics_fail_silently("coordinator", {"state": self.state}) self._write_metrics_fail_silently( - {"number_of_selected_participants": self.participants.len()} + "participant", {"state": State.FINISHED}, tags={"id": participant_id} + ) + self._write_metrics_fail_silently( + "coordinator", {"number_of_selected_participants": self.participants.len()} ) def select_participant_ids_and_init_round(self) -> None: @@ -294,7 +298,8 @@ def _handle_rendezvous( current_participants_count=self.participants.len(), ) self._write_metrics_fail_silently( - {"number_of_selected_participants": self.participants.len()} + "coordinator", + {"number_of_selected_participants": self.participants.len()}, ) # Select participants and change the state to ROUND if the latest added participant @@ -306,7 +311,7 @@ def _handle_rendezvous( self.round.add_selected(ids) self.state = State.ROUND - self._write_metrics_fail_silently({"state": self.state}) + self._write_metrics_fail_silently("coordinator", {"state": self.state}) else: reply = RendezvousReply.LATER logger.info( @@ -321,7 +326,7 @@ def _handle_rendezvous( return RendezvousResponse(reply=reply) def _handle_heartbeat( - self, _message: HeartbeatRequest, participant_id: str + self, message: HeartbeatRequest, participant_id: str ) -> HeartbeatResponse: """Handles a Heartbeat request. @@ -331,13 +336,19 @@ def _handle_heartbeat( - ``STANDBY``: if the participant is not selected for the current round. Args: - _message: The request to handle. Currently not used. + message: The request to handle. Currently not used. participant_id: The id of the participant making the request. Returns: The response to the participant. """ + self._write_metrics_fail_silently( + "participant", + {"state": message.state, "round": message.round}, + tags={"id": participant_id}, + ) + self.participants.update_expires(participant_id) if self.state == State.FINISHED or participant_id in self.round.participant_ids: @@ -353,7 +364,7 @@ def _handle_heartbeat( current_participants_count=self.participants.len(), ) self._write_metrics_fail_silently( - {"number_of_selected_participants": self.participants.len()} + "coordinator", {"number_of_selected_participants": self.participants.len()} ) # send heartbeat response advertising the current state return HeartbeatResponse(state=state, round=self.current_round) @@ -425,7 +436,7 @@ def _handle_end_training_round( try: if message.metrics != "[]": - self.metrics_store.write_participant_metrics(message.metrics) + self.metrics_store.write_received_participant_metrics(message.metrics) except MetricsStoreError as err: logger.warn( "Can not write metrics", participant_id=participant_id, error=repr(err) @@ -456,11 +467,13 @@ def _handle_end_training_round( if self.current_round >= self.num_rounds - 1: logger.info("Last round over", round=self.current_round) self.state = State.FINISHED - self._write_metrics_fail_silently({"state": self.state}) + self._write_metrics_fail_silently("coordinator", {"state": self.state}) else: self.current_round += 1 self.epoch_base += self.epochs_current_round - self._write_metrics_fail_silently({"round": self.current_round}) + self._write_metrics_fail_silently( + "coordinator", {"round": self.current_round} + ) # reinitialize the round self.select_participant_ids_and_init_round() @@ -469,11 +482,13 @@ def _handle_end_training_round( def _write_metrics_fail_silently( self, + owner: str, metrics: Dict[str, Union[str, int, float]], tags: Optional[Dict[str, str]] = None, ) -> None: """ - Write the metrics to a metric store that are collected on the coordinator site. + Write the metrics to a metric store that are collected on the coordinator site and owned by + the given owner. If an exception is raised, it will be caught and the error logged. FIXME: Helper function to make sure that the coordinator does not crash due to exception of @@ -481,13 +496,15 @@ def _write_metrics_fail_silently( Args: + owner: The name of the owner of the metrics e.g. coordinator or participant. metrics: A dictionary with the metric names as keys and the metric values as values. tags: A dictionary to append optional metadata to the metric. Defaults to None. """ + try: - self.metrics_store.write_coordinator_metrics(metrics, tags) + self.metrics_store.write_metrics(owner, metrics, tags) except MetricsStoreError as err: - logger.warn("Can not write coordinator metrics", error=repr(err)) + logger.warn("Can not write metrics", error=repr(err), owner=owner) def pb_enum_to_str(pb_enum: EnumDescriptor, member_value: int) -> str: diff --git a/xain_fl/coordinator/metrics_store.py b/xain_fl/coordinator/metrics_store.py index 35363dca7..00ea83d0d 100644 --- a/xain_fl/coordinator/metrics_store.py +++ b/xain_fl/coordinator/metrics_store.py @@ -2,11 +2,12 @@ from abc import ABC, abstractmethod import json +from json import JSONDecodeError import time -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from influxdb import InfluxDBClient -from jsonschema import validate +from jsonschema import ValidationError, validate from structlog import get_logger from xain_fl.config import MetricsConfig @@ -15,11 +16,11 @@ logger: StructLogger = get_logger(__name__) -class AbstractMetricsStore(ABC): # pylint: disable=too-few-public-methods +class AbstractMetricsStore(ABC): """An abstract metric store.""" @abstractmethod - def write_participant_metrics(self, metrics_as_json: str) -> None: + def write_received_participant_metrics(self, metrics_as_json: str) -> None: """ Write the participant metrics on behalf of the participant into a metric store. @@ -33,31 +34,27 @@ def write_participant_metrics(self, metrics_as_json: str) -> None: """ @abstractmethod - def write_coordinator_metrics( + def write_metrics( self, + owner: str, metrics: Dict[str, Union[str, int, float]], tags: Optional[Dict[str, str]] = None, ) -> None: """ - Write the metrics to a metric store that are collected on the coordinator site. + Write metrics into a metric store. Args: + owner: The name of the owner of the metrics e.g. coordinator or participant. metrics: A dictionary with the metric names as keys and the metric values as values. tags: A dictionary to append optional metadata to the metric. Defaults to None. - - Raises: - - MetricsStoreError: If the writing of the metrics to InfluxDB has failed. """ -class NullObjectMetricsStore( - AbstractMetricsStore -): # pylint: disable=too-few-public-methods +class NullObjectMetricsStore(AbstractMetricsStore): """A metric store that does nothing.""" - def write_participant_metrics(self, metrics_as_json: str) -> None: + def write_received_participant_metrics(self, metrics_as_json: str) -> None: """ A method that has no effect. @@ -66,8 +63,9 @@ def write_participant_metrics(self, metrics_as_json: str) -> None: metrics_as_json: The metrics of a specific participant. """ - def write_coordinator_metrics( + def write_metrics( self, + owner: str, metrics: Dict[str, Union[str, int, float]], tags: Optional[Dict[str, str]] = None, ) -> None: @@ -76,12 +74,13 @@ def write_coordinator_metrics( Args: + owner: The name of the owner of the metrics e.g. coordinator or participant. metrics: A dictionary with the metric names as keys and the metric values as values. tags: A dictionary to append optional metadata to the metric. Defaults to None. """ -class MetricsStore(AbstractMetricsStore): # pylint: disable=too-few-public-methods +class MetricsStore(AbstractMetricsStore): """A metric store that uses InfluxDB to store the metrics.""" def __init__(self, config: MetricsConfig): @@ -116,7 +115,7 @@ def __init__(self, config: MetricsConfig): "minItems": 1, } - def write_participant_metrics(self, metrics_as_json: str) -> None: + def write_received_participant_metrics(self, metrics_as_json: str) -> None: """ Write the participant metrics on behalf of the participant into InfluxDB. @@ -132,13 +131,15 @@ def write_participant_metrics(self, metrics_as_json: str) -> None: try: metrics = json.loads(metrics_as_json) validate(instance=metrics, schema=self.schema) - self.influx_client.write_points(metrics) - except Exception as err: # pylint: disable=broad-except + except (ValidationError, JSONDecodeError) as err: logger.error("Exception", error=repr(err)) raise MetricsStoreError("Can not write participant metrics.") from err + else: + self._write_metrics(metrics) - def write_coordinator_metrics( + def write_metrics( self, + owner: str, metrics: Dict[str, Union[str, int, float]], tags: Optional[Dict[str, str]] = None, ) -> None: @@ -147,6 +148,7 @@ def write_coordinator_metrics( Args: + owner: The name of the owner of the metrics e.g. coordinator or participant. metrics: A dictionary with the metric names as keys and the metric values as values. tags: A dictionary to append optional metadata to the metric. Defaults to None. @@ -154,22 +156,38 @@ def write_coordinator_metrics( MetricsStoreError: If the writing of the metrics to InfluxDB has failed. """ + if not tags: tags = {} current_time: int = int(time.time() * 1_000_000_000) - influx_point = { - "measurement": "coordinator", + influx_data_point = { + "measurement": owner, "time": current_time, "tags": tags, "fields": metrics, } + self._write_metrics([influx_data_point]) + + def _write_metrics(self, influx_points: List[dict]) -> None: + """ + Write the metrics to InfluxDB that are collected on the coordinator site. + + Args: + + influx_points: InfluxDB data points. + + Raises: + + MetricsStoreError: If the writing of the metrics to InfluxDB has failed. + """ + try: - self.influx_client.write_points([influx_point]) + self.influx_client.write_points(influx_points) except Exception as err: # pylint: disable=broad-except logger.error("Exception", error=repr(err)) - raise MetricsStoreError("Can not write coordinator metrics.") from err + raise MetricsStoreError("Can not write metrics.") from err class MetricsStoreError(Exception):