diff --git a/README.md b/README.md index 1ac59e5..a11acdf 100644 --- a/README.md +++ b/README.md @@ -48,11 +48,14 @@ The service uses the following storage backends: FILE_STORAGE_DEFAULT=django.core.files.storage.FileSystemStorage ``` - **FILE_STORAGE_DNN**: -This backend is dedicated to storing DNN model files. Ensure that the following two files are present in this storage and that they have exactly these names: - 1. *deploy.prototxt*: Defines the model architecture. Download it from the [model architecture link](https://raw.githubusercontent.com/sr6033/face-detection-with-OpenCV-and-DNN/master/deploy.prototxt.txt) - 2. *res10_300x300_ssd_iter_140000.caffemodel*: Contains the pre-trained model weights. Download it from the [model weights link](https://raw.githubusercontent.com/sr6033/face-detection-with-OpenCV-and-DNN/master/res10_300x300_ssd_iter_140000.caffemodel) +This backend is dedicated to storing DNN model files. Ensure that the following two files are present in this storage: + 1. *deploy.prototxt*: Defines the model architecture. + 2. *res10_300x300_ssd_iter_140000.caffemodel*: Contains the pre-trained model weights. - These files can be updated in the future by a dedicated pipeline that handles model training. The storage configuration for this backend is as follows: + The current process involves downloading files from a [GitHub repository](https://github.com/sr6033/face-detection-with-OpenCV-and-DNN) and saving them to Azure Blob Storage **FILE_STORAGE_DNN** using command `django-admin upgrade --with-worker-upgrade`, or through dedicated command `django-admin workerupgrade`. + In the future, an automated pipeline related to model training could handle file updates. + + The storage configuration for this backend is as follows: ``` FILE_STORAGE_DNN="storages.backends.azure_storage.AzureStorage?account_name=&account_key=&overwrite_files=true&azure_container=dnn" ``` diff --git a/src/hope_dedup_engine/apps/core/management/commands/upgrade.py b/src/hope_dedup_engine/apps/core/management/commands/upgrade.py index aa29e71..db7ca16 100644 --- a/src/hope_dedup_engine/apps/core/management/commands/upgrade.py +++ b/src/hope_dedup_engine/apps/core/management/commands/upgrade.py @@ -64,7 +64,13 @@ def add_arguments(self, parser: "ArgumentParser") -> None: default=True, help="Do not run collectstatic", ) - + parser.add_argument( + "--with-worker-upgrade", + action="store_true", + dest="worker_upgrade", + default=False, + help="Run upgrade for celery worker", + ) parser.add_argument( "--admin-email", action="store", @@ -86,6 +92,7 @@ def get_options(self, options: dict[str, Any]) -> None: self.prompt = not options["prompt"] self.static = options["static"] self.migrate = options["migrate"] + self.worker_upgrade = options["worker_upgrade"] self.debug = options["debug"] self.admin_email = str(options["admin_email"] or env("ADMIN_EMAIL", "")) @@ -120,6 +127,10 @@ def handle(self, *args: Any, **options: Any) -> None: # noqa: C901 } echo("Running upgrade", style_func=self.style.WARNING) + if self.worker_upgrade: + echo("Run upgrade for celery worker:") + call_command("workerupgrade") + call_command("env", check=True) if self.run_check: diff --git a/src/hope_dedup_engine/apps/core/management/commands/workerupgrade.py b/src/hope_dedup_engine/apps/core/management/commands/workerupgrade.py new file mode 100644 index 0000000..3c47738 --- /dev/null +++ b/src/hope_dedup_engine/apps/core/management/commands/workerupgrade.py @@ -0,0 +1,162 @@ +import logging +import sys +from argparse import ArgumentParser +from typing import Any, Final + +from django.conf import settings +from django.core.exceptions import ValidationError +from django.core.management import BaseCommand +from django.core.management.base import CommandError, SystemCheckError + +import requests +from storages.backends.azure_storage import AzureStorage + +logger = logging.getLogger(__name__) + + +MESSAGES: Final[dict[str, str]] = { + "already": "File '%s' already exists in FILE_STORAGE_DNN storage.", + "process": "Downloading file from '%s' to '%s' in FILE_STORAGE_DNN storage...", + "empty": "File at '%s' is empty (size is 0 bytes).", + "halted": "\n\n***\nSYSTEM HALTED\nUnable to start without DNN files...", +} + + +class Command(BaseCommand): + help = "Synchronizes DNN files from the git to azure storage" + dnn_files = None + + def add_arguments(self, parser: ArgumentParser) -> None: + """ + Adds custom command-line arguments to the management command. + + Args: + parser (ArgumentParser): The argument parser instance to which the arguments should be added. + + Adds the following arguments: + --force: A boolean flag that, when provided, forces the re-download of files even if they already exist + in Azure storage. Defaults to False. + --deployfile-url (str): The URL from which the deploy (prototxt) file is downloaded. + Defaults to the value set in the project settings. + --caffemodelfile-url (str): The URL from which the pre-trained model weights (caffemodel) are downloaded. + Defaults to the value set in the project settings. + --download-timeout (int): The maximum time allowed for downloading files, in seconds. + Defaults to 3 minutes (180 seconds). + --chunk-size (int): The size of each chunk to download in bytes. Defaults to 256 KB. + """ + parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force the re-download of files even if they already exist", + ) + parser.add_argument( + "--deployfile-url", + type=str, + default=settings.DNN_FILES.get("prototxt", {}) + .get("sources", {}) + .get("github"), + help="The URL of the model architecture (deploy) file", + ) + parser.add_argument( + "--caffemodelfile-url", + type=str, + default=settings.DNN_FILES.get("caffemodel", {}) + .get("sources", {}) + .get("github"), + help="The URL of the pre-trained model weights (caffemodel) file", + ) + parser.add_argument( + "--download-timeout", + type=int, + default=3 * 60, # 3 minutes + help="The timeout for downloading files", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=256 * 1024, # 256 KB + help="The size of each chunk to download in bytes", + ) + + def get_options(self, options: dict[str, Any]) -> None: + self.verbosity = options["verbosity"] + self.force = options["force"] + self.dnn_files = ( + { + "url": options["deployfile_url"], + "filename": settings.DNN_FILES.get("prototxt", {}) + .get("sources", {}) + .get("azure"), + }, + { + "url": options["caffemodelfile_url"], + "filename": settings.DNN_FILES.get("caffemodel", {}) + .get("sources", {}) + .get("azure"), + }, + ) + self.download_timeout = options["download_timeout"] + self.chunk_size = options["chunk_size"] + + def handle(self, *args: Any, **options: Any) -> None: + """ + Executes the command to download and store DNN files from a given source to Azure Blob Storage. + + Args: + *args (Any): Positional arguments passed to the command. + **options (dict[str, Any]): Keyword arguments passed to the command, including: + - force (bool): If True, forces the re-download of files even if they already exist in storage. + - deployfile_url (str): The URL of the DNN model architecture file to download. + - caffemodelfile_url (str): The URL of the pre-trained model weights to download. + - download_timeout (int): Timeout for downloading each file, in seconds. + - chunk_size (int): The size of chunks for streaming downloads, in bytes. + + Raises: + FileNotFoundError: If the downloaded file is empty (size is 0 bytes). + ValidationError: If any arguments are invalid or improperly configured. + CommandError: If an issue occurs with the Django command execution. + SystemCheckError: If a system check error is encountered during execution. + Exception: For any other errors that occur during the download or storage process. + """ + self.get_options(options) + if self.verbosity >= 1: + echo = self.stdout.write + else: + echo = lambda *a, **kw: None # noqa: E731 + + try: + dnn_storage = AzureStorage(**settings.STORAGES.get("dnn").get("OPTIONS")) + _, files = dnn_storage.listdir("") + for file in self.dnn_files: + if self.force or not file.get("filename") in files: + echo(MESSAGES["process"] % (file.get("url"), file.get("filename"))) + with requests.get( + file.get("url"), stream=True, timeout=self.download_timeout + ) as r: + r.raise_for_status() + if int(r.headers.get("Content-Length", 1)) == 0: + raise FileNotFoundError(MESSAGES["empty"] % file.get("url")) + with dnn_storage.open(file.get("filename"), "wb") as f: + for chunk in r.iter_content(chunk_size=self.chunk_size): + f.write(chunk) + else: + echo(MESSAGES["already"] % file.get("filename")) + except ValidationError as e: + self.halt(Exception("\n- ".join(["Wrong argument(s):", *e.messages]))) + except (CommandError, FileNotFoundError, SystemCheckError) as e: + self.halt(e) + except Exception as e: + self.halt(e) + + def halt(self, e: Exception) -> None: + """ + Handle an exception by logging the error and exiting the program. + + Args: + e (Exception): The exception that occurred. + """ + logger.exception(e) + self.stdout.write(self.style.ERROR(str(e))) + self.stdout.write(self.style.ERROR(MESSAGES["halted"])) + sys.exit(1) diff --git a/tests/test_command_worker.py b/tests/test_command_worker.py new file mode 100644 index 0000000..1ec9444 --- /dev/null +++ b/tests/test_command_worker.py @@ -0,0 +1,90 @@ +from io import StringIO +from typing import Final +from unittest import mock + +from django.core.exceptions import ValidationError +from django.core.management import call_command +from django.core.management.base import CommandError, SystemCheckError + +import pytest +from pytest_mock import MockerFixture + +DNN_FILES: Final[tuple[dict[str, str]]] = ( + {"url": "http://example.com/file1", "filename": "file1"}, + {"url": "http://example.com/file2", "filename": "file2"}, +) + + +@pytest.fixture +def mock_requests_get(): + with mock.patch("requests.get") as mock_get: + mock_response = mock_get.return_value.__enter__.return_value + mock_response.iter_content.return_value = [b"Hello, world!"] * 3 + mock_response.raise_for_status = lambda: None + yield mock_get + + +@pytest.fixture +def mock_azurite_manager(mocker: MockerFixture): + yield mocker.patch( + "hope_dedup_engine.apps.core.management.commands.workerupgrade.AzureStorage", + ) + + +@pytest.fixture +def mock_dnn_files(mocker: MockerFixture): + yield mocker.patch( + "hope_dedup_engine.apps.core.management.commands.workerupgrade.Command.dnn_files", + new_callable=mocker.PropertyMock, + ) + + +@pytest.mark.parametrize( + "force, expected_count, existing_files", + [ + (False, 2, []), + (False, 1, [DNN_FILES[0]["filename"]]), + (False, 0, [f["filename"] for f in DNN_FILES][:2]), + (True, 2, []), + (True, 2, [DNN_FILES[0]["filename"]]), + (True, 2, [f["filename"] for f in DNN_FILES][:2]), + ], +) +def test_workerupgrade_handle_success( + mock_requests_get, + mock_azurite_manager, + mock_dnn_files, + force, + expected_count, + existing_files, +): + mock_dnn_files.return_value = DNN_FILES + mock_azurite_manager().listdir.return_value = ([], existing_files) + out = StringIO() + + call_command("workerupgrade", stdout=out, force=force) + + assert "SYSTEM HALTED" not in out.getvalue() + assert mock_requests_get.call_count == expected_count + assert mock_azurite_manager().open.call_count == expected_count + + +@pytest.mark.parametrize( + "side_effect, expected_exception", + [ + (FileNotFoundError("File not found"), SystemExit), + (ValidationError("Invalid argument"), SystemExit), + (CommandError("Command execution failed"), SystemExit), + (SystemCheckError("System check failed"), SystemExit), + (Exception("Unknown error"), SystemExit), + ], +) +def test_workerupgrade_handle_exception( + mock_requests_get, mock_azurite_manager, side_effect, expected_exception +): + mock_azurite_manager.side_effect = side_effect + out = StringIO() + with pytest.raises(expected_exception): + call_command("workerupgrade", stdout=out) + + assert "SYSTEM HALTED" in out.getvalue() diff --git a/tests/test_commands.py b/tests/test_commands.py index 3c8b4a9..e2b745d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -59,6 +59,7 @@ def test_upgrade_init( migrate=migrate, stdout=out, check=False, + worker_upgrade=False, verbosity=verbosity, ) assert "error" not in str(out.getvalue()) @@ -72,7 +73,13 @@ def test_upgrade(verbosity, migrate, monkeypatch, environment): out = StringIO() SuperUserFactory() with mock.patch.dict(os.environ, environment, clear=True): - call_command("upgrade", stdout=out, check=False, verbosity=verbosity) + call_command( + "upgrade", + stdout=out, + check=False, + worker_upgrade=False, + verbosity=verbosity, + ) assert "error" not in str(out.getvalue()) @@ -98,7 +105,14 @@ def test_upgrade_admin(db, mocked_responses, environment, admin): out = StringIO() with mock.patch.dict(os.environ, environment, clear=True): - call_command("upgrade", stdout=out, check=False, admin_email=email) + call_command( + "upgrade", + stdout=out, + check=False, + worker_upgrade=False, + static=False, + admin_email=email, + ) @pytest.mark.parametrize("verbosity", [0, 1], ids=["0", "1"])