Skip to content

Feature/allocation model save tests #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion coldfront/core/allocation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ast import literal_eval
from enum import Enum

from django.conf import settings
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError
from django.db import models
Expand Down Expand Up @@ -131,7 +132,7 @@ def save(self, *args, **kwargs):
if self.pk:
old_obj = Allocation.objects.get(pk=self.pk)
if old_obj.status.name != self.status.name and self.status.name == "Expired":
for func_string in ALLOCATION_FUNCS_ON_EXPIRE:
for func_string in settings.ALLOCATION_FUNCS_ON_EXPIRE:
func_to_run = import_string(func_string)
func_to_run(self.pk)

Expand Down
153 changes: 152 additions & 1 deletion coldfront/core/allocation/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""Unit tests for the allocation models"""

import datetime
import typing

from django.core.exceptions import ValidationError
from django.test import TestCase
from django.test import TestCase, override_settings
from django.utils import timezone

from coldfront.core.allocation.models import (
Expand Down Expand Up @@ -140,3 +141,153 @@ def test_status_is_active_and_start_date_equals_end_date_no_error(self):
status=self.active_status, start_date=start_and_end_date, end_date=start_and_end_date, project=self.project
)
actual_allocation.full_clean()


class AllocationFuncOnExpireException(Exception):
"""Custom exception for testing allocation expiration function in the AllocationModelSaveMethodTests class."""

pass


def allocation_func_on_expire_exception(*args, **kwargs):
"""Test function to be called on allocation expiration in the AllocationModelSaveMethodTests class."""
raise AllocationFuncOnExpireException("This is a test exception for allocation expiration.")


def get_dotted_path(func):
"""Return the dotted path string for a Python function in the AllocationModelSaveMethodTests class."""
return f"{func.__module__}.{func.__qualname__}"


NUMBER_OF_INVOCATIONS = 12


def count_invocations(*args, **kwargs):
count_invocations.invocation_count = getattr(count_invocations, "invocation_count", 0) + 1 # type: ignore


def count_invocations_negative(*args, **kwargs):
count_invocations_negative.invocation_count = getattr(count_invocations_negative, "invocation_count", 0) - 1 # type: ignore


def list_of_same_expire_funcs(func: typing.Callable, size=NUMBER_OF_INVOCATIONS) -> list[str]:
return [get_dotted_path(func) for _ in range(size)]


def list_of_different_expire_funcs() -> list[str]:
"""Return a list of different functions to be called on allocation expiration.
The list will have a length of NUMBER_OF_INVOCATIONS, with the last function being allocation_func_on_expire_exception.
If NUMBER_OF_INVOCATIONS is odd, the list will contain (NUMBER_OF_INVOCATIONS // 2) instances of count_invocations and (NUMBER_OF_INVOCATIONS // 2) instances of count_invocations_negative.
If NUMBER_OF_INVOCATIONS is even, the list will contain (NUMBER_OF_INVOCATIONS // 2) instances of count_invocations and ((NUMBER_OF_INVOCATIONS // 2)-1) instances of count_invocations_negative.
"""
expire_funcs: list[str] = []
for i in range(NUMBER_OF_INVOCATIONS):
if i == (NUMBER_OF_INVOCATIONS - 1):
expire_funcs.append(get_dotted_path(allocation_func_on_expire_exception))
elif i % 2 == 0:
expire_funcs.append(get_dotted_path(count_invocations))
else:
expire_funcs.append(get_dotted_path(count_invocations_negative))
return expire_funcs


class AllocationModelSaveMethodTests(TestCase):
path_to_allocation_models_funcs_on_expire: str = "coldfront.core.allocation.models.ALLOCATION_FUNCS_ON_EXPIRE"

def setUp(self):
count_invocations.invocation_count = 0 # type: ignore
count_invocations_negative.invocation_count = 0 # type: ignore

@classmethod
def setUpTestData(cls):
"""Set up allocation to test clean method"""
cls.active_status: AllocationStatusChoice = AllocationStatusChoiceFactory(name="Active")
cls.expired_status: AllocationStatusChoice = AllocationStatusChoiceFactory(name="Expired")
cls.other_status: AllocationStatusChoice = AllocationStatusChoiceFactory(name="Other")
cls.project: Project = ProjectFactory()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_on_expiration_calls_single_func_in_funcs_on_expire(self):
"""Test that the allocation save method calls the functions specified in ALLOCATION_FUNCS_ON_EXPIRE when it expires."""
allocation = AllocationFactory(status=self.active_status)
with self.assertRaises(AllocationFuncOnExpireException):
allocation.status = self.expired_status
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(count_invocations))
def test_on_expiration_calls_multiple_funcs_in_funcs_on_expire(self):
"""Test that the allocation save method calls a function multiple times when ALLOCATION_FUNCS_ON_EXPIRE has multiple instances of it."""
allocation = AllocationFactory(status=self.active_status)
allocation.status = self.expired_status
allocation.save()
self.assertEqual(count_invocations.invocation_count, NUMBER_OF_INVOCATIONS) # type: ignore

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_different_expire_funcs())
def test_on_expiration_calls_multiple_different_funcs_in_funcs_on_expire(self):
"""Test that the allocation save method calls all the different functions present in the list ALLOCATION_FUNCS_ON_EXPIRE."""
allocation = AllocationFactory(status=self.active_status)
allocation.status = self.expired_status

# the last function in the list is allocation_func_on_expire_exception, which raises an exception
with self.assertRaises(AllocationFuncOnExpireException):
allocation.save()

# the other functions will have been called a different number of times depending on whether NUMBER_OF_INVOCATIONS is odd or even
if NUMBER_OF_INVOCATIONS % 2 == 0:
expected_positive_invocations = NUMBER_OF_INVOCATIONS // 2
expected_negative_invocations = -((NUMBER_OF_INVOCATIONS // 2) - 1)
self.assertEqual(count_invocations.invocation_count, expected_positive_invocations) # type: ignore
self.assertEqual(count_invocations_negative.invocation_count, expected_negative_invocations) # type: ignore
else:
expected_positive_invocations = NUMBER_OF_INVOCATIONS // 2
expected_negative_invocations = -(NUMBER_OF_INVOCATIONS // 2)
self.assertEqual(count_invocations.invocation_count, expected_positive_invocations) # type: ignore
self.assertEqual(count_invocations_negative.invocation_count, expected_negative_invocations) # type: ignore

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_no_expire_no_funcs_on_expire_called(self):
"""Test that the allocation save method does not call any functions when the allocation is not expired."""
allocation = AllocationFactory(status=self.active_status)
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_allocation_changed_but_always_expired_no_funcs_on_expire_called(self):
"""Test that the allocation save method does not call any functions when the allocation is always expired."""
allocation = AllocationFactory(status=self.expired_status)
allocation.justification = "This allocation is always expired."
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_allocation_changed_but_never_expired_no_funcs_on_expire_called(self):
"""Test that the allocation save method does not call any functions when the allocation is never expired."""
allocation = AllocationFactory(status=self.active_status)
allocation.status = self.other_status
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_allocation_always_expired_no_funcs_on_expire_called(self):
"""Test that the allocation save method does not call any functions when the allocation is always expired."""
allocation = AllocationFactory(status=self.expired_status)
allocation.justification = "This allocation is always expired."
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=list_of_same_expire_funcs(allocation_func_on_expire_exception, 1))
def test_allocation_reactivated_no_funcs_on_expire_called(self):
"""Test that the allocation save method does not call any functions when the allocation is reactivated."""
allocation = AllocationFactory(status=self.expired_status)
allocation.status = self.active_status
allocation.save()

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=[])
def test_new_allocation_is_in_database(self):
"""Test that a new allocation is saved in the database."""
allocation: Allocation = AllocationFactory(status=self.active_status)
allocation.save()
self.assertTrue(Allocation.objects.filter(id=allocation.id).exists())

@override_settings(ALLOCATION_FUNCS_ON_EXPIRE=[])
def test_multiple_new_allocations_are_in_database(self):
"""Test that multiple new allocations are saved in the database."""
allocations = [AllocationFactory(status=self.active_status) for _ in range(25)]
for allocation in allocations:
self.assertTrue(Allocation.objects.filter(id=allocation.id).exists())