Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

78: Allow per deduplication set face distance threshold configuration #85

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/hope_dedup_engine/apps/api/deduplication/adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Generator

from constance import config

from hope_dedup_engine.apps.api.deduplication.registry import DuplicateKeyPair
from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.faces.services.duplication_detector import (
Expand All @@ -20,8 +22,14 @@ def run(self) -> Generator[DuplicateKeyPair, None, None]:
"reference_pk", "filename"
)
}
face_distance_threshold: float = (
self.deduplication_set.config
and self.deduplication_set.config.face_distance_threshold
) or config.FACE_DISTANCE_THRESHOLD
# ignored key pairs are not handled correctly in DuplicationDetector
detector = DuplicationDetector(tuple[str](filename_to_reference_pk.keys()), ())
detector = DuplicationDetector(
tuple[str](filename_to_reference_pk.keys()), face_distance_threshold
)
for first_filename, second_filename, distance in detector.find_duplicates():
yield filename_to_reference_pk[first_filename], filename_to_reference_pk[
second_filename
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 5.0.7 on 2024-09-24 09:05

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("api", "0004_remove_deduplicationset_error_and_more"),
]

operations = [
migrations.CreateModel(
name="Config",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("face_distance_threshold", models.FloatField(null=True)),
],
),
migrations.AddField(
model_name="deduplicationset",
name="config",
field=models.OneToOneField(
null=True, on_delete=django.db.models.deletion.SET_NULL, to="api.config"
),
),
]
5 changes: 5 additions & 0 deletions src/hope_dedup_engine/apps/api/models/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
REFERENCE_PK_LENGTH: Final[int] = 100


class Config(models.Model):
face_distance_threshold = models.FloatField(null=True)


class DeduplicationSet(models.Model):
"""
Bucket for entries we want to deduplicate
Expand Down Expand Up @@ -52,6 +56,7 @@ class State(models.IntegerChoices):
)
updated_at = models.DateTimeField(auto_now=True)
notification_url = models.CharField(max_length=255, null=True, blank=True)
config = models.OneToOneField(Config, null=True, on_delete=models.SET_NULL)

def __str__(self) -> str:
return f"ID: {self.pk}" if not self.name else f"{self.name}"
Expand Down
23 changes: 22 additions & 1 deletion src/hope_dedup_engine/apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.api.models.deduplication import (
Config,
Duplicate,
IgnoredKeyPair,
Image,
)

CONFIG = "config"


class ConfigSerializer(serializers.ModelSerializer):
class Meta:
model = Config
exclude = ("id",)


class DeduplicationSetSerializer(serializers.ModelSerializer):
state = serializers.CharField(source="get_state_display", read_only=True)
config = ConfigSerializer(required=False)

class Meta:
model = DeduplicationSet
Expand All @@ -25,11 +35,22 @@ class Meta:
"updated_by",
)

def create(self, validated_data) -> DeduplicationSet:
config_data = validated_data.get(CONFIG) and validated_data.pop(CONFIG)
config = Config.objects.create(**config_data) if config_data else None
return DeduplicationSet.objects.create(config=config, **validated_data)


class CreateConfigSerializer(ConfigSerializer):
pass


class CreateDeduplicationSetSerializer(serializers.ModelSerializer):
config = CreateConfigSerializer(required=False)

class Meta:
model = DeduplicationSet
fields = ("reference_pk", "notification_url")
fields = ("config", "reference_pk", "notification_url")


class ImageSerializer(serializers.ModelSerializer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import face_recognition
import numpy as np
from constance import config

from hope_dedup_engine.apps.faces.managers import StorageManager
from hope_dedup_engine.apps.faces.services.image_processor import ImageProcessor
Expand All @@ -20,7 +19,10 @@ class DuplicationDetector:
logger: logging.Logger = logging.getLogger(__name__)

def __init__(
self, filenames: tuple[str], ignore_pairs: tuple[tuple[str, str], ...] = tuple()
self,
filenames: tuple[str],
face_distance_threshold: float,
ignore_pairs: tuple[tuple[str, str], ...] = (),
) -> None:
"""
Initialize the DuplicationDetector with the given filenames and ignore pairs.
Expand All @@ -31,9 +33,10 @@ def __init__(
The pairs of filenames to ignore. Defaults to an empty tuple.
"""
self.filenames = filenames
self.face_distance_threshold = face_distance_threshold
self.ignore_set = IgnorePairsValidator.validate(ignore_pairs)
self.storages = StorageManager()
self.image_processor = ImageProcessor()
self.image_processor = ImageProcessor(face_distance_threshold)

def _encodings_filename(self, filename: str) -> str:
"""
Expand Down Expand Up @@ -122,7 +125,7 @@ def find_duplicates(self) -> Generator[tuple[str, str, float], None, None]:
encodings_all = self._load_encodings_all()

for path1, path2 in combinations(existed_images_name, 2):
min_distance = config.FACE_DISTANCE_THRESHOLD
min_distance = self.face_distance_threshold
encodings1 = encodings_all.get(path1)
encodings2 = encodings_all.get(path2)
if encodings1 is None or encodings2 is None:
Expand All @@ -136,7 +139,7 @@ def find_duplicates(self) -> Generator[tuple[str, str, float], None, None]:
) < min_distance:
min_distance = current_min

if min_distance < config.FACE_DISTANCE_THRESHOLD:
if min_distance < self.face_distance_threshold:
yield (path1, path2, round(min_distance, 5))
except Exception as e:
self.logger.exception(
Expand Down
4 changes: 2 additions & 2 deletions src/hope_dedup_engine/apps/faces/services/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ImageProcessor:

logger: logging.Logger = logging.getLogger(__name__)

def __init__(self) -> None:
def __init__(self, face_distance_threshold: float) -> None:
"""
Initialize the ImageProcessor with the required configurations.
"""
Expand All @@ -75,7 +75,7 @@ def __init__(self) -> None:
model=config.FACE_ENCODINGS_MODEL,
)
self.face_detection_confidence: float = config.FACE_DETECTION_CONFIDENCE
self.distance_threshold: float = config.FACE_DISTANCE_THRESHOLD
self.distance_threshold: float = face_distance_threshold
self.nms_threshold: float = config.NMS_THRESHOLD

def _get_face_detections_dnn(
Expand Down
4 changes: 3 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NoDuplicateFinder,
)
from testutils.factories.api import (
ConfigFactory,
DeduplicationSetFactory,
DuplicateFactory,
IgnoredKeyPairFactory,
Expand All @@ -26,14 +27,15 @@
register(ExternalSystemFactory)
register(UserFactory)
register(DeduplicationSetFactory, external_system=LazyFixture("external_system"))
register(ImageFactory, deduplication_Set=LazyFixture("deduplication_set"))
register(ImageFactory, deduplication_set=LazyFixture("deduplication_set"))
register(
ImageFactory,
_name="second_image",
deduplication_Set=LazyFixture("deduplication_set"),
)
register(DuplicateFactory, deduplication_set=LazyFixture("deduplication_set"))
register(IgnoredKeyPairFactory, deduplication_set=LazyFixture("deduplication_set"))
register(ConfigFactory)


@fixture
Expand Down
48 changes: 43 additions & 5 deletions tests/api/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from random import random
from unittest.mock import MagicMock

from constance.test.unittest import override_config
from pytest import fixture
from pytest_mock import MockerFixture

from hope_dedup_engine.apps.api.deduplication.adapters import DuplicateFaceFinder
from hope_dedup_engine.apps.api.models import DeduplicationSet, Image


@fixture
def duplication_detector(mocker: MockerFixture) -> MagicMock:
yield mocker.patch(
"hope_dedup_engine.apps.api.deduplication.adapters.DuplicationDetector"
)


def test_duplicate_face_finder_uses_duplication_detector(
deduplication_set: DeduplicationSet,
image: Image,
second_image: Image,
mocker: MockerFixture,
duplication_detector: MagicMock,
) -> None:
duplication_detector = mocker.patch(
"hope_dedup_engine.apps.api.deduplication.adapters.DuplicationDetector"
)
duplication_detector.return_value.find_duplicates.return_value = iter(
(
(
Expand All @@ -27,7 +36,8 @@ def test_duplicate_face_finder_uses_duplication_detector(
found_pairs = tuple(finder.run())

duplication_detector.assert_called_once_with(
(image.filename, second_image.filename), ()
(image.filename, second_image.filename),
deduplication_set.config.face_distance_threshold,
)
duplication_detector.return_value.find_duplicates.assert_called_once()
assert len(found_pairs) == 1
Expand All @@ -36,3 +46,31 @@ def test_duplicate_face_finder_uses_duplication_detector(
second_image.reference_pk,
1 - distance,
)


def _run_duplicate_face_finder(deduplication_set: DeduplicationSet) -> None:
finder = DuplicateFaceFinder(deduplication_set)
tuple(finder.run()) # tuple is used to make generator finish execution


def test_duplication_detector_is_initiated_with_correct_face_distance_threshold_value(
deduplication_set: DeduplicationSet,
duplication_detector: MagicMock,
) -> None:
# deduplication set face_distance_threshold config value is used
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with(
(), deduplication_set.config.face_distance_threshold
)
face_distance_threshold = random()
with override_config(FACE_DISTANCE_THRESHOLD=face_distance_threshold):
# value from global config is used when face_distance_threshold is not set in deduplication set config
duplication_detector.reset_mock()
deduplication_set.config.face_distance_threshold = None
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with((), face_distance_threshold)
# value from global config is used when deduplication set has no config
duplication_detector.reset_mock()
deduplication_set.config = None
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with((), face_distance_threshold)
31 changes: 24 additions & 7 deletions tests/api/test_deduplication_set_create.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,63 @@
from api_const import DEDUPLICATION_SET_LIST_VIEW, JSON
from pytest import mark
from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APIClient
from testutils.factories.api import DeduplicationSetFactory

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.api.serializers import DeduplicationSetSerializer
from hope_dedup_engine.apps.api.serializers import CreateDeduplicationSetSerializer


def test_can_create_deduplication_set(api_client: APIClient) -> None:
previous_amount = DeduplicationSet.objects.count()
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_201_CREATED
assert DeduplicationSet.objects.count() == previous_amount + 1
data = response.json()
assert data["state"] == DeduplicationSet.State.CLEAN.label


def test_missing_fields_handling(api_client: APIClient) -> None:
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
del data["reference_pk"]

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()
assert len(errors) == 1
assert "reference_pk" in errors


def test_invalid_values_handling(api_client: APIClient) -> None:
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data["reference_pk"] = None
@mark.parametrize("field", ("reference_pk", "config"))
def test_invalid_values_handling(field: str, api_client: APIClient) -> None:
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data[field] = None

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()
assert len(errors) == 1
assert "reference_pk" in errors
assert field in errors


def test_can_set_deduplication_set_without_config(api_client: APIClient) -> None:
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
del data["config"]

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_201_CREATED
9 changes: 9 additions & 0 deletions tests/extras/testutils/factories/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from hope_dedup_engine.apps.api.models import DeduplicationSet, HDEToken
from hope_dedup_engine.apps.api.models.deduplication import (
Config,
Duplicate,
IgnoredKeyPair,
Image,
Expand All @@ -17,11 +18,19 @@ class Meta:
model = HDEToken


class ConfigFactory(DjangoModelFactory):
face_distance_threshold = fuzzy.FuzzyFloat(low=0.1, high=1.0)

class Meta:
model = Config


class DeduplicationSetFactory(DjangoModelFactory):
reference_pk = fuzzy.FuzzyText()
external_system = SubFactory(ExternalSystemFactory)
state = DeduplicationSet.State.CLEAN
notification_url = fuzzy.FuzzyText(prefix="https://")
config = SubFactory(ConfigFactory)

class Meta:
model = DeduplicationSet
Expand Down
Loading
Loading