diff --git a/Dockerfile b/Dockerfile index 0c9bcd1..17bd6d9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,12 @@ FROM ubuntu:22.04 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends --assume-yes \ - pip iputils-ping curl wget wkhtmltopdf + pip wkhtmltopdf \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* COPY requirements.txt /home/requirements.txt -RUN pip3 install -r /home/requirements.txt +RUN pip3 install -r /home/requirements.txt --no-cache-dir RUN ln -s /usr/bin/python3 /usr/bin/python WORKDIR /flcore diff --git a/client.py b/client.py index 93b4071..8c12178 100644 --- a/client.py +++ b/client.py @@ -3,6 +3,7 @@ from pathlib import Path import flwr as fl import yaml +import random import flcore.datasets as datasets from flcore.client_selector import get_model_client @@ -21,7 +22,7 @@ if config["production_mode"]: node_name = os.getenv("NODE_NAME") - num_client = int(node_name.split("_")[-1]) + num_client = random.randint(0, 2) data_path = os.getenv("DATA_PATH") flower_ssl_cacert = os.getenv("FLOWER_SSL_CACERT") root_certificate = Path(f"{flower_ssl_cacert}").read_bytes() diff --git a/config.yaml b/config.yaml index 4c561dc..9f010fd 100644 --- a/config.yaml +++ b/config.yaml @@ -115,4 +115,4 @@ local_port: 8081 data_path: dataset/icrc-dataset/ -production_mode: False # Turn on to use environment variables such as data path, server address, certificates etc. +production_mode: True # Turn on to use environment variables such as data path, server address, certificates etc. diff --git a/flcore/datasets.py b/flcore/datasets.py index 3ecebd1..d648161 100644 --- a/flcore/datasets.py +++ b/flcore/datasets.py @@ -276,11 +276,13 @@ def load_kaggle_hf(data_path, center_id, config) -> Dataset: if id == -1: id = 'switzerland' + elif id == 0: + id = 'hungarian' elif id == 1: id = 'hungarian' elif id == 2: id = 'va' - elif id == 0: + elif id == 3: id = 'cleveland' elif id == None: pass diff --git a/requirements.txt b/requirements.txt index 13078ec..81f149a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pandas==2.0.1 PyYAML==6.0.1 scikit_learn==1.2.2 torch==2.0.1 +--extra-index-url https://download.pytorch.org/whl/cpu torchmetrics==0.11.4 tqdm==4.65.0 xgboost==1.7.5 diff --git a/server.py b/server.py index 0b9784a..f9f0f6a 100644 --- a/server.py +++ b/server.py @@ -49,11 +49,19 @@ def check_config(config): data_path = os.getenv("DATA_PATH") central_ip = os.getenv("FLOWER_CENTRAL_SERVER_IP") central_port = os.getenv("FLOWER_CENTRAL_SERVER_PORT") + ca_cert = os.getenv("FLOWER_SSL_CACERT") # ca.crt + server_cert = os.getenv("FLOWER_SSL_SERVER_CERT") # server.pem + server_key = os.getenv("FLOWER_SSL_SERVER_KEY") # server.key certificates = ( - Path('.cache/certificates/rootCA_cert.pem').read_bytes(), - Path('.cache/certificates/server_cert.pem').read_bytes(), - Path('.cache/certificates/server_key.pem').read_bytes(), + Path(f"{ca_cert}").read_bytes(), + Path(f"{server_cert}").read_bytes(), + Path(f"{server_key}").read_bytes(), ) + # certificates = ( + # Path('.cache/certificates/rootCA_cert.pem').read_bytes(), + # Path('.cache/certificates/server_cert.pem').read_bytes(), + # Path('.cache/certificates/server_key.pem').read_bytes(), + # ) else: data_path = config["data_path"] central_ip = "LOCALHOST"