diff --git a/admin/notifications/views.py b/admin/notifications/views.py index 7a3a13a8df8..3546878e9af 100644 --- a/admin/notifications/views.py +++ b/admin/notifications/views.py @@ -1,17 +1,17 @@ -from osf.models.notifications import NotificationSubscription +from osf.models.notifications import NotificationSubscriptionLegacy from django.db.models import Count def delete_selected_notifications(selected_ids): - NotificationSubscription.objects.filter(id__in=selected_ids).delete() + NotificationSubscriptionLegacy.objects.filter(id__in=selected_ids).delete() def detect_duplicate_notifications(node_id=None): - query = NotificationSubscription.objects.values('_id').annotate(count=Count('_id')).filter(count__gt=1) + query = NotificationSubscriptionLegacy.objects.values('_id').annotate(count=Count('_id')).filter(count__gt=1) if node_id: query = query.filter(node_id=node_id) detailed_duplicates = [] for dup in query: - notifications = NotificationSubscription.objects.filter( + notifications = NotificationSubscriptionLegacy.objects.filter( _id=dup['_id'] ).order_by('created') diff --git a/admin_tests/notifications/test_views.py b/admin_tests/notifications/test_views.py index 08ad695edd1..42d182a77e5 100644 --- a/admin_tests/notifications/test_views.py +++ b/admin_tests/notifications/test_views.py @@ -1,10 +1,11 @@ import pytest from django.test import RequestFactory -from osf.models import OSFUser, NotificationSubscription, Node +from osf.models import OSFUser, Node from admin.notifications.views import ( delete_selected_notifications, detect_duplicate_notifications, ) +from osf.models.notifications import NotificationSubscriptionLegacy from tests.base import AdminTestCase pytestmark = pytest.mark.django_db @@ -18,19 +19,19 @@ def setUp(self): self.request_factory = RequestFactory() def test_delete_selected_notifications(self): - notification1 = NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event1') - notification2 = NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event2') - notification3 = NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event3') + notification1 = NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event1') + notification2 = NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event2') + notification3 = NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event3') delete_selected_notifications([notification1.id, notification2.id]) - assert not NotificationSubscription.objects.filter(id__in=[notification1.id, notification2.id]).exists() - assert NotificationSubscription.objects.filter(id=notification3.id).exists() + assert not NotificationSubscriptionLegacy.objects.filter(id__in=[notification1.id, notification2.id]).exists() + assert NotificationSubscriptionLegacy.objects.filter(id=notification3.id).exists() def test_detect_duplicate_notifications(self): - NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event1') - NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event1') - NotificationSubscription.objects.create(user=self.user, node=self.node, event_name='event2') + NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event1') + NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event1') + NotificationSubscriptionLegacy.objects.create(user=self.user, node=self.node, event_name='event2') duplicates = detect_duplicate_notifications() diff --git a/api/institutions/authentication.py b/api/institutions/authentication.py index a5588c2b034..b052834f181 100644 --- a/api/institutions/authentication.py +++ b/api/institutions/authentication.py @@ -20,10 +20,10 @@ from osf import features from osf.exceptions import InstitutionAffiliationStateError -from osf.models import Institution +from osf.models import Institution, NotificationType from osf.models.institution import SsoFilterCriteriaAction -from website.mails import send_mail, WELCOME_OSF4I, DUPLICATE_ACCOUNTS_OSF4I, ADD_SSO_EMAIL_OSF4I +from website.mails import send_mail, DUPLICATE_ACCOUNTS_OSF4I, ADD_SSO_EMAIL_OSF4I from website.settings import OSF_SUPPORT_EMAIL, DOMAIN from website.util.metrics import institution_source_tag @@ -334,14 +334,13 @@ def authenticate(self, request): user.save() # Send confirmation email for all three: created, confirmed and claimed - send_mail( - to_addr=user.username, - mail=WELCOME_OSF4I, - user=user, - domain=DOMAIN, - osf_support_email=OSF_SUPPORT_EMAIL, - storage_flag_is_active=flag_is_active(request, features.STORAGE_I18N), - ) + notification_type = NotificationType.objects.filter(name='welcome_osf4i') + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + 'NotificationType with name welcome_osf4i does not exist.', + ) + notification_type = notification_type.first() + notification_type.emit(user=user, message_frequency='instantly', event_context={'domain': DOMAIN, 'osf_support_email': OSF_SUPPORT_EMAIL, 'storage_flag_is_active': flag_is_active(request, features.STORAGE_I18N)}) # Add the email to the user's account if it is identified by the eppn if email_to_add: diff --git a/api/subscriptions/fields.py b/api/subscriptions/fields.py new file mode 100644 index 00000000000..ddbcd4f4aa5 --- /dev/null +++ b/api/subscriptions/fields.py @@ -0,0 +1,11 @@ +from rest_framework import serializers as ser + +class FrequencyField(ser.ChoiceField): + def __init__(self, **kwargs): + super().__init__(choices=['none', 'instantly', 'daily', 'weekly', 'monthly'], **kwargs) + + def to_representation(self, frequency: str): + return frequency or 'none' + + def to_internal_value(self, freq): + return super().to_internal_value(freq) diff --git a/api/subscriptions/permissions.py b/api/subscriptions/permissions.py index 19dc7bcbd58..b22831f2766 100644 --- a/api/subscriptions/permissions.py +++ b/api/subscriptions/permissions.py @@ -1,13 +1,10 @@ from rest_framework import permissions -from osf.models.notifications import NotificationSubscription +from osf.models.notification_subscription import NotificationSubscription class IsSubscriptionOwner(permissions.BasePermission): def has_object_permission(self, request, view, obj): assert isinstance(obj, NotificationSubscription), f'obj must be a NotificationSubscription; got {obj}' - user_id = request.user.id - return obj.none.filter(id=user_id).exists() \ - or obj.email_transactional.filter(id=user_id).exists() \ - or obj.email_digest.filter(id=user_id).exists() + return obj.user == request.user diff --git a/api/subscriptions/serializers.py b/api/subscriptions/serializers.py index da7aadbb1a4..ede0782ae65 100644 --- a/api/subscriptions/serializers.py +++ b/api/subscriptions/serializers.py @@ -1,58 +1,55 @@ +from django.contrib.contenttypes.models import ContentType from rest_framework import serializers as ser -from rest_framework.exceptions import ValidationError from api.nodes.serializers import RegistrationProviderRelationshipField from api.collections_providers.fields import CollectionProviderRelationshipField from api.preprints.serializers import PreprintProviderRelationshipField +from osf.models import Node from website.util import api_v2_url from api.base.serializers import JSONAPISerializer, LinksField - -NOTIFICATION_TYPES = { - 'none': 'none', - 'instant': 'email_transactional', - 'daily': 'email_digest', -} - - -class FrequencyField(ser.Field): - def to_representation(self, obj): - user_id = self.context['request'].user.id - if obj.email_transactional.filter(id=user_id).exists(): - return 'instant' - if obj.email_digest.filter(id=user_id).exists(): - return 'daily' - return 'none' - - def to_internal_value(self, frequency): - notification_type = NOTIFICATION_TYPES.get(frequency) - if notification_type: - return {'notification_type': notification_type} - raise ValidationError(f'Invalid frequency "{frequency}"') +from .fields import FrequencyField class SubscriptionSerializer(JSONAPISerializer): filterable_fields = frozenset([ 'id', 'event_name', + 'frequency', ]) - id = ser.CharField(source='_id', read_only=True) + id = ser.CharField( + read_only=True, + source='legacy_id', + help_text='The id of the subscription fixed for backward compatibility', + ) event_name = ser.CharField(read_only=True) - frequency = FrequencyField(source='*', required=True) - links = LinksField({ - 'self': 'get_absolute_url', - }) + frequency = FrequencyField(source='message_frequency', required=True) class Meta: type_ = 'subscription' + links = LinksField({ + 'self': 'get_absolute_url', + }) + def get_absolute_url(self, obj): return obj.absolute_api_v2_url def update(self, instance, validated_data): user = self.context['request'].user - notification_type = validated_data.get('notification_type') - instance.add_user_to_subscription(user, notification_type, save=True) + frequency = validated_data.get('frequency') or 'none' + instance.message_frequency = frequency + + if frequency != 'none' and instance.content_type == ContentType.objects.get_for_model(Node): + node = Node.objects.get( + id=instance.id, + content_type=instance.content_type, + ) + user_subs = node.parent_node.child_node_subscriptions + if node._id not in user_subs.setdefault(user._id, []): + user_subs[user._id].append(node._id) + node.parent_node.save() + return instance diff --git a/api/subscriptions/views.py b/api/subscriptions/views.py index c1d7e833b49..57a4dbf36c7 100644 --- a/api/subscriptions/views.py +++ b/api/subscriptions/views.py @@ -1,8 +1,11 @@ +from django.db.models import Value, When, Case, F, Q, OuterRef, Subquery +from django.db.models.fields import CharField, IntegerField +from django.db.models.functions import Concat, Cast +from django.contrib.contenttypes.models import ContentType from rest_framework import generics from rest_framework import permissions as drf_permissions from rest_framework.exceptions import NotFound -from django.core.exceptions import ObjectDoesNotExist -from django.db.models import Q +from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from framework.auth.oauth_scopes import CoreScopes from api.base.views import JSONAPIBaseView @@ -16,12 +19,13 @@ ) from api.subscriptions.permissions import IsSubscriptionOwner from osf.models import ( - NotificationSubscription, CollectionProvider, PreprintProvider, RegistrationProvider, - AbstractProvider, + AbstractProvider, AbstractNode, Preprint, OSFUser, ) +from osf.models.notification_type import NotificationType +from osf.models.notification_subscription import NotificationSubscription class SubscriptionList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin): @@ -37,32 +41,59 @@ class SubscriptionList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin): required_read_scopes = [CoreScopes.SUBSCRIPTIONS_READ] required_write_scopes = [CoreScopes.NULL] - def get_default_queryset(self): - user = self.request.user - return NotificationSubscription.objects.filter( - Q(none=user) | - Q(email_digest=user) | - Q( - email_transactional=user, - ), - ).distinct() - def get_queryset(self): - return self.get_queryset_from_request() + user_guid = self.request.user._id + provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider') + + provider_subquery = AbstractProvider.objects.filter( + id=Cast(OuterRef('object_id'), IntegerField()), + ).values('_id')[:1] + + node_subquery = AbstractNode.objects.filter( + id=Cast(OuterRef('object_id'), IntegerField()), + ).values('guids___id')[:1] + + return NotificationSubscription.objects.filter(user=self.request.user).annotate( + event_name=Case( + When( + notification_type__name=NotificationType.Type.NODE_FILES_UPDATED.value, + then=Value('files_updated'), + ), + When( + notification_type__name=NotificationType.Type.USER_FILE_UPDATED.value, + then=Value('global_file_updated'), + ), + default=F('notification_type__name'), + output_field=CharField(), + ), + legacy_id=Case( + When( + notification_type__name=NotificationType.Type.NODE_FILES_UPDATED.value, + then=Concat(Subquery(node_subquery), Value('_file_updated')), + ), + When( + notification_type__name=NotificationType.Type.USER_FILE_UPDATED.value, + then=Value(f'{user_guid}_global'), + ), + When( + Q(notification_type__name=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value) & + Q(content_type=provider_ct), + then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')), + ), + default=F('notification_type__name'), + output_field=CharField(), + ), + ) class AbstractProviderSubscriptionList(SubscriptionList): - def get_default_queryset(self): - user = self.request.user + def get_queryset(self): + provider = AbstractProvider.objects.get(_id=self.kwargs['provider_id']) return NotificationSubscription.objects.filter( - provider___id=self.kwargs['provider_id'], - provider__type=self.provider_class._typedmodels_type, - ).filter( - Q(none=user) | - Q(email_digest=user) | - Q(email_transactional=user), - ).distinct() - + object_id=provider, + provider__type=ContentType.objects.get_for_model(provider.__class__), + user=self.request.user, + ) class SubscriptionDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView): view_name = 'notification-subscription-detail' @@ -79,10 +110,63 @@ class SubscriptionDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView): def get_object(self): subscription_id = self.kwargs['subscription_id'] + user_guid = self.request.user._id + + provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider') + node_ct = ContentType.objects.get(app_label='osf', model='abstractnode') + + provider_subquery = AbstractProvider.objects.filter( + id=Cast(OuterRef('object_id'), IntegerField()), + ).values('_id')[:1] + + node_subquery = AbstractNode.objects.filter( + id=Cast(OuterRef('object_id'), IntegerField()), + ).values('guids___id')[:1] + + guid_id, *event_parts = subscription_id.split('_') + event = '_'.join(event_parts) if event_parts else '' + + subscription_obj = AbstractNode.load(guid_id) or Preprint.load(guid_id) or OSFUser.load(guid_id) + + if event != 'global': + obj_filter = Q( + object_id=getattr(subscription_obj, 'id', None), + content_type=ContentType.objects.get_for_model(subscription_obj.__class__), + notification_type__name__icontains=event, + ) + else: + obj_filter = Q() + try: - obj = NotificationSubscription.objects.get(_id=subscription_id) + obj = NotificationSubscription.objects.annotate( + legacy_id=Case( + When( + notification_type__name=NotificationType.Type.NODE_FILES_UPDATED.value, + content_type=node_ct, + then=Concat(Subquery(node_subquery), Value('_file_updated')), + ), + When( + notification_type__name=NotificationType.Type.USER_FILE_UPDATED.value, + then=Value(f'{user_guid}_global'), + ), + When( + notification_type__name=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value, + content_type=provider_ct, + then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')), + ), + default=Value(f'{user_guid}_global'), + output_field=CharField(), + ), + ).filter(obj_filter) + except ObjectDoesNotExist: raise NotFound + + try: + obj = obj.filter(user=self.request.user).get() + except ObjectDoesNotExist: + raise PermissionDenied + self.check_object_permissions(self.request, obj) return obj @@ -100,33 +184,6 @@ class AbstractProviderSubscriptionDetail(SubscriptionDetail): required_write_scopes = [CoreScopes.SUBSCRIPTIONS_WRITE] provider_class = None - def __init__(self, *args, **kwargs): - assert issubclass(self.provider_class, AbstractProvider), 'Class must be subclass of AbstractProvider' - super().__init__(*args, **kwargs) - - def get_object(self): - subscription_id = self.kwargs['subscription_id'] - if self.kwargs.get('provider_id'): - provider = self.provider_class.objects.get(_id=self.kwargs.get('provider_id')) - try: - obj = NotificationSubscription.objects.get( - _id=subscription_id, - provider_id=provider.id, - ) - except ObjectDoesNotExist: - raise NotFound - else: - try: - obj = NotificationSubscription.objects.get( - _id=subscription_id, - provider__type=self.provider_class._typedmodels_type, - ) - except ObjectDoesNotExist: - raise NotFound - self.check_object_permissions(self.request, obj) - return obj - - class CollectionProviderSubscriptionDetail(AbstractProviderSubscriptionDetail): provider_class = CollectionProvider serializer_class = CollectionSubscriptionSerializer diff --git a/api/users/views.py b/api/users/views.py index 8dea51613df..04fdb101d6f 100644 --- a/api/users/views.py +++ b/api/users/views.py @@ -99,6 +99,7 @@ OSFUser, Email, Tag, + NotificationType, ) from osf.utils.tokens import TokenHandler from osf.utils.tokens.handlers import sanction_handler @@ -822,7 +823,7 @@ def get(self, request, *args, **kwargs): raise ValidationError('Request must include email in query params.') institutional = bool(request.query_params.get('institutional', None)) - mail_template = mails.FORGOT_PASSWORD if not institutional else mails.FORGOT_PASSWORD_INSTITUTION + mail_template = 'forgot_password' if not institutional else 'forgot_password_institution' status_message = language.RESET_PASSWORD_SUCCESS_STATUS_MESSAGE.format(email=email) kind = 'success' @@ -842,12 +843,15 @@ def get(self, request, *args, **kwargs): user_obj.email_last_sent = timezone.now() user_obj.save() reset_link = f'{settings.RESET_PASSWORD_URL}{user_obj._id}/{user_obj.verification_key_v2['token']}/' - mails.send_mail( - to_addr=email, - mail=mail_template, - reset_link=reset_link, - can_change_preferences=False, - ) + + notification_type = NotificationType.objects.filter(name=mail_template) + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + f'NotificationType with name {mail_template} does not exist.', + ) + notification_type = notification_type.first() + notification_type.emit(user=user_obj, message_frequency='instantly', event_context={'can_change_preferences': False, 'reset_link': reset_link}) + return Response(status=status.HTTP_200_OK, data={'message': status_message, 'kind': kind, 'institutional': institutional}) @method_decorator(csrf_protect) @@ -1059,13 +1063,13 @@ def _process_external_identity(self, user, external_identity, service_url): if external_status == 'CREATE': service_url += '&' + urlencode({'new': 'true'}) elif external_status == 'LINK': - mails.send_mail( - user=user, - to_addr=user.username, - mail=mails.EXTERNAL_LOGIN_LINK_SUCCESS, - external_id_provider=provider, - can_change_preferences=False, - ) + notification_type = NotificationType.objects.filter(name='external_confirm_success') + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + 'NotificationType with name external_confirm_success does not exist.', + ) + notification_type = notification_type.first() + notification_type.emit(user=user, message_frequency='instantly', event_context={'can_change_preferences': False, 'external_id_provider': provider}) enqueue_task(update_affiliation_for_orcid_sso_users.s(user._id, provider_id)) @@ -1380,13 +1384,13 @@ def post(self, request, *args, **kwargs): if external_status == 'CREATE': service_url += '&{}'.format(urlencode({'new': 'true'})) elif external_status == 'LINK': - mails.send_mail( - user=user, - to_addr=user.username, - mail=mails.EXTERNAL_LOGIN_LINK_SUCCESS, - external_id_provider=provider, - can_change_preferences=False, - ) + notification_type = NotificationType.objects.filter(name='external_confirm_success') + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + 'NotificationType with name external_confirm_success does not exist.', + ) + notification_type = notification_type.first() + notification_type.emit(user=user, message_frequency='instantly', event_context={'can_change_preferences': False, 'external_id_provider': provider}) enqueue_task(update_affiliation_for_orcid_sso_users.s(user._id, provider_id)) diff --git a/api_tests/draft_registrations/views/test_draft_registration_contributor_list.py b/api_tests/draft_registrations/views/test_draft_registration_contributor_list.py index 71fe7450b6d..4126ba5fedb 100644 --- a/api_tests/draft_registrations/views/test_draft_registration_contributor_list.py +++ b/api_tests/draft_registrations/views/test_draft_registration_contributor_list.py @@ -209,6 +209,7 @@ def create_serializer(self): @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestDraftContributorCreateEmail(DraftRegistrationCRUDTestCase, TestNodeContributorCreateEmail): @pytest.fixture() def url_project_contribs(self, project_public): @@ -217,7 +218,7 @@ def url_project_contribs(self, project_public): def test_add_contributor_sends_email( self, app, user, user_two, - url_project_contribs, mock_send_grid): + url_project_contribs, mock_notification_send): # Overrides TestNodeContributorCreateEmail url = f'{url_project_contribs}?send_email=draft_registration' payload = { @@ -238,7 +239,7 @@ def test_add_contributor_sends_email( res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 # Overrides TestNodeContributorCreateEmail def test_add_contributor_signal_if_default( @@ -265,7 +266,7 @@ def test_add_contributor_signal_if_default( # Overrides TestNodeContributorCreateEmail def test_add_unregistered_contributor_sends_email( - self, mock_send_grid, app, user, url_project_contribs): + self, mock_notification_send, app, user, url_project_contribs): url = f'{url_project_contribs}?send_email=draft_registration' payload = { 'data': { @@ -278,7 +279,7 @@ def test_add_unregistered_contributor_sends_email( } res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 # Overrides TestNodeContributorCreateEmail @mock.patch('website.project.signals.unreg_contributor_added.send') @@ -301,7 +302,7 @@ def test_add_unregistered_contributor_signal_if_default( # Overrides TestNodeContributorCreateEmail def test_add_unregistered_contributor_without_email_no_email( - self, mock_send_grid, app, user, url_project_contribs): + self, mock_notification_send, app, user, url_project_contribs): url = f'{url_project_contribs}?send_email=draft_registration' payload = { 'data': { @@ -316,7 +317,7 @@ def test_add_unregistered_contributor_without_email_no_email( res = app.post_json_api(url, payload, auth=user.auth) assert contributor_added in mock_signal.signals_sent() assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 class TestDraftContributorBulkCreate(DraftRegistrationCRUDTestCase, TestNodeContributorBulkCreate): diff --git a/api_tests/draft_registrations/views/test_draft_registration_list.py b/api_tests/draft_registrations/views/test_draft_registration_list.py index d19c6d994d5..148593bd752 100644 --- a/api_tests/draft_registrations/views/test_draft_registration_list.py +++ b/api_tests/draft_registrations/views/test_draft_registration_list.py @@ -158,6 +158,7 @@ def test_draft_with_deleted_registered_node_shows_up_in_draft_list( @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestDraftRegistrationCreateWithNode(AbstractDraftRegistrationTestCase): @pytest.fixture() @@ -336,11 +337,11 @@ def test_logged_in_non_contributor_cannot_create_draft( ) assert res.status_code == 403 - def test_create_project_based_draft_does_not_email_initiator(self, app, user, url_draft_registrations, payload, mock_send_grid): - mock_send_grid.reset_mock() + def test_create_project_based_draft_does_not_email_initiator(self, app, user, url_draft_registrations, payload, mock_notification_send): + mock_notification_send.reset_mock() app.post_json_api(f'{url_draft_registrations}?embed=branched_from&embed=initiator', payload, auth=user.auth) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_affiliated_institutions_are_copied_from_node_no_institutions(self, app, user, url_draft_registrations, payload): """ @@ -403,6 +404,7 @@ def test_affiliated_institutions_are_copied_from_user(self, app, user, url_draft @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestDraftRegistrationCreateWithoutNode(AbstractDraftRegistrationTestCase): @pytest.fixture() def url_draft_registrations(self): @@ -429,21 +431,14 @@ def test_admin_can_create_draft( assert draft.creator == user assert draft.has_permission(user, ADMIN) is True - def test_create_no_project_draft_emails_initiator(self, app, user, url_draft_registrations, payload, mock_send_grid): + def test_create_no_project_draft_emails_initiator(self, app, user, url_draft_registrations, payload, mock_notification_send): # Intercepting the send_mail call from website.project.views.contributor.notify_added_contributor app.post_json_api( f'{url_draft_registrations}?embed=branched_from&embed=initiator', payload, auth=user.auth ) - assert mock_send_grid.called - - # Python 3.6 does not support mock.call_args.args/kwargs - # Instead, mock.call_args[0] is positional args, mock.call_args[1] is kwargs - # (note, this is compatible with later versions) - mock_send_kwargs = mock_send_grid.call_args[1] - assert mock_send_kwargs['subject'] == 'You have a new registration draft.' - assert mock_send_kwargs['to_addr'] == user.email + assert mock_notification_send.called def test_create_draft_with_provider( self, app, user, url_draft_registrations, non_default_provider, payload_with_non_default_provider diff --git a/api_tests/institutions/views/test_institution_relationship_nodes.py b/api_tests/institutions/views/test_institution_relationship_nodes.py index c62d760710d..4d2703d5599 100644 --- a/api_tests/institutions/views/test_institution_relationship_nodes.py +++ b/api_tests/institutions/views/test_institution_relationship_nodes.py @@ -26,6 +26,7 @@ def make_registration_payload(*node_ids): @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestInstitutionRelationshipNodes: @pytest.fixture() @@ -393,7 +394,7 @@ def test_email_sent_on_affiliation_addition(self, app, user, institution, node_w assert res.status_code == 201 mock_send_grid.assert_called_once() - def test_email_sent_on_affiliation_removal(self, app, admin, institution, node_public, url_institution_nodes, mock_send_grid): + def test_email_sent_on_affiliation_removal(self, app, admin, institution, node_public, url_institution_nodes, mock_notification_send): current_institution = InstitutionFactory() node_public.affiliated_institutions.add(current_institution) @@ -411,6 +412,3 @@ def test_email_sent_on_affiliation_removal(self, app, admin, institution, node_p # Assert response is successful assert res.status_code == 204 - - call_args = mock_send_grid.call_args[1] - assert call_args['to_addr'] == admin.email diff --git a/api_tests/nodes/views/test_node_contributors_list.py b/api_tests/nodes/views/test_node_contributors_list.py index 81910a6ef55..bfbd5d72dae 100644 --- a/api_tests/nodes/views/test_node_contributors_list.py +++ b/api_tests/nodes/views/test_node_contributors_list.py @@ -1203,6 +1203,7 @@ def test_add_contributor_validation( @pytest.mark.enable_bookmark_creation @pytest.mark.enable_enqueue_task @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestNodeContributorCreateEmail(NodeCRUDTestCase): @pytest.fixture() @@ -1210,7 +1211,7 @@ def url_project_contribs(self, project_public): return f'/{API_BASE}nodes/{project_public._id}/contributors/' def test_add_contributor_no_email_if_false( - self, mock_send_grid, app, user, url_project_contribs + self, mock_notification_send, app, user, url_project_contribs ): url = f'{url_project_contribs}?send_email=false' payload = { @@ -1221,10 +1222,10 @@ def test_add_contributor_no_email_if_false( } res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_add_contributor_sends_email( - self, mock_send_grid, app, user, user_two, url_project_contribs + self, mock_notification_send, app, user, user_two, url_project_contribs ): url = f'{url_project_contribs}?send_email=default' payload = { @@ -1239,7 +1240,7 @@ def test_add_contributor_sends_email( res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 @mock.patch('website.project.signals.contributor_added.send') def test_add_contributor_signal_if_default( @@ -1281,7 +1282,7 @@ def test_add_contributor_signal_preprint_email_disallowed( ) def test_add_unregistered_contributor_sends_email( - self, mock_send_grid, app, user, url_project_contribs + self, mock_notification_send, app, user, url_project_contribs ): url = f'{url_project_contribs}?send_email=default' payload = { @@ -1292,7 +1293,7 @@ def test_add_unregistered_contributor_sends_email( } res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 @mock.patch('website.project.signals.unreg_contributor_added.send') def test_add_unregistered_contributor_signal_if_default( @@ -1328,7 +1329,7 @@ def test_add_unregistered_contributor_signal_preprint_email_disallowed( ) def test_add_contributor_invalid_send_email_param( - self, mock_send_grid, app, user, url_project_contribs + self, mock_notification_send, app, user, url_project_contribs ): url = f'{url_project_contribs}?send_email=true' payload = { @@ -1342,10 +1343,10 @@ def test_add_contributor_invalid_send_email_param( assert ( res.json['errors'][0]['detail'] == 'true is not a valid email preference.' ) - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_add_unregistered_contributor_without_email_no_email( - self, mock_send_grid, app, user, url_project_contribs + self, mock_notification_send, app, user, url_project_contribs ): url = f'{url_project_contribs}?send_email=default' payload = { @@ -1361,7 +1362,7 @@ def test_add_unregistered_contributor_without_email_no_email( res = app.post_json_api(url, payload, auth=user.auth) assert contributor_added in mock_signal.signals_sent() assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 @pytest.mark.django_db diff --git a/api_tests/nodes/views/test_node_forks_list.py b/api_tests/nodes/views/test_node_forks_list.py index 8fc9f9eb35b..af2ee960ff8 100644 --- a/api_tests/nodes/views/test_node_forks_list.py +++ b/api_tests/nodes/views/test_node_forks_list.py @@ -204,6 +204,7 @@ def test_forks_list_does_not_show_registrations_of_forks( @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestNodeForkCreate: @pytest.fixture() @@ -403,12 +404,13 @@ def test_can_fork_registration( registration.title def test_read_only_contributor_can_fork_private_registration( - self, app, private_project, fork_data, private_project_url): + self, app, private_project, fork_data, private_project_url, mock_notification_send): read_contrib = AuthUserFactory() private_project.add_contributor( read_contrib, - permissions=permissions.READ, save=True) + permissions=permissions.READ, save=True + ) res = app.post_json_api( private_project_url, fork_data, auth=read_contrib.auth) @@ -416,10 +418,11 @@ def test_read_only_contributor_can_fork_private_registration( assert res.json['data']['id'] == private_project.forks.first()._id assert res.json['data']['attributes']['title'] == 'Fork of ' + \ private_project.title + assert mock_notification_send.called def test_send_email_success( self, app, user, public_project_url, - fork_data_with_title, public_project, mock_send_grid): + fork_data_with_title, public_project, mock_notification_send): res = app.post_json_api( public_project_url, @@ -427,13 +430,10 @@ def test_send_email_success( auth=user.auth) assert res.status_code == 201 assert res.json['data']['id'] == public_project.forks.first()._id - call_args = mock_send_grid.call_args[1] - assert call_args['to_addr'] == user.email - assert call_args['subject'] == 'Your fork has completed' def test_send_email_failed( self, app, user, public_project_url, - fork_data_with_title, public_project, mock_send_grid): + fork_data_with_title, public_project, mock_notification_send): with mock.patch.object(NodeForksSerializer, 'save', side_effect=Exception()): with pytest.raises(Exception): @@ -441,4 +441,4 @@ def test_send_email_failed( public_project_url, fork_data_with_title, auth=user.auth) - assert mock_send_grid.called + assert mock_notification_send.called diff --git a/api_tests/nodes/views/test_node_relationship_institutions.py b/api_tests/nodes/views/test_node_relationship_institutions.py index 3bf25dc5adf..aab1c0202e2 100644 --- a/api_tests/nodes/views/test_node_relationship_institutions.py +++ b/api_tests/nodes/views/test_node_relationship_institutions.py @@ -114,6 +114,7 @@ def create_payload(self, institutions): } @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestNodeRelationshipInstitutions(RelationshipInstitutionsTestMixin): def test_node_with_no_permissions(self, app, unauthorized_user_with_affiliation, institution_one, node_institutions_url): @@ -203,31 +204,22 @@ def test_user_with_institution_and_permissions( assert institution_two in node.affiliated_institutions.all() def test_user_with_institution_and_permissions_through_patch(self, app, user, institution_one, institution_two, - node, node_institutions_url, mock_send_grid): + node, node_institutions_url, mock_notification_send): - mock_send_grid.reset_mock() res = app.patch_json_api( node_institutions_url, self.create_payload([institution_one, institution_two]), auth=user.auth ) assert res.status_code == 200 - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 - first_call_args = mock_send_grid.call_args_list[0][1] - assert first_call_args['to_addr'] == user.email - assert first_call_args['subject'] == 'Project Affiliation Changed' - - second_call_args = mock_send_grid.call_args_list[1][1] - assert second_call_args['to_addr'] == user.email - assert second_call_args['subject'] == 'Project Affiliation Changed' - - def test_remove_institutions_with_affiliated_user(self, app, user, institution_one, node, node_institutions_url, mock_send_grid): + def test_remove_institutions_with_affiliated_user(self, app, user, institution_one, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() assert institution_one in node.affiliated_institutions.all() - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.put_json_api( node_institutions_url, { @@ -236,25 +228,21 @@ def test_remove_institutions_with_affiliated_user(self, app, user, institution_o auth=user.auth ) - first_call_args = mock_send_grid.call_args_list[0][1] - assert first_call_args['to_addr'] == user.email - assert first_call_args['subject'] == 'Project Affiliation Changed' - assert res.status_code == 200 assert node.affiliated_institutions.count() == 0 - def test_using_post_making_no_changes_returns_201(self, app, user, institution_one, node, node_institutions_url, mock_send_grid): + def test_using_post_making_no_changes_returns_201(self, app, user, institution_one, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() assert institution_one in node.affiliated_institutions.all() - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api( node_institutions_url, self.create_payload([institution_one]), auth=user.auth ) - mock_send_grid.assert_not_called() + mock_notification_send.assert_not_called() assert res.status_code == 201 assert institution_one in node.affiliated_institutions.all() @@ -276,56 +264,42 @@ def test_put_not_admin_but_affiliated(self, app, institution_one, node, node_ins assert institution_one in node.affiliated_institutions.all() def test_add_through_patch_one_inst_to_node_with_inst( - self, app, user, institution_one, institution_two, node, node_institutions_url, mock_send_grid): + self, app, user, institution_one, institution_two, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() assert institution_one in node.affiliated_institutions.all() assert institution_two not in node.affiliated_institutions.all() - mock_send_grid.reset_mock() res = app.patch_json_api( node_institutions_url, self.create_payload([institution_one, institution_two]), auth=user.auth ) - assert mock_send_grid.call_count == 1 - first_call_args = mock_send_grid.call_args_list[0][1] - assert first_call_args['to_addr'] == user.email - assert first_call_args['subject'] == 'Project Affiliation Changed' assert res.status_code == 200 assert institution_one in node.affiliated_institutions.all() assert institution_two in node.affiliated_institutions.all() def test_add_through_patch_one_inst_while_removing_other( - self, app, user, institution_one, institution_two, node, node_institutions_url, mock_send_grid): + self, app, user, institution_one, institution_two, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() assert institution_one in node.affiliated_institutions.all() assert institution_two not in node.affiliated_institutions.all() - mock_send_grid.reset_mock() res = app.patch_json_api( node_institutions_url, self.create_payload([institution_two]), auth=user.auth ) - assert mock_send_grid.call_count == 2 - - first_call_args = mock_send_grid.call_args_list[0][1] - assert first_call_args['to_addr'] == user.email - assert first_call_args['subject'] == 'Project Affiliation Changed' - - second_call_args = mock_send_grid.call_args_list[1][1] - assert second_call_args['to_addr'] == user.email - assert second_call_args['subject'] == 'Project Affiliation Changed' + assert mock_notification_send.call_count == 2 assert res.status_code == 200 assert institution_one not in node.affiliated_institutions.all() assert institution_two in node.affiliated_institutions.all() def test_add_one_inst_with_post_to_node_with_inst( - self, app, user, institution_one, institution_two, node, node_institutions_url, mock_send_grid): + self, app, user, institution_one, institution_two, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() assert institution_one in node.affiliated_institutions.all() @@ -336,9 +310,6 @@ def test_add_one_inst_with_post_to_node_with_inst( self.create_payload([institution_two]), auth=user.auth ) - call_args = mock_send_grid.call_args[1] - assert call_args['to_addr'] == user.email - assert call_args['subject'] == 'Project Affiliation Changed' assert res.status_code == 201 assert institution_one in node.affiliated_institutions.all() @@ -352,7 +323,7 @@ def test_delete_nothing(self, app, user, node_institutions_url): ) assert res.status_code == 204 - def test_delete_existing_inst(self, app, user, institution_one, node, node_institutions_url, mock_send_grid): + def test_delete_existing_inst(self, app, user, institution_one, node, node_institutions_url, mock_notification_send): node.affiliated_institutions.add(institution_one) node.save() @@ -362,10 +333,6 @@ def test_delete_existing_inst(self, app, user, institution_one, node, node_insti auth=user.auth ) - call_args = mock_send_grid.call_args[1] - assert call_args['to_addr'] == user.email - assert call_args['subject'] == 'Project Affiliation Changed' - assert res.status_code == 204 assert institution_one not in node.affiliated_institutions.all() diff --git a/api_tests/preprints/views/test_preprint_contributors_list.py b/api_tests/preprints/views/test_preprint_contributors_list.py index 6676b542b60..3f8baa30f07 100644 --- a/api_tests/preprints/views/test_preprint_contributors_list.py +++ b/api_tests/preprints/views/test_preprint_contributors_list.py @@ -1346,6 +1346,7 @@ def test_add_contributor_validation(self, preprint_published, validate_data): @pytest.mark.django_db @pytest.mark.enable_enqueue_task @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestPreprintContributorCreateEmail(NodeCRUDTestCase): @pytest.fixture() @@ -1353,7 +1354,7 @@ def url_preprint_contribs(self, preprint_published): return f'/{API_BASE}preprints/{preprint_published._id}/contributors/' def test_add_contributor_no_email_if_false( - self, mock_send_grid, app, user, url_preprint_contribs): + self, mock_notification_send, app, user, url_preprint_contribs): url = f'{url_preprint_contribs}?send_email=false' payload = { 'data': { @@ -1364,13 +1365,13 @@ def test_add_contributor_no_email_if_false( } } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_add_contributor_needs_preprint_filter_to_send_email( - self, mock_send_grid, app, user, user_two, + self, mock_notification_send, app, user, user_two, url_preprint_contribs): url = f'{url_preprint_contribs}?send_email=default' payload = { @@ -1389,11 +1390,11 @@ def test_add_contributor_needs_preprint_filter_to_send_email( } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api(url, payload, auth=user.auth, expect_errors=True) assert res.status_code == 400 assert res.json['errors'][0]['detail'] == 'default is not a valid email preference.' - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 @mock.patch('website.project.signals.contributor_added.send') def test_add_contributor_signal_if_preprint( @@ -1421,7 +1422,7 @@ def test_add_contributor_signal_if_preprint( assert 'preprint' == kwargs['email_template'] def test_add_unregistered_contributor_sends_email( - self, mock_send_grid, app, user, url_preprint_contribs): + self, mock_notification_send, app, user, url_preprint_contribs): url = f'{url_preprint_contribs}?send_email=preprint' payload = { 'data': { @@ -1433,10 +1434,10 @@ def test_add_unregistered_contributor_sends_email( } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 @mock.patch('website.project.signals.unreg_contributor_added.send') def test_add_unregistered_contributor_signal_if_preprint( @@ -1458,7 +1459,7 @@ def test_add_unregistered_contributor_signal_if_preprint( assert mock_send.call_count == 1 def test_add_contributor_invalid_send_email_param( - self, mock_send_grid, app, user, url_preprint_contribs): + self, mock_notification_send, app, user, url_preprint_contribs): url = f'{url_preprint_contribs}?send_email=true' payload = { 'data': { @@ -1469,16 +1470,16 @@ def test_add_contributor_invalid_send_email_param( } } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api( url, payload, auth=user.auth, expect_errors=True) assert res.status_code == 400 assert res.json['errors'][0]['detail'] == 'true is not a valid email preference.' - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_add_unregistered_contributor_without_email_no_email( - self, mock_send_grid, app, user, url_preprint_contribs): + self, mock_notification_send, app, user, url_preprint_contribs): url = f'{url_preprint_contribs}?send_email=preprint' payload = { 'data': { @@ -1489,16 +1490,16 @@ def test_add_unregistered_contributor_without_email_no_email( } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() with capture_signals() as mock_signal: res = app.post_json_api(url, payload, auth=user.auth) assert contributor_added in mock_signal.signals_sent() assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') def test_publishing_preprint_sends_emails_to_contributors( - self, mock_update, mock_send_grid, app, user, url_preprint_contribs, preprint_unpublished): + self, mock_update, mock_notification_send, app, user, url_preprint_contribs, preprint_unpublished): url = f'/{API_BASE}preprints/{preprint_unpublished._id}/' user_two = AuthUserFactory() preprint_unpublished.add_contributor(user_two, permissions=permissions.WRITE, save=True) @@ -1537,7 +1538,7 @@ def test_contributor_added_signal_not_specified( assert mock_send.call_count == 1 def test_contributor_added_not_sent_if_unpublished( - self, mock_send_grid, app, user, preprint_unpublished): + self, mock_notification_send, app, user, preprint_unpublished): url = f'/{API_BASE}preprints/{preprint_unpublished._id}/contributors/?send_email=preprint' payload = { 'data': { @@ -1548,10 +1549,10 @@ def test_contributor_added_not_sent_if_unpublished( } } } - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api(url, payload, auth=user.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 @pytest.mark.django_db diff --git a/api_tests/providers/tasks/test_bulk_upload.py b/api_tests/providers/tasks/test_bulk_upload.py index 221861ea313..100ad4ad530 100644 --- a/api_tests/providers/tasks/test_bulk_upload.py +++ b/api_tests/providers/tasks/test_bulk_upload.py @@ -64,6 +64,7 @@ def test_error_message_default(self): @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestBulkUploadTasks: @pytest.fixture() @@ -317,7 +318,7 @@ def test_bulk_creation_dry_run(self, registration_row_1, registration_row_2, upl assert upload_job_done_full.state == JobState.PICKED_UP assert not upload_job_done_full.email_sent - def test_bulk_creation_done_full(self, mock_send_grid, registration_row_1, registration_row_2, + def test_bulk_creation_done_full(self, mock_notification_send, registration_row_1, registration_row_2, upload_job_done_full, provider, initiator, read_contributor, write_contributor): bulk_create_registrations(upload_job_done_full.id, dry_run=False) @@ -335,9 +336,9 @@ def test_bulk_creation_done_full(self, mock_send_grid, registration_row_1, regis assert row.draft_registration.contributor_set.get(user=write_contributor).permission == WRITE assert row.draft_registration.contributor_set.get(user=read_contributor).permission == READ - mock_send_grid.assert_called() + mock_notification_send.assert_called() - def test_bulk_creation_done_partial(self, mock_send_grid, registration_row_3, + def test_bulk_creation_done_partial(self, mock_notification_send, registration_row_3, registration_row_invalid_extra_bib_1, upload_job_done_partial, provider, initiator, read_contributor, write_contributor): @@ -355,7 +356,7 @@ def test_bulk_creation_done_partial(self, mock_send_grid, registration_row_3, assert registration_row_3.draft_registration.contributor_set.get(user=write_contributor).permission == WRITE assert registration_row_3.draft_registration.contributor_set.get(user=read_contributor).permission == READ - mock_send_grid.assert_called() + mock_notification_send.assert_called() def test_bulk_creation_done_error(self, mock_send_grid, registration_row_invalid_extra_bib_2, registration_row_invalid_affiliation, upload_job_done_error, diff --git a/api_tests/registrations/views/test_registration_detail.py b/api_tests/registrations/views/test_registration_detail.py index 9112d0a3264..02c8ed42f6a 100644 --- a/api_tests/registrations/views/test_registration_detail.py +++ b/api_tests/registrations/views/test_registration_detail.py @@ -696,6 +696,7 @@ def test_read_write_contributor_can_edit_writeable_fields( @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestRegistrationWithdrawal(TestRegistrationUpdateTestCase): @pytest.fixture @@ -754,14 +755,14 @@ def test_initiate_withdraw_registration_fails( res = app.put_json_api(public_url, public_payload, auth=user.auth, expect_errors=True) assert res.status_code == 400 - def test_initiate_withdrawal_success(self, mock_send_grid, app, user, public_registration, public_url, public_payload): + def test_initiate_withdrawal_success(self, mock_notification_send, app, user, public_registration, public_url, public_payload): res = app.put_json_api(public_url, public_payload, auth=user.auth) assert res.status_code == 200 assert res.json['data']['attributes']['pending_withdrawal'] is True public_registration.refresh_from_db() assert public_registration.is_pending_retraction assert public_registration.registered_from.logs.first().action == 'retraction_initiated' - assert mock_send_grid.called + assert mock_notification_send.called @pytest.mark.usefixtures('mock_gravy_valet_get_verified_links') def test_initiate_withdrawal_with_embargo_ends_embargo( @@ -786,7 +787,7 @@ def test_initiate_withdrawal_with_embargo_ends_embargo( assert not public_registration.is_pending_embargo def test_withdraw_request_does_not_send_email_to_unregistered_admins( - self, mock_send_grid, app, user, public_registration, public_url, public_payload): + self, mock_notification_send, app, user, public_registration, public_url, public_payload): unreg = UnregUserFactory() with disconnected_from_listeners(contributor_added): public_registration.add_unregistered_contributor( @@ -803,7 +804,7 @@ def test_withdraw_request_does_not_send_email_to_unregistered_admins( # Only the creator gets an email; the unreg user does not get emailed assert public_registration._contributors.count() == 2 - assert mock_send_grid.call_count == 3 + assert mock_notification_send.call_count == 2 @pytest.mark.django_db diff --git a/api_tests/requests/views/test_node_request_institutional_access.py b/api_tests/requests/views/test_node_request_institutional_access.py index d868739e9bd..0cf0380d206 100644 --- a/api_tests/requests/views/test_node_request_institutional_access.py +++ b/api_tests/requests/views/test_node_request_institutional_access.py @@ -10,6 +10,7 @@ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestNodeRequestListInstitutionalAccess(NodeRequestTestMixin): @pytest.fixture() @@ -206,37 +207,37 @@ def test_institutional_admin_unauth_institution(self, app, project, institution_ assert res.status_code == 403 assert 'Institutional request access is not enabled.' in res.json['errors'][0]['detail'] - def test_email_not_sent_without_recipient(self, mock_send_grid, app, project, institutional_admin, url, + def test_email_not_sent_without_recipient(self, mock_notification_send, app, project, institutional_admin, url, create_payload, institution): """ Test that an email is not sent when no recipient is listed when an institutional access request is made, but the request is still made anyway without email. """ + mock_notification_send.reset_mock() del create_payload['data']['relationships']['message_recipient'] - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth) assert res.status_code == 201 # Check that an email is sent - assert not mock_send_grid.called + assert not mock_notification_send.called - def test_email_not_sent_outside_institution(self, mock_send_grid, app, project, institutional_admin, url, + def test_email_not_sent_outside_institution(self, mock_notification_send, app, project, institutional_admin, url, create_payload, user_without_affiliation, institution): """ Test that you are prevented from requesting a user with the correct institutional affiliation. """ create_payload['data']['relationships']['message_recipient']['data']['id'] = user_without_affiliation._id - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth, expect_errors=True) assert res.status_code == 403 assert f'User {user_without_affiliation._id} is not affiliated with the institution.' in res.json['errors'][0]['detail'] # Check that an email is sent - assert not mock_send_grid.called + assert not mock_notification_send.called def test_email_sent_on_creation( self, - mock_send_grid, + mock_notification_send, app, project, institutional_admin, @@ -248,15 +249,14 @@ def test_email_sent_on_creation( """ Test that an email is sent to the appropriate recipients when an institutional access request is made. """ - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 def test_bcc_institutional_admin( self, - mock_send_grid, + mock_notification_send, app, project, institutional_admin, @@ -269,15 +269,14 @@ def test_bcc_institutional_admin( Ensure BCC option works as expected, sending messages to sender giving them a copy for themselves. """ create_payload['data']['attributes']['bcc_sender'] = True - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 def test_reply_to_institutional_admin( self, - mock_send_grid, + mock_notification_send, app, project, institutional_admin, @@ -290,11 +289,10 @@ def test_reply_to_institutional_admin( Ensure reply-to option works as expected, allowing a reply to header be added to the email. """ create_payload['data']['attributes']['reply_to'] = True - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 def test_access_requests_disabled_raises_permission_denied( self, app, node_with_disabled_access_requests, user_with_affiliation, institutional_admin, create_payload @@ -313,7 +311,7 @@ def test_access_requests_disabled_raises_permission_denied( def test_placeholder_text_when_comment_is_empty( self, - mock_send_grid, + mock_notification_send, app, project, institutional_admin, @@ -327,11 +325,10 @@ def test_placeholder_text_when_comment_is_empty( """ # Test with empty comment create_payload['data']['attributes']['comment'] = '' - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=institutional_admin.auth) assert res.status_code == 201 - mock_send_grid.assert_called() + mock_notification_send.assert_called() def test_requester_can_resubmit(self, app, project, institutional_admin, url, create_payload): """ diff --git a/api_tests/requests/views/test_node_request_list.py b/api_tests/requests/views/test_node_request_list.py index 41ee66747d4..829fdf4ec4d 100644 --- a/api_tests/requests/views/test_node_request_list.py +++ b/api_tests/requests/views/test_node_request_list.py @@ -9,6 +9,7 @@ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestNodeRequestListCreate(NodeRequestTestMixin): @pytest.fixture() def url(self, project): @@ -80,27 +81,25 @@ def test_requests_disabled_list(self, app, url, create_payload, project, admin): res = app.get(url, create_payload, auth=admin.auth, expect_errors=True) assert res.status_code == 403 - def test_email_sent_to_all_admins_on_submit(self, mock_send_grid, app, project, noncontrib, url, create_payload, second_admin): + def test_email_sent_to_all_admins_on_submit(self, mock_notification_send, app, project, noncontrib, url, create_payload, second_admin): project.is_public = True project.save() - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=noncontrib.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 def test_email_not_sent_to_parent_admins_on_submit(self, mock_send_grid, app, project, noncontrib, url, create_payload, second_admin): component = NodeFactory(parent=project, creator=second_admin) component.is_public = True project.save() url = f'/{API_BASE}nodes/{component._id}/requests/' - mock_send_grid.reset_mock() res = app.post_json_api(url, create_payload, auth=noncontrib.auth) assert res.status_code == 201 assert component.parent_admin_contributors.count() == 1 assert component.contributors.count() == 1 assert mock_send_grid.call_count == 1 - def test_request_followed_by_added_as_contrib(elf, app, project, noncontrib, admin, url, create_payload): + def test_request_followed_by_added_as_contrib(self, app, project, noncontrib, admin, url, create_payload): res = app.post_json_api(url, create_payload, auth=noncontrib.auth) assert res.status_code == 201 assert project.requests.filter(creator=noncontrib, machine_state='pending').exists() diff --git a/api_tests/requests/views/test_preprint_request_list.py b/api_tests/requests/views/test_preprint_request_list.py index 72e16862f7a..05c28a834dc 100644 --- a/api_tests/requests/views/test_preprint_request_list.py +++ b/api_tests/requests/views/test_preprint_request_list.py @@ -6,6 +6,7 @@ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestPreprintRequestListCreate(PreprintRequestTestMixin): def url(self, preprint): return f'/{API_BASE}preprints/{preprint._id}/requests/' @@ -65,7 +66,7 @@ def test_requester_cannot_submit_again(self, app, admin, create_payload, pre_mod assert res.json['errors'][0]['detail'] == 'Users may not have more than one withdrawal request per preprint.' @pytest.mark.skip('TODO: IN-284 -- add emails') - def test_email_sent_to_moderators_on_submit(self, mock_send_grid, app, admin, create_payload, moderator, post_mod_preprint): + def test_email_sent_to_moderators_on_submit(self, mock_notification_send, app, admin, create_payload, moderator, post_mod_preprint): res = app.post_json_api(self.url(post_mod_preprint), create_payload, auth=admin.auth) assert res.status_code == 201 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 diff --git a/api_tests/requests/views/test_request_actions_create.py b/api_tests/requests/views/test_request_actions_create.py index 30e579d3ab3..1e7c94b3c76 100644 --- a/api_tests/requests/views/test_request_actions_create.py +++ b/api_tests/requests/views/test_request_actions_create.py @@ -8,6 +8,7 @@ @pytest.mark.django_db @pytest.mark.enable_enqueue_task @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestCreateNodeRequestAction(NodeRequestTestMixin): @pytest.fixture() def url(self, node_request): @@ -190,8 +191,8 @@ def test_rejects_fail_with_requests_disabled(self, app, admin, url, node_request assert initial_state == node_request.machine_state assert node_request.creator not in node_request.target.contributors - def test_email_sent_on_approve(self, mock_send_grid, app, admin, url, node_request): - mock_send_grid.reset_mock() + def test_email_sent_on_approve(self, mock_notification_send, app, admin, url, node_request): + mock_notification_send.reset_mock() initial_state = node_request.machine_state assert node_request.creator not in node_request.target.contributors payload = self.create_payload(node_request._id, trigger='accept') @@ -200,10 +201,9 @@ def test_email_sent_on_approve(self, mock_send_grid, app, admin, url, node_reque node_request.reload() assert initial_state != node_request.machine_state assert node_request.creator in node_request.target.contributors - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 - def test_email_sent_on_reject(self, mock_send_grid, app, admin, url, node_request): - mock_send_grid.reset_mock() + def test_email_sent_on_reject(self, mock_notification_send, app, admin, url, node_request): initial_state = node_request.machine_state assert node_request.creator not in node_request.target.contributors payload = self.create_payload(node_request._id, trigger='reject') @@ -212,10 +212,10 @@ def test_email_sent_on_reject(self, mock_send_grid, app, admin, url, node_reques node_request.reload() assert initial_state != node_request.machine_state assert node_request.creator not in node_request.target.contributors - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 - def test_email_not_sent_on_reject(self, mock_send_grid, app, requester, url, node_request): - mock_send_grid.reset_mock() + def test_email_not_sent_on_reject(self, mock_notification_send, app, requester, url, node_request): + mock_notification_send.reset_mock() initial_state = node_request.machine_state initial_comment = node_request.comment payload = self.create_payload(node_request._id, trigger='edit_comment', comment='ASDFG') @@ -224,7 +224,7 @@ def test_email_not_sent_on_reject(self, mock_send_grid, app, requester, url, nod node_request.reload() assert initial_state == node_request.machine_state assert initial_comment != node_request.comment - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_set_permissions_on_approve(self, app, admin, url, node_request): assert node_request.creator not in node_request.target.contributors @@ -256,6 +256,7 @@ def test_accept_request_defaults_to_read_and_visible(self, app, admin, url, node @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestCreatePreprintRequestAction(PreprintRequestTestMixin): @pytest.fixture() def url(self, pre_request, post_request, none_request): @@ -385,8 +386,7 @@ def test_write_contrib_and_noncontrib_cannot_edit_comment(self, app, write_contr assert initial_state == request.machine_state assert initial_comment == request.comment - def test_email_sent_on_approve(self, mock_send_grid, app, moderator, url, pre_request, post_request): - mock_send_grid.reset_mock() + def test_email_sent_on_approve(self, mock_notification_send, app, moderator, url, pre_request, post_request): for request in [pre_request, post_request]: initial_state = request.machine_state assert not request.target.is_retracted @@ -397,11 +397,11 @@ def test_email_sent_on_approve(self, mock_send_grid, app, moderator, url, pre_re request.target.reload() assert initial_state != request.machine_state assert request.target.is_retracted - # There are two preprints withdrawn and each preprint have 2 contributors. So 4 emails are sent in total. - assert mock_send_grid.call_count == 4 + # There are two preprints withdrawn and each preprint have 2 contributors. + assert mock_notification_send.call_count == 2 @pytest.mark.skip('TODO: IN-331 -- add emails') - def test_email_sent_on_reject(self, mock_send_grid, app, moderator, url, pre_request, post_request): + def test_email_sent_on_reject(self, mock_notification_send, app, moderator, url, pre_request, post_request): for request in [pre_request, post_request]: initial_state = request.machine_state assert not request.target.is_retracted @@ -411,10 +411,10 @@ def test_email_sent_on_reject(self, mock_send_grid, app, moderator, url, pre_req request.reload() assert initial_state != request.machine_state assert not request.target.is_retracted - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 @pytest.mark.skip('TODO: IN-284/331 -- add emails') - def test_email_not_sent_on_edit_comment(self, mock_send_grid, app, moderator, url, pre_request, post_request): + def test_email_not_sent_on_edit_comment(self, mock_notification_send, app, moderator, url, pre_request, post_request): for request in [pre_request, post_request]: initial_state = request.machine_state assert not request.target.is_retracted @@ -424,7 +424,7 @@ def test_email_not_sent_on_edit_comment(self, mock_send_grid, app, moderator, ur request.reload() assert initial_state != request.machine_state assert not request.target.is_retracted - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 def test_auto_approve(self, app, auto_withdrawable_pre_mod_preprint, auto_approved_pre_request): assert auto_withdrawable_pre_mod_preprint.is_retracted diff --git a/api_tests/subscriptions/views/test_subscriptions_detail.py b/api_tests/subscriptions/views/test_subscriptions_detail.py index 2a8741fc173..0e2fa22b119 100644 --- a/api_tests/subscriptions/views/test_subscriptions_detail.py +++ b/api_tests/subscriptions/views/test_subscriptions_detail.py @@ -1,8 +1,10 @@ import pytest from api.base.settings.defaults import API_BASE -from osf_tests.factories import AuthUserFactory, NotificationSubscriptionFactory - +from osf_tests.factories import ( + AuthUserFactory, + NotificationSubscriptionFactory +) @pytest.mark.django_db class TestSubscriptionDetail: @@ -16,18 +18,16 @@ def user_no_auth(self): return AuthUserFactory() @pytest.fixture() - def global_user_notification(self, user): - notification = NotificationSubscriptionFactory(_id=f'{user._id}_global', user=user, event_name='global') - notification.add_user_to_subscription(user, 'email_transactional') - return notification + def notification(self, user): + return NotificationSubscriptionFactory(user=user) @pytest.fixture() - def url(self, global_user_notification): - return f'/{API_BASE}subscriptions/{global_user_notification._id}/' + def url(self, notification): + return f'/{API_BASE}subscriptions/{notification._id}/' @pytest.fixture() def url_invalid(self): - return '/{}subscriptions/{}/'.format(API_BASE, 'invalid-notification-id') + return f'/{API_BASE}subscriptions/invalid-notification-id/' @pytest.fixture() def payload(self): @@ -51,56 +51,99 @@ def payload_invalid(self): } } - def test_subscription_detail(self, app, user, user_no_auth, global_user_notification, url, url_invalid, payload, payload_invalid): - # GET with valid notification_id - # Invalid user - res = app.get(url, auth=user_no_auth.auth, expect_errors=True) + def test_subscription_detail_invalid_user(self, app, user, user_no_auth, notification, url, payload): + res = app.get( + url, + auth=user_no_auth.auth, + expect_errors=True + ) assert res.status_code == 403 - # No user - res = app.get(url, expect_errors=True) + + def test_subscription_detail_no_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): + res = app.get( + url, + expect_errors=True + ) assert res.status_code == 401 - # Valid user + + def test_subscription_detail_valid_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): + res = app.get(url, auth=user.auth) notification_id = res.json['data']['id'] assert res.status_code == 200 assert notification_id == f'{user._id}_global' - # GET with invalid notification_id - # No user + def test_subscription_detail_invalid_notification_id_no_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.get(url_invalid, expect_errors=True) assert res.status_code == 404 - # Existing user - res = app.get(url_invalid, auth=user.auth, expect_errors=True) + + def test_subscription_detail_invalid_notification_id_existing_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): + res = app.get( + url_invalid, + auth=user.auth, + expect_errors=True + ) assert res.status_code == 404 - # PATCH with valid notification_id and invalid data - # Invalid user + def test_subscription_detail_invalid_payload_403( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url, payload_invalid, auth=user_no_auth.auth, expect_errors=True) assert res.status_code == 403 - # No user + + def test_subscription_detail_invalid_payload_401( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url, payload_invalid, expect_errors=True) assert res.status_code == 401 - # Valid user - res = app.patch_json_api(url, payload_invalid, auth=user.auth, expect_errors=True) + + def test_subscription_detail_invalid_payload_400( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): + res = app.patch_json_api( + url, + payload_invalid, + auth=user.auth, + expect_errors=True + ) assert res.status_code == 400 - assert res.json['errors'][0]['detail'] == 'Invalid frequency "invalid-frequency"' + assert res.json['errors'][0]['detail'] == ('"invalid-frequency" is not a valid choice.') - # PATCH with invalid notification_id - # No user + def test_subscription_detail_patch_invalid_notification_id_no_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url_invalid, payload, expect_errors=True) assert res.status_code == 404 - # Existing user + + def test_subscription_detail_patch_invalid_notification_id_existing_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url_invalid, payload, auth=user.auth, expect_errors=True) assert res.status_code == 404 - # PATCH with valid notification_id and valid data - # Invalid user + def test_subscription_detail_patch_invalid_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url, payload, auth=user_no_auth.auth, expect_errors=True) assert res.status_code == 403 - # No user + + def test_subscription_detail_patch_no_user( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url, payload, expect_errors=True) assert res.status_code == 401 - # Valid user + + def test_subscription_detail_patch( + self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid + ): res = app.patch_json_api(url, payload, auth=user.auth) assert res.status_code == 200 assert res.json['data']['attributes']['frequency'] == 'none' diff --git a/api_tests/subscriptions/views/test_subscriptions_list.py b/api_tests/subscriptions/views/test_subscriptions_list.py index a04f04e3e06..a0a01bf513c 100644 --- a/api_tests/subscriptions/views/test_subscriptions_list.py +++ b/api_tests/subscriptions/views/test_subscriptions_list.py @@ -1,7 +1,13 @@ import pytest from api.base.settings.defaults import API_BASE -from osf_tests.factories import AuthUserFactory, PreprintProviderFactory, ProjectFactory, NotificationSubscriptionFactory +from osf.models import NotificationType +from osf_tests.factories import ( + AuthUserFactory, + PreprintProviderFactory, + ProjectFactory, + NotificationSubscriptionFactory +) @pytest.mark.django_db @@ -23,25 +29,42 @@ def node(self, user): @pytest.fixture() def global_user_notification(self, user): - notification = NotificationSubscriptionFactory(_id=f'{user._id}_global', user=user, event_name='global') - notification.add_user_to_subscription(user, 'email_transactional') - return notification + return NotificationSubscriptionFactory( + notification_type=NotificationType.Type.USER_FILE_UPDATED.instance, + user=user, + ) @pytest.fixture() def file_updated_notification(self, node, user): - notification = NotificationSubscriptionFactory( - _id=node._id + 'file_updated', - owner=node, - event_name='file_updated', + return NotificationSubscriptionFactory( + notification_type=NotificationType.Type.NODE_FILES_UPDATED.instance, + subscribed_object=node, + user=user, + ) + + @pytest.fixture() + def provider_notification(self, provider, user): + return NotificationSubscriptionFactory( + notification_type=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance, + subscribed_object=provider, + user=user, ) - notification.add_user_to_subscription(user, 'email_transactional') - return notification @pytest.fixture() def url(self, user, node): return f'/{API_BASE}subscriptions/' - def test_list_complete(self, app, user, provider, node, global_user_notification, url): + def test_list_complete( + self, + app, + user, + provider, + node, + global_user_notification, + provider_notification, + file_updated_notification, + url + ): res = app.get(url, auth=user.auth) notification_ids = [item['id'] for item in res.json['data']] # There should only be 3 notifications: users' global, node's file updates and provider's preprint added. diff --git a/api_tests/users/views/test_user_claim.py b/api_tests/users/views/test_user_claim.py index 0e265021c5c..243e45ce6ee 100644 --- a/api_tests/users/views/test_user_claim.py +++ b/api_tests/users/views/test_user_claim.py @@ -13,6 +13,7 @@ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestClaimUser: @pytest.fixture() @@ -116,41 +117,41 @@ def test_claim_unauth_failure(self, app, url, unreg_user, project, wrong_preprin ) assert res.status_code == 401 - def test_claim_unauth_success_with_original_email(self, app, url, project, unreg_user, mock_send_grid): - mock_send_grid.reset_mock() + def test_claim_unauth_success_with_original_email(self, app, url, project, unreg_user, mock_notification_send): + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(email='david@david.son', id=project._id), ) assert res.status_code == 204 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 - def test_claim_unauth_success_with_claimer_email(self, app, url, unreg_user, project, claimer, mock_send_grid): - mock_send_grid.reset_mock() + def test_claim_unauth_success_with_claimer_email(self, app, url, unreg_user, project, claimer, mock_notification_send): + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(email=claimer.username, id=project._id) ) assert res.status_code == 204 - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 - def test_claim_unauth_success_with_unknown_email(self, app, url, project, unreg_user, mock_send_grid): - mock_send_grid.reset_mock() + def test_claim_unauth_success_with_unknown_email(self, app, url, project, unreg_user, mock_notification_send): + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(email='asdf@fdsa.com', id=project._id), ) assert res.status_code == 204 - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 - def test_claim_unauth_success_with_preprint_id(self, app, url, preprint, unreg_user, mock_send_grid): - mock_send_grid.reset_mock() + def test_claim_unauth_success_with_preprint_id(self, app, url, preprint, unreg_user, mock_notification_send): + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(email='david@david.son', id=preprint._id), ) assert res.status_code == 204 - assert mock_send_grid.call_count == 1 + assert mock_notification_send.call_count == 1 def test_claim_auth_failure(self, app, url, claimer, wrong_preprint, project, unreg_user, referrer): _url = url.format(unreg_user._id) @@ -209,10 +210,10 @@ def test_claim_auth_failure(self, app, url, claimer, wrong_preprint, project, un ) assert res.status_code == 403 - def test_claim_auth_throttle_error(self, app, url, claimer, unreg_user, project, mock_send_grid): + def test_claim_auth_throttle_error(self, app, url, claimer, unreg_user, project, mock_notification_send): unreg_user.unclaimed_records[project._id]['last_sent'] = timezone.now() unreg_user.save() - mock_send_grid.reset_mock() + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(id=project._id), @@ -221,14 +222,14 @@ def test_claim_auth_throttle_error(self, app, url, claimer, unreg_user, project, ) assert res.status_code == 400 assert res.json['errors'][0]['detail'] == 'User account can only be claimed with an existing user once every 24 hours' - assert mock_send_grid.call_count == 0 + assert mock_notification_send.call_count == 0 - def test_claim_auth_success(self, app, url, claimer, unreg_user, project, mock_send_grid): - mock_send_grid.reset_mock() + def test_claim_auth_success(self, app, url, claimer, unreg_user, project, mock_notification_send): + mock_notification_send.reset_mock() res = app.post_json_api( url.format(unreg_user._id), self.payload(id=project._id), auth=claimer.auth ) assert res.status_code == 204 - assert mock_send_grid.call_count == 2 + assert mock_notification_send.call_count == 2 diff --git a/api_tests/users/views/test_user_confirm.py b/api_tests/users/views/test_user_confirm.py index 0cb4b7606a2..d304fc456b5 100644 --- a/api_tests/users/views/test_user_confirm.py +++ b/api_tests/users/views/test_user_confirm.py @@ -6,6 +6,7 @@ @pytest.mark.django_db +@pytest.mark.usefixtures('mock_notification_send') class TestConfirmEmail: @pytest.fixture() @@ -147,8 +148,7 @@ def test_post_success_create(self, mock_send_mail, app, confirm_url, user_with_e assert user.external_identity == {'ORCID': {'0002-0001-0001-0001': 'VERIFIED'}} assert user.emails.filter(address=email.lower()).exists() - @mock.patch('website.mails.send_mail') - def test_post_success_link(self, mock_send_mail, app, confirm_url, user_with_email_verification): + def test_post_success_link(self, mock_notification_send, app, confirm_url, user_with_email_verification): user, token, email = user_with_email_verification user.external_identity['ORCID']['0000-0000-0000-0000'] = 'LINK' user.save() @@ -168,7 +168,7 @@ def test_post_success_link(self, mock_send_mail, app, confirm_url, user_with_ema ) assert res.status_code == 201 - assert mock_send_mail.called + assert mock_notification_send.called user.reload() assert user.external_identity['ORCID']['0000-0000-0000-0000'] == 'VERIFIED' diff --git a/api_tests/users/views/test_user_settings.py b/api_tests/users/views/test_user_settings.py index 4854e2528ee..2ff7bdb0a06 100644 --- a/api_tests/users/views/test_user_settings.py +++ b/api_tests/users/views/test_user_settings.py @@ -1,15 +1,12 @@ from unittest import mock import pytest -import urllib from api.base.settings.defaults import API_BASE -from api.base.settings import CSRF_COOKIE_NAME from api.base.utils import hashids from osf_tests.factories import ( AuthUserFactory, UserFactory, ) -from django.middleware import csrf from osf.models import Email, NotableDomain from framework.auth.views import auth_email_logout @@ -44,7 +41,7 @@ def payload(self): } } - def test_get(self, app, user_one, url): + def test_get(self, app, user_one, url, mock_notification_send): res = app.get(url, auth=user_one.auth, expect_errors=True) assert res.status_code == 405 @@ -166,131 +163,6 @@ def test_multiple_errors(self, app, user_one, url, payload): assert res.json['errors'][0]['detail'] == 'Old password is invalid' assert res.json['errors'][1]['detail'] == 'Password should be at least eight characters' - -@pytest.mark.django_db -@pytest.mark.usefixtures('mock_send_grid') -class TestResetPassword: - - @pytest.fixture() - def user_one(self): - user = UserFactory() - user.set_password('password1') - user.auth = (user.username, 'password1') - user.save() - return user - - @pytest.fixture() - def url(self): - return f'/{API_BASE}users/reset_password/' - - @pytest.fixture - def csrf_token(self): - return csrf._mask_cipher_secret(csrf._get_new_csrf_string()) - - def test_get(self, mock_send_grid, app, url, user_one): - encoded_email = urllib.parse.quote(user_one.email) - url = f'{url}?email={encoded_email}' - res = app.get(url) - assert res.status_code == 200 - - user_one.reload() - assert mock_send_grid.call_args[1]['to_addr'] == user_one.username - - def test_get_invalid_email(self, mock_send_grid, app, url): - url = f'{url}?email={'invalid_email'}' - res = app.get(url) - assert res.status_code == 200 - assert not mock_send_grid.called - - def test_post(self, app, url, user_one, csrf_token): - app.set_cookie(CSRF_COOKIE_NAME, csrf_token) - encoded_email = urllib.parse.quote(user_one.email) - url = f'{url}?email={encoded_email}' - res = app.get(url) - user_one.reload() - payload = { - 'data': { - 'attributes': { - 'uid': user_one._id, - 'token': user_one.verification_key_v2['token'], - 'password': 'password2', - } - } - } - - res = app.post_json_api(url, payload, headers={'X-CSRFToken': csrf_token}) - user_one.reload() - assert res.status_code == 200 - assert user_one.check_password('password2') - - def test_post_empty_payload(self, app, url, csrf_token): - app.set_cookie(CSRF_COOKIE_NAME, csrf_token) - payload = { - 'data': { - 'attributes': { - } - } - } - res = app.post_json_api(url, payload, expect_errors=True, headers={'X-CSRFToken': csrf_token}) - assert res.status_code == 400 - - def test_post_invalid_token(self, app, url, user_one, csrf_token): - app.set_cookie(CSRF_COOKIE_NAME, csrf_token) - payload = { - 'data': { - 'attributes': { - 'uid': user_one._id, - 'token': 'invalid_token', - 'password': 'password2', - } - } - } - res = app.post_json_api(url, payload, expect_errors=True, headers={'X-THROTTLE-TOKEN': 'test-token', 'X-CSRFToken': csrf_token}) - assert res.status_code == 400 - - def test_post_invalid_password(self, app, url, user_one, csrf_token): - app.set_cookie(CSRF_COOKIE_NAME, csrf_token) - encoded_email = urllib.parse.quote(user_one.email) - url = f'{url}?email={encoded_email}' - res = app.get(url) - user_one.reload() - payload = { - 'data': { - 'attributes': { - 'uid': user_one._id, - 'token': user_one.verification_key_v2['token'], - 'password': user_one.username, - } - } - } - - res = app.post_json_api(url, payload, expect_errors=True, headers={'X-THROTTLE-TOKEN': 'test-token', 'X-CSRFToken': csrf_token}) - assert res.status_code == 400 - - def test_throttle(self, app, url, user_one): - encoded_email = urllib.parse.quote(user_one.email) - url = f'{url}?email={encoded_email}' - app.get(url) - user_one.reload() - payload = { - 'data': { - 'attributes': { - 'uid': user_one._id, - 'token': user_one.verification_key_v2['token'], - 'password': '12345', - } - } - } - - res = app.post_json_api(url, payload, expect_errors=True) - res = app.post_json_api(url, payload, expect_errors=True) - assert res.status_code == 429 - - res = app.get(url, expect_errors=True) - assert res.json['message'] == 'You have recently requested to change your password. Please wait a few minutes before trying again.' - - -@pytest.mark.django_db class TestUserEmailsList: @pytest.fixture(autouse=True) diff --git a/api_tests/users/views/test_user_settings_reset_password.py b/api_tests/users/views/test_user_settings_reset_password.py new file mode 100644 index 00000000000..2108a8272d7 --- /dev/null +++ b/api_tests/users/views/test_user_settings_reset_password.py @@ -0,0 +1,129 @@ +import pytest +import urllib + +from api.base.settings.defaults import API_BASE +from api.base.settings import CSRF_COOKIE_NAME +from osf_tests.factories import ( + UserFactory, +) +from django.middleware import csrf + +@pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') +class TestResetPassword: + + @pytest.fixture() + def user_one(self): + user = UserFactory() + user.set_password('password1') + user.auth = (user.username, 'password1') + user.save() + return user + + @pytest.fixture() + def url(self): + return f'/{API_BASE}users/reset_password/' + + @pytest.fixture + def csrf_token(self): + return csrf._mask_cipher_secret(csrf._get_new_csrf_string()) + + def test_get(self, mock_notification_send, app, url, user_one): + encoded_email = urllib.parse.quote(user_one.email) + url = f'{url}?email={encoded_email}' + res = app.get(url) + assert res.status_code == 200 + + user_one.reload() + assert mock_notification_send.called + + def test_get_invalid_email(self, mock_send_grid, app, url): + url = f'{url}?email={'invalid_email'}' + res = app.get(url) + assert res.status_code == 200 + assert not mock_send_grid.called + + def test_post(self, app, url, user_one, csrf_token): + app.set_cookie(CSRF_COOKIE_NAME, csrf_token) + encoded_email = urllib.parse.quote(user_one.email) + url = f'{url}?email={encoded_email}' + res = app.get(url) + user_one.reload() + payload = { + 'data': { + 'attributes': { + 'uid': user_one._id, + 'token': user_one.verification_key_v2['token'], + 'password': 'password2', + } + } + } + + res = app.post_json_api(url, payload, headers={'X-CSRFToken': csrf_token}) + user_one.reload() + assert res.status_code == 200 + assert user_one.check_password('password2') + + def test_post_empty_payload(self, app, url, csrf_token): + app.set_cookie(CSRF_COOKIE_NAME, csrf_token) + payload = { + 'data': { + 'attributes': { + } + } + } + res = app.post_json_api(url, payload, expect_errors=True, headers={'X-CSRFToken': csrf_token}) + assert res.status_code == 400 + + def test_post_invalid_token(self, app, url, user_one, csrf_token): + app.set_cookie(CSRF_COOKIE_NAME, csrf_token) + payload = { + 'data': { + 'attributes': { + 'uid': user_one._id, + 'token': 'invalid_token', + 'password': 'password2', + } + } + } + res = app.post_json_api(url, payload, expect_errors=True, headers={'X-THROTTLE-TOKEN': 'test-token', 'X-CSRFToken': csrf_token}) + assert res.status_code == 400 + + def test_post_invalid_password(self, app, url, user_one, csrf_token): + app.set_cookie(CSRF_COOKIE_NAME, csrf_token) + encoded_email = urllib.parse.quote(user_one.email) + url = f'{url}?email={encoded_email}' + res = app.get(url) + user_one.reload() + payload = { + 'data': { + 'attributes': { + 'uid': user_one._id, + 'token': user_one.verification_key_v2['token'], + 'password': user_one.username, + } + } + } + + res = app.post_json_api(url, payload, expect_errors=True, headers={'X-THROTTLE-TOKEN': 'test-token', 'X-CSRFToken': csrf_token}) + assert res.status_code == 400 + + def test_throttle(self, app, url, user_one, csrf_token): + app.set_cookie(CSRF_COOKIE_NAME, csrf_token) + encoded_email = urllib.parse.quote(user_one.email) + url = f'{url}?email={encoded_email}' + app.get(url) + user_one.reload() + payload = { + 'data': { + 'attributes': { + 'uid': user_one._id, + 'token': user_one.verification_key_v2['token'], + 'password': '12345', + } + } + } + res = app.post_json_api(url, payload, expect_errors=True, headers={'X-CSRFToken': csrf_token}) + + res = app.get(url, expect_errors=True) + assert res.json['message'] == 'You have recently requested to change your password. Please wait a few minutes before trying again.' diff --git a/conftest.py b/conftest.py index a65aa7aa50f..f7b7bf72b07 100644 --- a/conftest.py +++ b/conftest.py @@ -18,12 +18,12 @@ from framework.celery_tasks import app as celery_app from osf.external.spam import tasks as spam_tasks from website import settings as website_settings +from osf.management.commands.populate_notification_types import populate_notification_types def pytest_configure(config): if not os.getenv('GITHUB_ACTIONS') == 'true': config.option.allow_hosts += ',mailhog' - logger = logging.getLogger(__name__) # Silence some 3rd-party logging and some "loud" internal loggers @@ -362,6 +362,7 @@ def helpful_thing(self): """ yield from rolledback_transaction('function_transaction') + @pytest.fixture() def mock_send_grid(): with mock.patch.object(website_settings, 'USE_EMAIL', True): @@ -391,3 +392,25 @@ def mock_gravy_valet_get_verified_links(): with mock.patch('osf.external.gravy_valet.translations.get_verified_links') as mock_get_verified_links: mock_get_verified_links.return_value = [] yield mock_get_verified_links + + +@pytest.fixture() +def mock_notification_send(): + with mock.patch.object(website_settings, 'USE_EMAIL', True): + with mock.patch.object(website_settings, 'USE_CELERY', False): + with mock.patch('osf.models.notification.Notification.send') as mock_emit: + mock_emit.return_value = None # Or True, if needed + yield mock_emit + + +def start_mock_notification_send(test_case): + patcher = mock.patch('osf.models.notification.Notification.send') + mocked_emit = patcher.start() + test_case.addCleanup(patcher.stop) + mocked_emit.return_value = None + return mocked_emit + + +@pytest.fixture(autouse=True) +def load_notification_types(db, *args, **kwargs): + populate_notification_types(*args, **kwargs) diff --git a/framework/auth/views.py b/framework/auth/views.py index 26aa494ddd4..a1c42eda1ca 100644 --- a/framework/auth/views.py +++ b/framework/auth/views.py @@ -33,6 +33,7 @@ from osf.exceptions import ValidationValueError, BlockedEmailError from osf.models.provider import PreprintProvider from osf.models.tag import Tag +from osf.models.notification_type import NotificationType from osf.utils.requests import check_select_for_update from website.util.metrics import CampaignClaimedTags, CampaignSourceTags from website.ember_osf_web.decorators import ember_flag_is_active @@ -207,14 +208,14 @@ def redirect_unsupported_institution(auth): def forgot_password_post(): """Dispatches to ``_forgot_password_post`` passing non-institutional user mail template and reset action.""" - return _forgot_password_post(mail_template=mails.FORGOT_PASSWORD, + return _forgot_password_post(mail_template='forgot_password', reset_route='reset_password_get') def forgot_password_institution_post(): """Dispatches to `_forgot_password_post` passing institutional user mail template, reset action, and setting the ``institutional`` flag.""" - return _forgot_password_post(mail_template=mails.FORGOT_PASSWORD_INSTITUTION, + return _forgot_password_post(mail_template='forgot_password_institution', reset_route='reset_password_institution_get', institutional=True) @@ -272,12 +273,13 @@ def _forgot_password_post(mail_template, reset_route, institutional=False): token=user_obj.verification_key_v2['token'] ) ) - mails.send_mail( - to_addr=email, - mail=mail_template, - reset_link=reset_link, - can_change_preferences=False, - ) + notification_type = NotificationType.objects.filter(name=mail_template) + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + f'NotificationType with name {mail_template} does not exist.' + ) + notification_type = notification_type.first() + notification_type.emit(user=user_obj, message_frequency='instantly', event_context={'can_change_preferences': False, 'reset_link': reset_link}) # institutional forgot password page displays the message as main text, not as an alert if institutional: diff --git a/notifications.yaml b/notifications.yaml new file mode 100644 index 00000000000..11ab9db90fd --- /dev/null +++ b/notifications.yaml @@ -0,0 +1,223 @@ +# This file contains the configuration for our notification system using the NotificationType object, this is intended to +# exist as a simple declarative list of NotificationTypes and their attributes. Every notification sent by OSF should be +# represented here for bussiness logic dnd metrics reasons. + +# Workflow: +# 1. Add a new notification template +# 2. Add a entry here with the desired notification types +# 3. Add name tp Enum osf.notification.NotificationType.Type +# 4. Use the emit method to send or subscribe the notification for immediate deliver or periodic digest. +notification_types: + #### GLOBAL (User Notifications) + - name: user_pending_verification_registered + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: password_reset + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/password_reset.html.mako' + notification_freq_default: instantly + - name: forgot_password + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/forgot_password.html.mako' + notification_freq_default: instantly + - name: welcome_osf4i + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/welcome_osf4i.html.mako' + notification_freq_default: instantly + - name: invite_preprints_osf + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_preprints_osf.html.mako' + notification_freq_default: instantly + - name: invite_preprints + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_preprints.html.mako' + notification_freq_default: instantly + - name: invite_draft_registration + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_draft_registration.html.mako' + notification_freq_default: instantly + - name: invite_default + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_default.html.mako' + notification_freq_default: instantly + - name: pending_invite + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/pending_invite.html.mako' + notification_freq_default: instantly + - name: user_forward_invite + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/forward_invite.html.mako' + notification_freq_default: instantly + - name: external_confirm_success + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/external_confirm_success.html.mako' + notification_freq_default: instantly + - name: forgot_password_institution + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/forgot_password_institution.html.mako' + notification_freq_default: instantly + - name: user_contributor_added_default + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/contributor_added_default.html.mako' + notification_freq_default: instantly + - name: user_forward_invite_registered + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/forward_invite_registered.html.mako' + notification_freq_default: instantly + - name: user_invite_draft_registration + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_draft_registration.html.mako' + notification_freq_default: instantly + - name: user_invite_default + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/invite_default.html.mako' + notification_freq_default: instantly + - name: user_invite_osf_preprint + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/contributor_added_preprints_osf.html.mako' + notification_freq_default: instantly + - name: user_contributor_added_osf_preprint + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/contributor_added_preprints_osf.html.mako' + notification_freq_default: instantly + - name: user_contributor_added_draft_registration + __docs__: ... + object_content_type_model_name: osfuser + template: 'website/templates/emails/contributor_added_draft_registration.html.mako' + notification_freq_default: instantly + #### PROVIDER + - name: new_pending_submissions + __docs__: ... + object_content_type_model_name: abstractprovider + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: new_pending_withdraw_requests + __docs__: ... + object_content_type_model_name: abstractprovider + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: provider_user_invite_preprint + __docs__: ... + object_content_type_model_name: abstractprovider + template: 'website/templates/emails/contributor_added_preprints.html.mako' + notification_freq_default: instantly + #### NODE + - name: file_updated + __docs__: ... + object_content_type_model_name: abstractnode + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: wiki_updated + __docs__: ... + object_content_type_model_name: abstractnode + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: node_request_institutional_access_request + __docs__: ... + object_content_type_model_name: abstractnode + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: node_contributor_added_access_request + __docs__: ... + object_content_type_model_name: abstractnode + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + + #### PREPRINT + - name: pending_retraction_admin + __docs__: ... + object_content_type_model_name: preprint + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: preprint_contributor_added_preprint_node_from_osf + __docs__: ... + object_content_type_model_name: preprint + template: 'website/templates/emails/contributor_added_preprint_node_from_osf.html.mako' + notification_freq_default: instantly + #### SUPPORT + - name: crossref_error + __docs__: ... + object_content_type_model_name: abstractnode + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + #### Collection Submissions + - name: collection_submission_removed_moderator + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_removed_private + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_removed_admin + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_submitted + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_cancel + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_accepted + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: collection_submission_rejected + __docs__: ... + object_content_type_model_name: collectionsubmission + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + #### DESK + - name: desk_archive_job_exceeded + __docs__: Archive job failed due to size exceeded. Sent to support desk. + object_content_type_model_name: desk + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: desk_archive_job_copy_error + __docs__: Archive job failed due to copy error. Sent to support desk. + object_content_type_model_name: desk + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: desk_archive_job_file_not_found + __docs__: Archive job failed because files were not found. Sent to support desk. + object_content_type_model_name: desk + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: desk_archive_job_uncaught_error + __docs__: Archive job failed due to an uncaught error. Sent to support desk. + object_content_type_model_name: desk + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly + - name: desk_osf_support_email + __docs__: ... + object_content_type_model_name: desk + template: 'website/templates/emails/new_pending_submissions.html.mako' + notification_freq_default: instantly diff --git a/osf/email/__init__.py b/osf/email/__init__.py new file mode 100644 index 00000000000..d8cc1d6de5a --- /dev/null +++ b/osf/email/__init__.py @@ -0,0 +1,68 @@ +import logging +import smtplib +from email.mime.text import MIMEText +from sendgrid import SendGridAPIClient +from sendgrid.helpers.mail import Mail +from website import settings + +def send_email_over_smtp(to_addr, notification_type, context): + """Send an email notification using SMTP. This is typically not used in productions as other 3rd party mail services + are preferred. This is to be used for tests and on staging environments and special situations. + + Args: + to_addr (str): The recipient's email address. + notification_type (str): The subject of the notification. + context (dict): The email content context. + """ + if not settings.MAIL_SERVER: + raise NotImplementedError('MAIL_SERVER is not set') + if not settings.MAIL_USERNAME and settings.MAIL_PASSWORD: + raise NotImplementedError('MAIL_USERNAME and MAIL_PASSWORD are required for STMP') + + msg = MIMEText( + notification_type.template.format(context), + 'html', + _charset='utf-8' + ) + msg['Subject'] = notification_type.email_subject_line_template.format(context=context) + + with smtplib.SMTP(settings.MAIL_SERVER) as server: + server.ehlo() + server.starttls() + server.ehlo() + server.login(settings.MAIL_USERNAME, settings.MAIL_PASSWORD) + server.sendmail( + settings.FROM_EMAIL, + [to_addr], + msg.as_string() + ) + + +def send_email_with_send_grid(to_addr, notification_type, context): + """Send an email notification using SendGrid. + + Args: + to_addr (str): The recipient's email address. + notification_type (str): The subject of the notification. + context (dict): The email content context. + """ + if not settings.SENDGRID_API_KEY: + raise NotImplementedError('SENDGRID_API_KEY is required for sendgrid notifications.') + + message = Mail( + from_email=settings.FROM_EMAIL, + to_emails=to_addr, + subject=notification_type, + html_content=context.get('message', '') + ) + + try: + sg = SendGridAPIClient(settings.SENDGRID_API_KEY) + response = sg.send(message) + if response.status_code not in (200, 201, 202): + logging.error(f'SendGrid response error: {response.status_code}, body: {response.body}') + response.raise_for_status() + logging.info(f'Notification email sent to {to_addr} for {notification_type}.') + except Exception as exc: + logging.error(f'Failed to send email notification to {to_addr}: {exc}') + raise exc diff --git a/osf/management/commands/add_colon_delim_to_s3_buckets.py b/osf/management/commands/add_colon_delim_to_s3_buckets.py deleted file mode 100644 index 0a283f78f0f..00000000000 --- a/osf/management/commands/add_colon_delim_to_s3_buckets.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging - -from django.core.management.base import BaseCommand -from django.apps import apps -from django.db.models import F, Value -from django.db.models.functions import Concat, Replace - -logger = logging.getLogger(__name__) - - -class Command(BaseCommand): - """ - Adds Colon (':') delineators to s3 buckets to separate them from them from their subfolder, so `` - becomes `:/` , the root path. Folder names will also be updated to maintain consistency. - - """ - - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--reverse', - action='store_true', - dest='reverse', - help='Unsets date_retraction' - ) - - def handle(self, *args, **options): - reverse = options.get('reverse', False) - if reverse: - reverse_update_folder_names() - else: - update_folder_names() - - -def update_folder_names(): - NodeSettings = apps.get_model('addons_s3', 'NodeSettings') - - # Update folder_id for all records - NodeSettings.objects.exclude( - folder_name__contains=':/' - ).update( - folder_id=Concat(F('folder_id'), Value(':/')) - ) - - # Update folder_name for records containing '(' - NodeSettings.objects.filter( - folder_name__contains=' (' - ).exclude( - folder_name__contains=':/' - ).update( - folder_name=Replace(F('folder_name'), Value(' ('), Value(':/ (')) - ) - NodeSettings.objects.exclude( - folder_name__contains=':/' - ).exclude( - folder_name__contains=' (' - ).update( - folder_name=Concat(F('folder_name'), Value(':/')) - ) - logger.info('Update Folder Names/IDs complete') - - -def reverse_update_folder_names(): - NodeSettings = apps.get_model('addons_s3', 'NodeSettings') - - # Reverse update folder_id for all records - NodeSettings.objects.update(folder_id=Replace(F('folder_id'), Value(':/'), Value(''))) - - # Reverse update folder_name for records containing ':/ (' - NodeSettings.objects.filter(folder_name__contains=':/ (').update( - folder_name=Replace(F('folder_name'), Value(':/ ('), Value(' (')) - ) - NodeSettings.objects.filter(folder_name__contains=':/').update( - folder_name=Replace(F('folder_name'), Value(':/'), Value('')) - ) - logger.info('Reverse Update Folder Names/IDs complete') diff --git a/osf/management/commands/add_egap_registration_schema.py b/osf/management/commands/add_egap_registration_schema.py deleted file mode 100644 index ea5df1e7f4a..00000000000 --- a/osf/management/commands/add_egap_registration_schema.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging - -from django.core.management.base import BaseCommand -from osf.models import RegistrationSchema -from website.project.metadata.schemas import ensure_schema_structure, from_json - -logger = logging.getLogger(__name__) - - -class Command(BaseCommand): - """Add egap-registration schema to the db. - For now, doing this outside of a migration so it can be individually added to - a staging environment for preview. - """ - - def handle(self, *args, **options): - egap_registration_schema = ensure_schema_structure(from_json('egap-registration-3.json')) - schema_obj, created = RegistrationSchema.objects.update_or_create( - name=egap_registration_schema['name'], - schema_version=egap_registration_schema.get('version', 1), - defaults={ - 'schema': egap_registration_schema, - } - ) - - if created: - logger.info('Added schema {} to the database'.format(egap_registration_schema['name'])) - else: - logger.info('updated existing schema {}'.format(egap_registration_schema['name'])) diff --git a/osf/management/commands/add_institution_perm_groups.py b/osf/management/commands/add_institution_perm_groups.py deleted file mode 100644 index d7becaf2d8b..00000000000 --- a/osf/management/commands/add_institution_perm_groups.py +++ /dev/null @@ -1,19 +0,0 @@ -import logging - -from django.core.management.base import BaseCommand -from osf.models import Institution - -logger = logging.getLogger(__name__) - - -class Command(BaseCommand): - """A new permissions group was created for Institutions, which will be created upon each new Institution, - but the old institutions will not have this group. This management command creates those groups for the - existing institutions. - """ - - def handle(self, *args, **options): - institutions = Institution.objects.all() - for institution in institutions: - institution.update_group_permissions() - logger.info(f'Added perms to {institution.name}.') diff --git a/osf/management/commands/add_notification_subscription.py b/osf/management/commands/add_notification_subscription.py deleted file mode 100644 index 7d9a404f37a..00000000000 --- a/osf/management/commands/add_notification_subscription.py +++ /dev/null @@ -1,77 +0,0 @@ -# This is a management command, rather than a migration script, for two primary reasons: -# 1. It makes no changes to database structure (e.g. AlterField), only database content. -# 2. It takes a long time to run and the site doesn't need to be down that long. - -import logging - -import django -django.setup() - -from django.core.management.base import BaseCommand -from django.db import transaction - -from website.notifications.utils import to_subscription_key - -from scripts import utils as script_utils - -logger = logging.getLogger(__name__) - - -def add_reviews_notification_setting(notification_type, state=None): - if state: - OSFUser = state.get_model('osf', 'OSFUser') - NotificationSubscription = state.get_model('osf', 'NotificationSubscription') - else: - from osf.models import OSFUser, NotificationSubscription - - active_users = OSFUser.objects.filter(date_confirmed__isnull=False).exclude(date_disabled__isnull=False).exclude(is_active=False).order_by('id') - total_active_users = active_users.count() - - logger.info(f'About to add a global_reviews setting for {total_active_users} users.') - - total_created = 0 - for user in active_users.iterator(): - user_subscription_id = to_subscription_key(user._id, notification_type) - - subscription = NotificationSubscription.load(user_subscription_id) - if not subscription: - logger.info(f'No {notification_type} subscription found for user {user._id}. Subscribing...') - subscription = NotificationSubscription(_id=user_subscription_id, owner=user, event_name=notification_type) - subscription.save() # Need to save in order to access m2m fields - subscription.add_user_to_subscription(user, 'email_transactional') - else: - logger.info(f'User {user._id} already has a {notification_type} subscription') - total_created += 1 - - logger.info(f'Added subscriptions for {total_created}/{total_active_users} users') - - -class Command(BaseCommand): - """ - Add subscription to all active users for given notification type. - """ - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--dry', - action='store_true', - dest='dry_run', - help='Run migration and roll back changes to db', - ) - - parser.add_argument( - '--notification', - type=str, - required=True, - help='Notification type to subscribe users to', - ) - - def handle(self, *args, **options): - dry_run = options.get('dry_run', False) - state = options.get('state', None) - if not dry_run: - script_utils.add_file_logger(logger, __file__) - with transaction.atomic(): - add_reviews_notification_setting(notification_type=options['notification'], state=state) - if dry_run: - raise RuntimeError('Dry run, transaction rolled back.') diff --git a/osf/management/commands/addon_deleted_date.py b/osf/management/commands/addon_deleted_date.py deleted file mode 100644 index df2f78b26e0..00000000000 --- a/osf/management/commands/addon_deleted_date.py +++ /dev/null @@ -1,96 +0,0 @@ -import datetime -import logging - -from django.core.management.base import BaseCommand -from django.db import connection, transaction -from framework.celery_tasks import app as celery_app - -logger = logging.getLogger(__name__) - -TABLES_TO_POPULATE_WITH_MODIFIED = [ - 'addons_zotero_usersettings', - 'addons_dropbox_usersettings', - 'addons_dropbox_nodesettings', - 'addons_figshare_nodesettings', - 'addons_figshare_usersettings', - 'addons_forward_nodesettings', - 'addons_github_nodesettings', - 'addons_github_usersettings', - 'addons_gitlab_nodesettings', - 'addons_gitlab_usersettings', - 'addons_googledrive_nodesettings', - 'addons_googledrive_usersettings', - 'addons_mendeley_nodesettings', - 'addons_mendeley_usersettings', - 'addons_onedrive_nodesettings', - 'addons_onedrive_usersettings', - 'addons_osfstorage_nodesettings', - 'addons_osfstorage_usersettings', - 'addons_bitbucket_nodesettings', - 'addons_bitbucket_usersettings', - 'addons_owncloud_nodesettings', - 'addons_box_nodesettings', - 'addons_owncloud_usersettings', - 'addons_box_usersettings', - 'addons_dataverse_nodesettings', - 'addons_dataverse_usersettings', - 'addons_s3_nodesettings', - 'addons_s3_usersettings', - 'addons_twofactor_usersettings', - 'addons_wiki_nodesettings', - 'addons_zotero_nodesettings' -] - -UPDATE_DELETED_WITH_MODIFIED = """UPDATE {} SET deleted=modified - WHERE id IN (SELECT id FROM {} WHERE is_deleted AND deleted IS NULL LIMIT {}) RETURNING id;""" - -@celery_app.task(name='management.commands.addon_deleted_date') -def populate_deleted(dry_run=False, page_size=1000): - with transaction.atomic(): - for table in TABLES_TO_POPULATE_WITH_MODIFIED: - run_statements(UPDATE_DELETED_WITH_MODIFIED, page_size, table) - if dry_run: - raise RuntimeError('Dry Run -- Transaction rolled back') - -def run_statements(statement, page_size, table): - logger.info(f'Populating deleted column in table {table}') - with connection.cursor() as cursor: - cursor.execute(statement.format(table, table, page_size)) - rows = cursor.fetchall() - if rows: - logger.info(f'Table {table} still has rows to populate') - -class Command(BaseCommand): - help = '''Populates new deleted field for various models. Ensure you have run migrations - before running this script.''' - - def add_arguments(self, parser): - parser.add_argument( - '--dry_run', - type=bool, - default=False, - help='Run queries but do not write files', - ) - parser.add_argument( - '--page_size', - type=int, - default=1000, - help='How many rows to process at a time', - ) - - def handle(self, *args, **options): - script_start_time = datetime.datetime.now() - logger.info(f'Script started time: {script_start_time}') - logger.debug(options) - - dry_run = options['dry_run'] - page_size = options['page_size'] - - if dry_run: - logger.info('DRY RUN') - - populate_deleted(dry_run, page_size) - - script_finish_time = datetime.datetime.now() - logger.info(f'Script finished time: {script_finish_time}') - logger.info(f'Run time {script_finish_time - script_start_time}') diff --git a/osf/management/commands/backfill_date_retracted.py b/osf/management/commands/backfill_date_retracted.py deleted file mode 100644 index 698a67c82ae..00000000000 --- a/osf/management/commands/backfill_date_retracted.py +++ /dev/null @@ -1,89 +0,0 @@ -# This is a management command, rather than a migration script, for two primary reasons: -# 1. It makes no changes to database structure (e.g. AlterField), only database content. -# 2. It may need to be ran more than once, as it skips failed registrations. - -from datetime import timedelta -import logging - -import django -django.setup() - -from django.core.management.base import BaseCommand -from django.db import transaction - -from osf.models import Registration, Retraction, Sanction -from scripts import utils as script_utils - -logger = logging.getLogger(__name__) - -def set_date_retracted(*args): - registrations = ( - Registration.objects.filter(retraction__state=Sanction.APPROVED, retraction__date_retracted=None) - .select_related('retraction') - .prefetch_related('registered_from__logs') - .prefetch_related('registered_from__guids') - ) - total = registrations.count() - logger.info(f'Migrating {total} retractions.') - - for registration in registrations: - if not registration.registered_from: - logger.warning(f'Skipping failed registration {registration._id}') - continue - retraction_logs = registration.registered_from.logs.filter(action='retraction_approved', params__retraction_id=registration.retraction._id) - if retraction_logs.count() != 1 and retraction_logs.first().date - retraction_logs.last().date > timedelta(seconds=5): - msg = ( - 'There should be a retraction_approved log for retraction {} on node {}. No retraction_approved log found.' - if retraction_logs.count() == 0 - else 'There should only be one retraction_approved log for retraction {} on node {}. Multiple logs found.' - ) - raise Exception(msg.format(registration.retraction._id, registration.registered_from._id)) - date_retracted = retraction_logs[0].date - logger.info( - 'Setting date_retracted for retraction {} to be {}, from retraction_approved node log {}.'.format( - registration.retraction._id, date_retracted, retraction_logs[0]._id - ) - ) - registration.retraction.date_retracted = date_retracted - registration.retraction.save() - -def unset_date_retracted(*args): - retractions = Retraction.objects.filter(state=Sanction.APPROVED).exclude(date_retracted=None) - logger.info(f'Migrating {retractions.count()} retractions.') - - for retraction in retractions: - retraction.date_retracted = None - retraction.save() - - -class Command(BaseCommand): - """ - Backfill Retraction.date_retracted with `RETRACTION_APPROVED` log date. - """ - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--dry', - action='store_true', - dest='dry_run', - help='Run migration and roll back changes to db', - ) - parser.add_argument( - '--reverse', - action='store_true', - dest='reverse', - help='Unsets date_retraction' - ) - - def handle(self, *args, **options): - reverse = options.get('reverse', False) - dry_run = options.get('dry_run', False) - if not dry_run: - script_utils.add_file_logger(logger, __file__) - with transaction.atomic(): - if reverse: - unset_date_retracted() - else: - set_date_retracted() - if dry_run: - raise RuntimeError('Dry run, transaction rolled back.') diff --git a/osf/management/commands/create_fake_preprint_actions.py b/osf/management/commands/create_fake_preprint_actions.py deleted file mode 100644 index 85b28ae9f20..00000000000 --- a/osf/management/commands/create_fake_preprint_actions.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 - -import random -import logging -from faker import Faker - -from django.core.management.base import BaseCommand - -from osf.models import ReviewAction, Preprint, OSFUser -from osf.utils.workflows import DefaultStates, DefaultTriggers - -logger = logging.getLogger(__name__) - - -class Command(BaseCommand): - """Add fake Actions to every preprint that doesn't already have one""" - - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - 'user', - type=str, - nargs='?', - default=None, - help='Guid for user to list as creator for all fake actions (default to arbitrary user)' - ) - parser.add_argument( - '--num-actions', - action='store', - type=int, - default=10, - help='Number of actions to create for each preprint which does not have one' - ) - - def handle(self, *args, **options): - user_guid = options.get('user') - num_actions = options.get('--num-actions') - - if user_guid is None: - user = OSFUser.objects.first() - else: - user = OSFUser.objects.get(guids___id=user_guid) - - fake = Faker() - triggers = [a.value for a in DefaultTriggers] - states = [s.value for s in DefaultStates] - for preprint in Preprint.objects.filter(actions__isnull=True): - for i in range(num_actions): - action = ReviewAction( - target=preprint, - creator=user, - trigger=random.choice(triggers), - from_state=random.choice(states), - to_state=random.choice(states), - comment=fake.text(), - ) - action.save() diff --git a/osf/management/commands/fake_metrics_reports.py b/osf/management/commands/fake_metrics_reports.py deleted file mode 100644 index 765d6e475c1..00000000000 --- a/osf/management/commands/fake_metrics_reports.py +++ /dev/null @@ -1,62 +0,0 @@ -from datetime import date, timedelta -from random import randint - -from django.conf import settings -from django.core.management.base import BaseCommand - -from osf.metrics import ( - UserSummaryReport, - PreprintSummaryReport, -) -from osf.models import PreprintProvider - - -def fake_user_counts(days_back): - yesterday = date.today() - timedelta(days=1) - first_report = UserSummaryReport( - report_date=(yesterday - timedelta(days=days_back)), - active=randint(0, 23), - deactivated=randint(0, 2), - merged=randint(0, 4), - new_users_daily=randint(0, 7), - new_users_with_institution_daily=randint(0, 5), - unconfirmed=randint(0, 3), - ) - first_report.save() - - last_report = first_report - while last_report.report_date < yesterday: - new_user_count = randint(0, 500) - new_report = UserSummaryReport( - report_date=(last_report.report_date + timedelta(days=1)), - active=(last_report.active + randint(0, new_user_count)), - deactivated=(last_report.deactivated + randint(0, new_user_count)), - merged=(last_report.merged + randint(0, new_user_count)), - new_users_daily=new_user_count, - new_users_with_institution_daily=randint(0, new_user_count), - unconfirmed=(last_report.unconfirmed + randint(0, new_user_count)), - ) - new_report.save() - last_report = new_report - - -def fake_preprint_counts(days_back): - yesterday = date.today() - timedelta(days=1) - provider_keys = PreprintProvider.objects.all().values_list('_id', flat=True) - for day_delta in range(days_back): - for provider_key in provider_keys: - preprint_count = randint(100, 5000) * (days_back - day_delta) - PreprintSummaryReport( - report_date=yesterday - timedelta(days=day_delta), - provider_key=provider_key, - preprint_count=preprint_count, - ).save() - - -class Command(BaseCommand): - def handle(self, *args, **kwargs): - if not settings.DEBUG: - raise NotImplementedError('fake_reports requires DEBUG mode') - fake_user_counts(1000) - fake_preprint_counts(1000) - # TODO: more reports diff --git a/osf/management/commands/make_dummy_pageviews_for_metrics.py b/osf/management/commands/make_dummy_pageviews_for_metrics.py deleted file mode 100644 index 09de34bf7a8..00000000000 --- a/osf/management/commands/make_dummy_pageviews_for_metrics.py +++ /dev/null @@ -1,118 +0,0 @@ -"""osf/management/commands/poke_metrics_timespan_queries.py -""" -import logging -import random -import datetime - -from django.core.management.base import BaseCommand -from osf.metrics import CountedAuthUsage - - -logger = logging.getLogger(__name__) - -TIME_FILTERS = ( - {'gte': 'now/d-150d'}, - {'gte': '2021-11-28T23:00:00.000Z', 'lte': '2023-01-16T00:00:00.000Z'}, -) - -PLATFORM_IRI = 'http://localhost:9201/' - -ITEM_GUID = 'foo' - - -class Command(BaseCommand): - - def add_arguments(self, parser): - parser.add_argument( - '--count', - type=int, - default=100, - help='number of fake pageviews to generate', - ) - parser.add_argument( - '--seconds_back', - type=int, - default=60 * 60 * 24 * 14, # up to two weeks back - help='max age in seconds of random event', - ) - - def handle(self, *args, **options): - self._generate_random_countedusage(options.get('count'), options.get('seconds_back')) - - results = [ - self._run_date_query(time_filter) - for time_filter in TIME_FILTERS - ] - - self._print_line( - (str(f) for f in TIME_FILTERS), - label='timefilter:', - ) - - date_keys = { - k - for r in results - for k in r - } - for date_key in sorted(date_keys): - self._print_line( - (r.get(date_key, 0) for r in results), - label=str(date_key), - ) - - def _print_line(self, lineitems, label=''): - print('\t'.join((label, *map(str, lineitems)))) - - def _generate_random_countedusage(self, n, max_age): - now = datetime.datetime.now(tz=datetime.UTC) - for _ in range(n): - seconds_back = random.randint(0, max_age) - timestamp_time = now - datetime.timedelta(seconds=seconds_back) - CountedAuthUsage.record( - platform_iri=PLATFORM_IRI, - timestamp=timestamp_time, - item_guid=ITEM_GUID, - session_id='freshen by key', - user_is_authenticated=bool(random.randint(0, 1)), - item_public=bool(random.randint(0, 1)), - action_labels=[['view', 'download'][random.randint(0, 1)]], - ) - - def _run_date_query(self, time_range_filter): - result = self._run_query({ - 'query': { - 'bool': { - 'filter': { - 'range': { - 'timestamp': time_range_filter, - }, - }, - }, - }, - 'aggs': { - 'by-date': { - 'date_histogram': { - 'field': 'timestamp', - 'interval': 'day', - }, - }, - 'max-timestamp': { - 'max': {'field': 'timestamp'}, - }, - 'min-timestamp': { - 'min': {'field': 'timestamp'}, - }, - }, - }) - return { - 'min': result.aggs['min-timestamp'].value, - 'max': result.aggs['max-timestamp'].value, - **{ - str(bucket.key.date()): bucket.doc_count - for bucket in result.aggs['by-date'] - }, - } - - def _run_query(self, query_dict): - analytics_search = CountedAuthUsage.search().update_from_dict(query_dict) - return analytics_search.execute() diff --git a/osf/management/commands/migrate_notifications.py b/osf/management/commands/migrate_notifications.py new file mode 100644 index 00000000000..f4dfaf3c0c8 --- /dev/null +++ b/osf/management/commands/migrate_notifications.py @@ -0,0 +1,63 @@ +import logging +from django.contrib.contenttypes.models import ContentType +from osf.models import NotificationType, NotificationSubscription +from osf.models.notifications import NotificationSubscriptionLegacy +from django.core.management.base import BaseCommand +from django.db import transaction +from osf.management.commands.populate_notification_types import populate_notification_types + +logger = logging.getLogger(__name__) + +FREQ_MAP = { + 'none': 'none', + 'email_digest': 'weekly', + 'email_transactional': 'instantly', +} + +def migrate_legacy_notification_subscriptions(*args, **kwargs): + """ + Migrate legacy NotificationSubscription data to new notifications app. + """ + logger.info('Beginning legacy notification subscription migration...') + + PROVIDER_BASED_LEGACY_NOTIFICATION_TYPES = [ + 'new_pending_submissions', 'new_pending_withdraw_requests' + ] + + for legacy in NotificationSubscriptionLegacy.objects.all(): + event_name = legacy.event_name + if event_name in PROVIDER_BASED_LEGACY_NOTIFICATION_TYPES: + subscribed_object = legacy.provider + elif subscribed_object := legacy.node: + pass + elif subscribed_object := legacy.user: + pass + else: + raise NotImplementedError(f'Invalid Notification id {event_name}') + content_type = ContentType.objects.get_for_model(subscribed_object.__class__) + subscription, _ = NotificationSubscription.objects.update_or_create( + notification_type=NotificationType.objects.get(name=event_name), + user=legacy.user, + content_type=content_type, + object_id=subscribed_object.id, + defaults={ + 'user': legacy.user, + 'message_frequency': ( + ('weekly' if legacy.email_digest.exists() else 'none'), + 'instantly' if legacy.email_transactional.exists() else 'none' + ), + 'content_type': content_type, + 'object_id': subscribed_object.id, + } + ) + logger.info(f'Created NotificationType "{event_name}" with content_type {content_type}') + +class Command(BaseCommand): + help = 'Migrate legacy NotificationSubscriptionLegacy objects to new Notification app models.' + + def handle(self, *args, **options): + with transaction.atomic(): + populate_notification_types(args, options) + + with transaction.atomic(): + migrate_legacy_notification_subscriptions(args, options) diff --git a/osf/management/commands/migrate_pagecounter_data.py b/osf/management/commands/migrate_pagecounter_data.py deleted file mode 100644 index 050a355123f..00000000000 --- a/osf/management/commands/migrate_pagecounter_data.py +++ /dev/null @@ -1,124 +0,0 @@ -import datetime -import logging - -from django.core.management.base import BaseCommand -from django.db import connection - -from framework import sentry -from framework.celery_tasks import app as celery_app - -logger = logging.getLogger(__name__) - - -LIMIT_CLAUSE = ' LIMIT %s);' -NO_LIMIT_CLAUSE = ');' - -REVERSE_SQL_BASE = ''' -UPDATE osf_pagecounter PC -SET - resource_id = NULL, - file_id = NULL, - version = NULL, - action = NULL -WHERE PC.id IN ( - SELECT PC.id FROM osf_pagecounter PC - INNER JOIN osf_guid Guid on Guid._id = split_part(PC._id, ':', 2) - INNER JOIN osf_basefilenode File on File._id = split_part(PC._id, ':', 3) -''' -REVERSE_SQL = f'{REVERSE_SQL_BASE} {NO_LIMIT_CLAUSE}' -REVERSE_SQL_LIMITED = f'{REVERSE_SQL_BASE} {LIMIT_CLAUSE}' - -FORWARD_SQL_BASE = ''' - UPDATE osf_pagecounter PC - SET - action = split_part(PC._id, ':', 1), - resource_id = Guid.id, - file_id = File.id, - version = NULLIF(split_part(PC._id, ':', 4), '')::int - FROM osf_guid Guid, osf_basefilenode File - WHERE - Guid._id = split_part(PC._id, ':', 2) AND - File._id = split_part(PC._id, ':', 3) AND - PC.id in ( - select PC.id from osf_pagecounter PC - INNER JOIN osf_guid Guid on Guid._id = split_part(PC._id, ':', 2) - INNER JOIN osf_basefilenode File on File._id = split_part(PC._id, ':', 3) - WHERE (PC.resource_id IS NULL OR PC.file_id IS NULL) -''' -FORWARD_SQL = f'{FORWARD_SQL_BASE} {NO_LIMIT_CLAUSE}' -FORWARD_SQL_LIMITED = f'{FORWARD_SQL_BASE} {LIMIT_CLAUSE}' - -COUNT_SQL = ''' -SELECT count(PC.id) - from osf_pagecounter as PC - INNER JOIN osf_guid Guid on Guid._id = split_part(PC._id, ':', 2) - INNER JOIN osf_basefilenode File on File._id = split_part(PC._id, ':', 3) -where (PC.resource_id IS NULL or PC.file_id IS NULL); -''' - -@celery_app.task(name='management.commands.migrate_pagecounter_data') -def migrate_page_counters(dry_run=False, rows=10000, reverse=False): - script_start_time = datetime.datetime.now() - logger.info(f'Script started time: {script_start_time}') - - sql_query = REVERSE_SQL_LIMITED if reverse else FORWARD_SQL_LIMITED - logger.info(f'SQL Query: {sql_query}') - - with connection.cursor() as cursor: - if not dry_run: - cursor.execute(sql_query, [rows]) - if not reverse: - cursor.execute(COUNT_SQL) - number_of_entries_left = cursor.fetchone()[0] - logger.info(f'Entries left: {number_of_entries_left}') - if number_of_entries_left == 0: - sentry.log_message('Migrate pagecounter data complete') - - script_finish_time = datetime.datetime.now() - logger.info(f'Script finished time: {script_finish_time}') - logger.info(f'Run time {script_finish_time - script_start_time}') - - -class Command(BaseCommand): - help = '''Does the work of the pagecounter migration so that it can be done incrementally when convenient. - You will either need to set the page_size large enough to get all of the records, or you will need to run the - script multiple times until it tells you that it is done.''' - - def add_arguments(self, parser): - parser.add_argument( - '--dry_run', - type=bool, - default=False, - help='Run queries but do not write files', - ) - parser.add_argument( - '--rows', - type=int, - default=10000, - help='How many rows to process during this run', - ) - parser.add_argument( - '--reverse', - type=bool, - default=False, - help='Reverse out the migration', - ) - - # Management command handler - def handle(self, *args, **options): - logger.debug(options) - - dry_run = options['dry_run'] - rows = options['rows'] - reverse = options['reverse'] - logger.debug( - 'Dry run: {}, rows: {}, reverse: {}'.format( - dry_run, - rows, - reverse, - ) - ) - if dry_run: - logger.info('DRY RUN') - - migrate_page_counters(dry_run, rows, reverse) diff --git a/osf/management/commands/migrate_preprint_affiliation.py b/osf/management/commands/migrate_preprint_affiliation.py deleted file mode 100644 index e34c6dc6b27..00000000000 --- a/osf/management/commands/migrate_preprint_affiliation.py +++ /dev/null @@ -1,118 +0,0 @@ -import datetime -import logging - -from django.core.management.base import BaseCommand -from django.db import transaction -from django.db.models import F, Exists, OuterRef - -from osf.models import PreprintContributor, InstitutionAffiliation - -logger = logging.getLogger(__name__) - -AFFILIATION_TARGET_DATE = datetime.datetime(2024, 9, 19, 14, 37, 48, tzinfo=datetime.timezone.utc) - - -class Command(BaseCommand): - """Assign affiliations from users to preprints where they have write or admin permissions, with optional exclusion by user GUIDs.""" - - help = 'Assign affiliations from users to preprints where they have write or admin permissions.' - - def add_arguments(self, parser): - parser.add_argument( - '--exclude-guids', - nargs='+', - dest='exclude_guids', - help='List of user GUIDs to exclude from affiliation assignment' - ) - parser.add_argument( - '--dry-run', - action='store_true', - dest='dry_run', - help='If true, performs a dry run without making changes' - ) - parser.add_argument( - '--batch-size', - type=int, - default=1000, - dest='batch_size', - help='Number of contributors to process in each batch' - ) - - def handle(self, *args, **options): - start_time = datetime.datetime.now() - logger.info(f'Script started at: {start_time}') - - exclude_guids = set(options.get('exclude_guids') or []) - dry_run = options.get('dry_run', False) - batch_size = options.get('batch_size', 1000) - - if dry_run: - logger.info('Dry run mode activated.') - - processed_count, updated_count = assign_affiliations_to_preprints( - exclude_guids=exclude_guids, - dry_run=dry_run, - batch_size=batch_size - ) - - finish_time = datetime.datetime.now() - logger.info(f'Script finished at: {finish_time}') - logger.info(f'Total processed: {processed_count}, Updated: {updated_count}') - logger.info(f'Total run time: {finish_time - start_time}') - - -def assign_affiliations_to_preprints(exclude_guids=None, dry_run=True, batch_size=1000): - exclude_guids = exclude_guids or set() - processed_count = updated_count = 0 - - # Subquery to check if the user has any affiliated institutions - user_has_affiliations = Exists( - InstitutionAffiliation.objects.filter( - user=OuterRef('user') - ) - ) - - contributors_qs = PreprintContributor.objects.filter( - preprint__preprintgroupobjectpermission__permission__codename__in=['write_preprint'], - preprint__preprintgroupobjectpermission__group__user=F('user'), - ).filter( - user_has_affiliations - ).select_related( - 'user', - 'preprint' - ).exclude( - user__guids___id__in=exclude_guids - ).order_by('pk') # Ensure consistent ordering for batching - - total_contributors = contributors_qs.count() - logger.info(f'Total contributors to process: {total_contributors}') - - # Process contributors in batches - with transaction.atomic(): - for offset in range(0, total_contributors, batch_size): - # Use select_for_update() to ensure query hits the primary database - batch_contributors = contributors_qs[offset:offset + batch_size].select_for_update() - - logger.info(f'Processing contributors {offset + 1} to {min(offset + batch_size, total_contributors)}') - - for contributor in batch_contributors: - user = contributor.user - preprint = contributor.preprint - - if preprint.created > AFFILIATION_TARGET_DATE: - continue - - user_institutions = user.get_affiliated_institutions() - processed_count += 1 - if not dry_run: - preprint.affiliated_institutions.add(*user_institutions) - updated_count += 1 - logger.info( - f'Assigned {len(user_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.' - ) - else: - logger.info( - f'Dry run: Would assign {len(user_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.' - ) - - return processed_count, updated_count diff --git a/osf/management/commands/migrate_user_institution_affiliation.py b/osf/management/commands/migrate_user_institution_affiliation.py deleted file mode 100644 index 79170c5ece4..00000000000 --- a/osf/management/commands/migrate_user_institution_affiliation.py +++ /dev/null @@ -1,84 +0,0 @@ -import datetime -import logging - -from django.core.management.base import BaseCommand - -from osf.models import Institution, InstitutionAffiliation - -logger = logging.getLogger(__name__) - - -class Command(BaseCommand): - """Update emails of users from a given affiliated institution (when eligible). - """ - - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--dry', - action='store_true', - dest='dry_run', - help='If true, iterate through eligible users and institutions only' - ) - - def handle(self, *args, **options): - script_start_time = datetime.datetime.now() - logger.info(f'Script started time: {script_start_time}') - - dry_run = options.get('dry_run', False) - if dry_run: - logger.warning('Dry Run: This is a dry-run pass!') - migrate_user_institution_affiliation(dry_run=dry_run) - - script_finish_time = datetime.datetime.now() - logger.info(f'Script finished time: {script_finish_time}') - logger.info(f'Run time {script_finish_time - script_start_time}') - - -def migrate_user_institution_affiliation(dry_run=True): - - institutions = Institution.objects.get_all_institutions() - institution_total = institutions.count() - - institution_count = 0 - user_count = 0 - skipped_user_count = 0 - - for institution in institutions: - institution_count += 1 - user_count_per_institution = 0 - skipped_user_count_per_institution = 0 - users = institution.osfuser_set.all() - user_total_per_institution = users.count() - sso_identity = None - if not institution.delegation_protocol: - sso_identity = InstitutionAffiliation.DEFAULT_VALUE_FOR_SSO_IDENTITY_NOT_AVAILABLE - logger.info(f'Migrating affiliation for <{institution.name}> [{institution_count}/{institution_total}]') - for user in institution.osfuser_set.all(): - user_count_per_institution += 1 - user_count += 1 - logger.info(f'\tMigrating affiliation for <{user._id}::{institution.name}> ' - f'[{user_count_per_institution}/{user_total_per_institution}]') - if not dry_run: - affiliation = user.add_or_update_affiliated_institution( - institution, - sso_identity=sso_identity, - sso_department=user.department - ) - if affiliation: - logger.info(f'\tAffiliation=<{affiliation}> migrated or updated ' - f'for user=<{user._id}> @ institution=<{institution._id}>') - else: - skipped_user_count_per_institution += 1 - skipped_user_count += 1 - logger.info(f'\tSkip migration or update since affiliation exists ' - f'for user=<{user._id}> @ institution=<{institution._id}>') - else: - logger.warning(f'\tDry Run: Affiliation not migrated for {user._id} @ {institution._id}!') - if user_count_per_institution == 0: - logger.warning('No eligible user found') - else: - logger.info(f'Finished migrating affiliation for {user_count_per_institution} users ' - f'@ <{institution.name}>, including {skipped_user_count_per_institution} skipped users') - logger.info(f'Finished migrating affiliation for {user_count} users @ {institution_count} institutions, ' - f'including {skipped_user_count} skipped users') diff --git a/osf/management/commands/move_egap_regs_to_provider.py b/osf/management/commands/move_egap_regs_to_provider.py deleted file mode 100644 index 1dcaa7a6b77..00000000000 --- a/osf/management/commands/move_egap_regs_to_provider.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging - -from django.core.management.base import BaseCommand - -from scripts import utils as script_utils - -logger = logging.getLogger(__name__) - -from osf.models import ( - RegistrationProvider, - RegistrationSchema, - Registration -) - - -def main(dry_run): - egap_provider = RegistrationProvider.objects.get(_id='egap') - egap_schemas = RegistrationSchema.objects.filter(name='EGAP Registration').order_by('-schema_version') - - for egap_schema in egap_schemas: - egap_regs = Registration.objects.filter(registered_schema=egap_schema.id, provider___id='osf') - - if dry_run: - logger.info(f'[DRY RUN] {egap_regs.count()} updated to {egap_provider} with id {egap_provider.id}') - else: - egap_regs.update(provider_id=egap_provider.id) - - -class Command(BaseCommand): - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--dry', - action='store_true', - dest='dry_run', - help='Dry run', - ) - - def handle(self, *args, **options): - dry_run = options.get('dry_run', False) - if not dry_run: - script_utils.add_file_logger(logger, __file__) - - main(dry_run=dry_run) diff --git a/osf/management/commands/populate_branched_from_node.py b/osf/management/commands/populate_branched_from_node.py deleted file mode 100644 index 086f7e4dbef..00000000000 --- a/osf/management/commands/populate_branched_from_node.py +++ /dev/null @@ -1,67 +0,0 @@ -import logging -import datetime - -from django.core.management.base import BaseCommand -from framework.celery_tasks import app as celery_app -from django.db import connection, transaction - -logger = logging.getLogger(__name__) - -POPULATE_BRANCHED_FROM_NODE = """WITH cte AS ( - SELECT id - FROM osf_abstractnode - WHERE type = 'osf.registration' AND - branched_from_node IS null - LIMIT %s -) -UPDATE osf_abstractnode a - SET branched_from_node = CASE WHEN - EXISTS(SELECT id FROM osf_nodelog WHERE action='project_created_from_draft_reg' AND node_id = a.id) THEN False - ELSE True -END -FROM cte -WHERE cte.id = a.id -""" - -@celery_app.task(name='management.commands.populate_branched_from') -def populate_branched_from(page_size=10000, dry_run=False): - with transaction.atomic(): - with connection.cursor() as cursor: - cursor.execute(POPULATE_BRANCHED_FROM_NODE, [page_size]) - if dry_run: - raise RuntimeError('Dry Run -- Transaction rolled back') - -class Command(BaseCommand): - help = '''Populates new deleted field for various models. Ensure you have run migrations - before running this script.''' - - def add_arguments(self, parser): - parser.add_argument( - '--dry_run', - type=bool, - default=False, - help='Run queries but do not write files', - ) - parser.add_argument( - '--page_size', - type=int, - default=10000, - help='How many rows to process at a time', - ) - - def handle(self, *args, **options): - script_start_time = datetime.datetime.now() - logger.info(f'Script started time: {script_start_time}') - logger.debug(options) - - dry_run = options['dry_run'] - page_size = options['page_size'] - - if dry_run: - logger.info('DRY RUN') - - populate_branched_from(page_size, dry_run) - - script_finish_time = datetime.datetime.now() - logger.info(f'Script finished time: {script_finish_time}') - logger.info(f'Run time {script_finish_time - script_start_time}') diff --git a/osf/management/commands/populate_collection_provider_notification_subscriptions.py b/osf/management/commands/populate_collection_provider_notification_subscriptions.py index 5713b08061b..c3a21eb8d20 100644 --- a/osf/management/commands/populate_collection_provider_notification_subscriptions.py +++ b/osf/management/commands/populate_collection_provider_notification_subscriptions.py @@ -1,7 +1,7 @@ import logging from django.core.management.base import BaseCommand -from osf.models import NotificationSubscription, CollectionProvider +from osf.models import NotificationSubscriptionLegacy, CollectionProvider logger = logging.getLogger(__file__) @@ -12,7 +12,7 @@ def populate_collection_provider_notification_subscriptions(): provider_moderators = provider.get_group('moderator').user_set.all() for subscription in provider.DEFAULT_SUBSCRIPTIONS: - instance, created = NotificationSubscription.objects.get_or_create( + instance, created = NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{provider._id}_{subscription}', event_name=subscription, provider=provider diff --git a/osf/management/commands/populate_initial_schema_responses.py b/osf/management/commands/populate_initial_schema_responses.py deleted file mode 100644 index 26ba3da7710..00000000000 --- a/osf/management/commands/populate_initial_schema_responses.py +++ /dev/null @@ -1,100 +0,0 @@ -import logging - -from django.core.management.base import BaseCommand -from django.db import transaction -from django.db.models import Exists, F, OuterRef -from framework.celery_tasks import app as celery_app - -from osf.exceptions import PreviousSchemaResponseError, SchemaResponseUpdateError -from osf.models import Registration, SchemaResponse -from osf.utils.workflows import ApprovalStates, RegistrationModerationStates as RegStates - -logger = logging.getLogger(__name__) - -# Initial response pending amin approval or rejected while awaiting it -UNAPPROVED_STATES = [RegStates.INITIAL.db_name, RegStates.REVERTED.db_name] -# Initial response pending moderator approval or rejected while awaiting it -PENDING_MODERATION_STATES = [RegStates.PENDING.db_name, RegStates.REJECTED.db_name] - - -def _update_schema_response_state(schema_response): - '''Set the schema_response's state based on the current state of the parent rgistration.''' - moderation_state = schema_response.parent.moderation_state - if moderation_state in UNAPPROVED_STATES: - schema_response.state = ApprovalStates.UNAPPROVED - elif moderation_state in PENDING_MODERATION_STATES: - schema_response.state = ApprovalStates.PENDING_MODERATION - else: # All remainint states imply initial responses were approved by users at some point - schema_response.state = ApprovalStates.APPROVED - schema_response.save() - - -@celery_app.task(name='management.commands.populate_initial_schema_responses') -@transaction.atomic -def populate_initial_schema_responses(dry_run=False, batch_size=None): - '''Migrate registration_responses into a SchemaResponse for historical registrations.''' - # Find all root registrations that do not yet have SchemaResponses - qs = Registration.objects.prefetch_related('root').annotate( - has_schema_response=Exists(SchemaResponse.objects.filter(nodes__id=OuterRef('id'))) - ).filter( - has_schema_response=False, root=F('id') - ) - if batch_size: - qs = qs[:batch_size] - - count = 0 - for registration in qs: - logger.info( - f'{"[DRY RUN] " if dry_run else ""}' - f'Creating initial SchemaResponse for Registration with guid {registration._id}' - ) - try: - registration.copy_registration_responses_into_schema_response() - except SchemaResponseUpdateError as e: - logger.info( - f'Ignoring unsupported values "registration_responses" for registration ' - f'with guid [{registration._id}]: {str(e)}' - ) - except (ValueError, PreviousSchemaResponseError): - logger.exception( - f'{"[DRY RUN] " if dry_run else ""}' - f'Failure creating SchemaResponse for Registration with guid {registration._id}' - ) - # These errors should have prevented SchemaResponse creation, but better safe than sorry - registration.schema_responses.all().delete() - continue - - _update_schema_response_state(registration.schema_responses.last()) - count += 1 - - logger.info( - f'{"[DRY RUN] " if dry_run else ""}' - f'Created initial SchemaResponses for {count} registrations' - ) - - if dry_run: - raise RuntimeError('Dry run, transaction rolled back') - - return count - - -class Command(BaseCommand): - def add_arguments(self, parser): - super().add_arguments(parser) - parser.add_argument( - '--dry', - action='store_true', - dest='dry_run', - help='Dry run', - ) - - parser.add_argument( - '--batch_size', - type=int, - default=0 - ) - - def handle(self, *args, **options): - dry_run = options.get('dry_run') - batch_size = options.get('batch_size') - populate_initial_schema_responses(dry_run=dry_run, batch_size=batch_size) diff --git a/osf/management/commands/populate_notification_types.py b/osf/management/commands/populate_notification_types.py new file mode 100644 index 00000000000..8f20531f06a --- /dev/null +++ b/osf/management/commands/populate_notification_types.py @@ -0,0 +1,72 @@ +import yaml +from django.apps import apps +from website import settings + +import logging +from django.contrib.contenttypes.models import ContentType +from osf.models import NotificationType +from django.core.management.base import BaseCommand +from django.db import transaction + +logger = logging.getLogger(__name__) + +FREQ_MAP = { + 'none': 'none', + 'email_digest': 'weekly', + 'email_transactional': 'instantly', +} + +def populate_notification_types(*args, **kwargs): + + with open(settings.NOTIFICATION_TYPES_YAML) as stream: + notification_types = yaml.safe_load(stream) + for notification_type in notification_types['notification_types']: + notification_type.pop('__docs__') + object_content_type_model_name = notification_type.pop('object_content_type_model_name') + notification_freq = notification_type.pop('notification_freq_default') + + if object_content_type_model_name == 'desk': + content_type = None + elif object_content_type_model_name == 'osfuser': + OSFUser = apps.get_model('osf', 'OSFUser') + content_type = ContentType.objects.get_for_model(OSFUser) + elif object_content_type_model_name == 'preprint': + Preprint = apps.get_model('osf', 'Preprint') + content_type = ContentType.objects.get_for_model(Preprint) + elif object_content_type_model_name == 'collectionsubmission': + CollectionSubmission = apps.get_model('osf', 'CollectionSubmission') + content_type = ContentType.objects.get_for_model(CollectionSubmission) + elif object_content_type_model_name == 'abstractprovider': + AbstractProvider = apps.get_model('osf', 'abstractprovider') + content_type = ContentType.objects.get_for_model(AbstractProvider) + elif object_content_type_model_name == 'osfuser': + OSFUser = apps.get_model('osf', 'OSFUser') + content_type = ContentType.objects.get_for_model(OSFUser) + else: + try: + content_type = ContentType.objects.get( + app_label='osf', + model=object_content_type_model_name + ) + except ContentType.DoesNotExist: + raise ValueError(f'No content type for osf.{object_content_type_model_name}') + + with open(notification_type['template']) as stream: + template = stream.read() + + notification_types['template'] = template + notification_types['notification_freq'] = notification_freq + nt, _ = NotificationType.objects.update_or_create( + name=notification_type['name'], + defaults=notification_type, + ) + nt.object_content_type = content_type + nt.save() + + +class Command(BaseCommand): + help = 'Population notification types.' + + def handle(self, *args, **options): + with transaction.atomic(): + populate_notification_types(args, options) diff --git a/osf/management/commands/populate_registration_provider_notification_subscriptions.py b/osf/management/commands/populate_registration_provider_notification_subscriptions.py index fe372fcbb80..db4b44acba5 100644 --- a/osf/management/commands/populate_registration_provider_notification_subscriptions.py +++ b/osf/management/commands/populate_registration_provider_notification_subscriptions.py @@ -2,7 +2,7 @@ from django.contrib.auth.models import Group from django.core.management.base import BaseCommand -from osf.models import NotificationSubscription, RegistrationProvider +from osf.models import RegistrationProvider, NotificationSubscriptionLegacy logger = logging.getLogger(__file__) @@ -17,7 +17,7 @@ def populate_registration_provider_notification_subscriptions(): continue for subscription in provider.DEFAULT_SUBSCRIPTIONS: - instance, created = NotificationSubscription.objects.get_or_create( + instance, created = NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{provider._id}_{subscription}', event_name=subscription, provider=provider diff --git a/osf/migrations/0032_alter_notificationsubscription_options_and_more.py b/osf/migrations/0032_alter_notificationsubscription_options_and_more.py new file mode 100644 index 00000000000..b4f273108d5 --- /dev/null +++ b/osf/migrations/0032_alter_notificationsubscription_options_and_more.py @@ -0,0 +1,132 @@ +# Generated by Django 4.2.13 on 2025-07-08 17:07 + +from django.conf import settings +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import django_extensions.db.fields +import osf.models.base +import osf.models.notification_type + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('osf', '0031_alter_osfgroupgroupobjectpermission_unique_together_and_more'), + ] + + operations = [ + migrations.AlterModelOptions( + name='notificationsubscription', + options={'verbose_name': 'Notification Subscription', 'verbose_name_plural': 'Notification Subscriptions'}, + ), + migrations.AlterUniqueTogether( + name='notificationsubscription', + unique_together=set(), + ), + migrations.AddField( + model_name='notificationsubscription', + name='content_type', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contenttypes.contenttype'), + ), + migrations.AddField( + model_name='notificationsubscription', + name='message_frequency', + field=models.CharField(max_length=500, null=True), + ), + migrations.AddField( + model_name='notificationsubscription', + name='object_id', + field=models.CharField(blank=True, max_length=255, null=True), + ), + migrations.AlterField( + model_name='notificationsubscription', + name='user', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='subscriptions', to=settings.AUTH_USER_MODEL), + ), + migrations.CreateModel( + name='NotificationType', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('notification_interval_choices', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=32), blank=True, default=osf.models.notification_type.get_default_frequency_choices, size=None)), + ('name', models.CharField(max_length=255, unique=True)), + ('template', models.TextField(help_text='Template used to render the event_info. Supports Django template syntax.')), + ('subject', models.TextField(blank=True, help_text='Template used to render the subject line of email. Supports Django template syntax.', null=True)), + ('object_content_type', models.ForeignKey(blank=True, help_text='Content type for subscribed objects. Null means global event.', null=True, on_delete=django.db.models.deletion.SET_NULL, to='contenttypes.contenttype')), + ], + options={ + 'verbose_name': 'Notification Type', + 'verbose_name_plural': 'Notification Types', + }, + ), + migrations.CreateModel( + name='Notification', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('event_context', models.JSONField()), + ('sent', models.DateTimeField(blank=True, null=True)), + ('seen', models.DateTimeField(blank=True, null=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('subscription', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='notifications', to='osf.notificationsubscription')), + ], + options={ + 'verbose_name': 'Notification', + 'verbose_name_plural': 'Notifications', + }, + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='_id', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='email_digest', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='email_transactional', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='event_name', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='node', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='none', + ), + migrations.RemoveField( + model_name='notificationsubscription', + name='provider', + ), + migrations.AddField( + model_name='notificationsubscription', + name='notification_type', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='osf.notificationtype'), + ), + migrations.CreateModel( + name='NotificationSubscriptionLegacy', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', django_extensions.db.fields.CreationDateTimeField(auto_now_add=True, verbose_name='created')), + ('modified', django_extensions.db.fields.ModificationDateTimeField(auto_now=True, verbose_name='modified')), + ('_id', models.CharField(db_index=True, max_length=100)), + ('event_name', models.CharField(max_length=100)), + ('email_digest', models.ManyToManyField(related_name='+', to=settings.AUTH_USER_MODEL)), + ('email_transactional', models.ManyToManyField(related_name='+', to=settings.AUTH_USER_MODEL)), + ('node', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='notification_subscriptions', to='osf.node')), + ('none', models.ManyToManyField(related_name='+', to=settings.AUTH_USER_MODEL)), + ('provider', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='notification_subscriptions', to='osf.abstractprovider')), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='notification_subscriptions', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'db_table': 'osf_notificationsubscription_legacy', + 'unique_together': {('_id', 'provider')}, + }, + bases=(models.Model, osf.models.base.QuerySetExplainMixin), + ), + ] diff --git a/osf/models/__init__.py b/osf/models/__init__.py index 909183adab6..d09e350adfe 100644 --- a/osf/models/__init__.py +++ b/osf/models/__init__.py @@ -62,7 +62,11 @@ from .node_relation import NodeRelation from .nodelog import NodeLog from .notable_domain import NotableDomain, DomainReference -from .notifications import NotificationDigest, NotificationSubscription +from .notifications import NotificationDigest, NotificationSubscriptionLegacy +from .notification_subscription import NotificationSubscription +from .notification_type import NotificationType +from .notification import Notification + from .oauth import ( ApiOAuth2Application, ApiOAuth2PersonalToken, diff --git a/osf/models/collection_submission.py b/osf/models/collection_submission.py index 893533d85d1..56c5a64f659 100644 --- a/osf/models/collection_submission.py +++ b/osf/models/collection_submission.py @@ -132,10 +132,10 @@ def _notify_moderators_pending(self, event_data): 'allow_submissions': True, } - from .notifications import NotificationSubscription + from .notifications import NotificationSubscriptionLegacy from website.notifications.emails import store_emails - provider_subscription, created = NotificationSubscription.objects.get_or_create( + provider_subscription, created = NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{self.collection.provider._id}_new_pending_submissions', provider=self.collection.provider ) diff --git a/osf/models/notification.py b/osf/models/notification.py new file mode 100644 index 00000000000..557712b81a5 --- /dev/null +++ b/osf/models/notification.py @@ -0,0 +1,67 @@ +import logging + +from django.db import models +from website import settings +from api.base import settings as api_settings +from osf import email + +class Notification(models.Model): + subscription = models.ForeignKey( + 'NotificationSubscription', + on_delete=models.CASCADE, + related_name='notifications' + ) + event_context: dict = models.JSONField() + sent = models.DateTimeField(null=True, blank=True) + seen = models.DateTimeField(null=True, blank=True) + created = models.DateTimeField(auto_now_add=True) + + def send(self, protocol_type='email', recipient=None): + if not settings.USE_EMAIL: + return + if not protocol_type == 'email': + raise NotImplementedError(f'Protocol type {protocol_type}. Email notifications are only implemented.') + + recipient_address = getattr(recipient, 'username', None) or self.subscription.user + + if protocol_type == 'email' and settings.DEV_MODE and settings.ENABLE_TEST_EMAIL: + email.send_email_over_smtp( + recipient_address, + self.subscription.notification_type, + self.event_context + ) + elif protocol_type == 'email' and settings.DEV_MODE: + if not api_settings.CI_ENV: + logging.info( + f"Attempting to send email in DEV_MODE with ENABLE_TEST_EMAIL false just logs:" + f"\nto={recipient_address}" + f"\ntype={self.subscription.notification_type.name}" + f"\ncontext={self.event_context}" + ) + elif protocol_type == 'email': + email.send_email_with_send_grid( + getattr(recipient, 'username', None) or self.subscription.user, + self.subscription.notification_type, + self.event_context + ) + else: + raise NotImplementedError(f'protocol `{protocol_type}` is not supported.') + + self.mark_sent() + + def mark_sent(self) -> None: + raise NotImplementedError('mark_sent must be implemented by subclasses.') + # self.sent = timezone.now() + # self.save(update_fields=['sent']) + + def mark_seen(self) -> None: + raise NotImplementedError('mark_seen must be implemented by subclasses.') + # self.seen = timezone.now() + # self.save(update_fields=['seen']) + + def __str__(self) -> str: + return f'Notification for {self.subscription.user} [{self.subscription.notification_type.name}]' + + class Meta: + verbose_name = 'Notification' + verbose_name_plural = 'Notifications' diff --git a/osf/models/notification_subscription.py b/osf/models/notification_subscription.py new file mode 100644 index 00000000000..a1c9467b50e --- /dev/null +++ b/osf/models/notification_subscription.py @@ -0,0 +1,101 @@ +from django.db import models +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ValidationError +from osf.models.notification_type import get_default_frequency_choices +from osf.models.notification import Notification + +from .base import BaseModel + + +class NotificationSubscription(BaseModel): + notification_type = models.ForeignKey( + 'NotificationType', + on_delete=models.CASCADE, + null=True + ) + user = models.ForeignKey( + 'osf.OSFUser', + null=True, + on_delete=models.CASCADE, + related_name='subscriptions' + ) + message_frequency: str = models.CharField( + max_length=500, + null=True + ) + + content_type = models.ForeignKey(ContentType, null=True, blank=True, on_delete=models.CASCADE) + object_id = models.CharField(max_length=255, null=True, blank=True) + subscribed_object = GenericForeignKey('content_type', 'object_id') + + def clean(self): + ct = self.notification_type.object_content_type + + if ct: + if self.content_type != ct: + raise ValidationError('Subscribed object must match type\'s content_type.') + if not self.object_id: + raise ValidationError('Subscribed object ID is required.') + else: + if self.content_type or self.object_id: + raise ValidationError('Global subscriptions must not have an object.') + + allowed_freqs = self.notification_type.notification_interval_choices or get_default_frequency_choices() + if self.message_frequency not in allowed_freqs: + raise ValidationError(f'{self.message_frequency!r} is not allowed for {self.notification_type.name!r}.') + + def __str__(self) -> str: + return f'<{self.user} via {self.subscribed_object} subscribes to {self.notification_type.name} ({self.message_frequency})>' + + class Meta: + verbose_name = 'Notification Subscription' + verbose_name_plural = 'Notification Subscriptions' + + def emit(self, user, subscribed_object=None, event_context=None): + """Emit a notification to a user by creating Notification and NotificationSubscription objects. + + Args: + user (OSFUser): The recipient of the notification. + subscribed_object (optional): The object the subscription is related to. + event_context (dict, optional): Context for rendering the notification template. + """ + if self.message_frequency == 'instantly': + Notification.objects.create( + subscription=self, + event_context=event_context + ).send() + else: + Notification.objects.create( + subscription=self, + event_context=event_context + ) + + @property + def absolute_api_v2_url(self): + from api.base.utils import absolute_reverse + return absolute_reverse('institutions:institution-detail', kwargs={'institution_id': self._id, 'version': 'v2'}) + + @property + def _id(self): + """ + Legacy subscription id for API compatibility. + Provider: _ + User/global: _global_ + Node/etc: _ + """ + # Safety checks + event = self.notification_type.name + ct = self.notification_type.object_content_type + match getattr(ct, 'model', None): + case 'preprintprovider' | 'collectionprovider' | 'registrationprovider': + # Providers: use subscribed_object._id (which is the provider short name, e.g. 'mindrxiv') + return f'{self.subscribed_object._id}_new_pending_submissions' + case 'node' | 'collection' | 'preprint': + # Node-like objects: use object_id (guid) + return f'{self.subscribed_object._id}_{event}' + case 'osfuser' | 'user' | None: + # Global: _global + return f'{self.user._id}_global' + case _: + raise NotImplementedError() diff --git a/osf/models/notification_type.py b/osf/models/notification_type.py new file mode 100644 index 00000000000..9b36d20e93a --- /dev/null +++ b/osf/models/notification_type.py @@ -0,0 +1,250 @@ +from django.db import models +from django.contrib.postgres.fields import ArrayField +from django.contrib.contenttypes.models import ContentType + +from osf.models.notification import Notification +from enum import Enum + + +class FrequencyChoices(Enum): + NONE = 'none' + INSTANTLY = 'instantly' + DAILY = 'daily' + WEEKLY = 'weekly' + MONTHLY = 'monthly' + + @classmethod + def choices(cls): + return [(key.value, key.name.capitalize()) for key in cls] + +def get_default_frequency_choices(): + DEFAULT_FREQUENCY_CHOICES = ['none', 'instantly', 'daily', 'weekly', 'monthly'] + return DEFAULT_FREQUENCY_CHOICES.copy() + + +class NotificationType(models.Model): + + class Type(str, Enum): + # Desk notifications + DESK_REQUEST_EXPORT = 'desk_request_export' + DESK_REQUEST_DEACTIVATION = 'desk_request_deactivation' + DESK_OSF_SUPPORT_EMAIL = 'desk_osf_support_email' + DESK_REGISTRATION_BULK_UPLOAD_PRODUCT_OWNER = 'desk_registration_bulk_upload_product_owner' + DESK_USER_REGISTRATION_BULK_UPLOAD_UNEXPECTED_FAILURE = 'desk_user_registration_bulk_upload_unexpected_failure' + DESK_ARCHIVE_JOB_EXCEEDED = 'desk_archive_job_exceeded' + DESK_ARCHIVE_JOB_COPY_ERROR = 'desk_archive_job_copy_error' + DESK_ARCHIVE_JOB_FILE_NOT_FOUND = 'desk_archive_job_file_not_found' + DESK_ARCHIVE_JOB_UNCAUGHT_ERROR = 'desk_archive_job_uncaught_error' + + # User notifications + USER_PENDING_VERIFICATION = 'user_pending_verification' + USER_PENDING_VERIFICATION_REGISTERED = 'user_pending_verification_registered' + USER_STORAGE_CAP_EXCEEDED_ANNOUNCEMENT = 'user_storage_cap_exceeded_announcement' + USER_SPAM_BANNED = 'user_spam_banned' + USER_REQUEST_DEACTIVATION_COMPLETE = 'user_request_deactivation_complete' + USER_PRIMARY_EMAIL_CHANGED = 'user_primary_email_changed' + USER_INSTITUTION_DEACTIVATION = 'user_institution_deactivation' + USER_FORGOT_PASSWORD = 'user_forgot_password' + USER_FORGOT_PASSWORD_INSTITUTION = 'user_forgot_password_institution' + USER_REQUEST_EXPORT = 'user_request_export' + USER_CONTRIBUTOR_ADDED_OSF_PREPRINT = 'user_contributor_added_osf_preprint' + USER_CONTRIBUTOR_ADDED_DEFAULT = 'user_contributor_added_default' + USER_DUPLICATE_ACCOUNTS_OSF4I = 'user_duplicate_accounts_osf4i' + USER_EXTERNAL_LOGIN_LINK_SUCCESS = 'user_external_login_link_success' + USER_REGISTRATION_BULK_UPLOAD_FAILURE_ALL = 'user_registration_bulk_upload_failure_all' + USER_REGISTRATION_BULK_UPLOAD_SUCCESS_PARTIAL = 'user_registration_bulk_upload_success_partial' + USER_REGISTRATION_BULK_UPLOAD_SUCCESS_ALL = 'user_registration_bulk_upload_success_all' + USER_ADD_SSO_EMAIL_OSF4I = 'user_add_sso_email_osf4i' + USER_WELCOME_OSF4I = 'user_welcome_osf4i' + USER_ARCHIVE_JOB_EXCEEDED = 'user_archive_job_exceeded' + USER_ARCHIVE_JOB_COPY_ERROR = 'user_archive_job_copy_error' + USER_ARCHIVE_JOB_FILE_NOT_FOUND = 'user_archive_job_file_not_found' + USER_ARCHIVE_JOB_UNCAUGHT_ERROR = 'user_archive_job_uncaught_error' + USER_COMMENT_REPLIES = 'user_comment_replies' + USER_COMMENTS = 'user_comments' + USER_FILE_UPDATED = 'user_file_updated' + USER_COMMENT_MENTIONS = 'user_mentions' + USER_REVIEWS = 'user_reviews' + USER_PASSWORD_RESET = 'user_password_reset' + USER_CONTRIBUTOR_ADDED_DRAFT_REGISTRATION = 'user_contributor_added_draft_registration' + USER_EXTERNAL_LOGIN_CONFIRM_EMAIL_CREATE = 'user_external_login_confirm_email_create' + USER_EXTERNAL_LOGIN_CONFIRM_EMAIL_LINK = 'user_external_login_confirm_email_link' + USER_CONFIRM_MERGE = 'user_confirm_merge' + USER_CONFIRM_EMAIL = 'user_confirm_email' + USER_INITIAL_CONFIRM_EMAIL = 'user_initial_confirm_email' + USER_INVITE_DEFAULT = 'user_invite_default' + USER_PENDING_INVITE = 'user_pending_invite' + USER_FORWARD_INVITE = 'user_forward_invite' + USER_FORWARD_INVITE_REGISTERED = 'user_forward_invite_registered' + USER_INVITE_DRAFT_REGISTRATION = 'user_invite_draft_registration' + USER_INVITE_OSF_PREPRINT = 'user_invite_osf_preprint' + + # Node notifications + NODE_COMMENT = 'node_comments' + NODE_FILES_UPDATED = 'node_files_updated' + NODE_AFFILIATION_CHANGED = 'node_affiliation_changed' + NODE_REQUEST_ACCESS_SUBMITTED = 'node_access_request_submitted' + NODE_REQUEST_ACCESS_DENIED = 'node_request_access_denied' + NODE_FORK_COMPLETED = 'node_fork_completed' + NODE_FORK_FAILED = 'node_fork_failed' + NODE_REQUEST_INSTITUTIONAL_ACCESS_REQUEST = 'node_request_institutional_access_request' + NODE_CONTRIBUTOR_ADDED_ACCESS_REQUEST = 'node_contributor_added_access_request' + NODE_PENDING_EMBARGO_ADMIN = 'node_pending_embargo_admin' + NODE_PENDING_EMBARGO_NON_ADMIN = 'node_pending_embargo_non_admin' + NODE_PENDING_RETRACTION_NON_ADMIN = 'node_pending_retraction_non_admin' + NODE_PENDING_RETRACTION_ADMIN = 'node_pending_retraction_admin' + NODE_PENDING_REGISTRATION_NON_ADMIN = 'node_pending_registration_non_admin' + NODE_PENDING_REGISTRATION_ADMIN = 'node_pending_registration_admin' + NODE_PENDING_EMBARGO_TERMINATION_NON_ADMIN = 'node_pending_embargo_termination_non_admin' + NODE_PENDING_EMBARGO_TERMINATION_ADMIN = 'node_pending_embargo_termination_admin' + + # Provider notifications + PROVIDER_NEW_PENDING_SUBMISSIONS = 'provider_new_pending_submissions' + PROVIDER_REVIEWS_SUBMISSION_CONFIRMATION = 'provider_reviews_submission_confirmation' + PROVIDER_REVIEWS_MODERATOR_SUBMISSION_CONFIRMATION = 'provider_reviews_moderator_submission_confirmation' + PROVIDER_REVIEWS_WITHDRAWAL_REQUESTED = 'preprint_request_withdrawal_requested' + PROVIDER_REVIEWS_REJECT_CONFIRMATION = 'provider_reviews_reject_confirmation' + PROVIDER_REVIEWS_ACCEPT_CONFIRMATION = 'provider_reviews_accept_confirmation' + PROVIDER_REVIEWS_RESUBMISSION_CONFIRMATION = 'provider_reviews_resubmission_confirmation' + PROVIDER_REVIEWS_COMMENT_EDITED = 'provider_reviews_comment_edited' + PROVIDER_CONTRIBUTOR_ADDED_PREPRINT = 'provider_contributor_added_preprint' + PROVIDER_CONFIRM_EMAIL_MODERATION = 'provider_confirm_email_moderation' + PROVIDER_MODERATOR_ADDED = 'provider_moderator_added' + PROVIDER_CONFIRM_EMAIL_PREPRINTS = 'provider_confirm_email_preprints' + PROVIDER_USER_INVITE_PREPRINT = 'provider_user_invite_preprint' + + # Preprint notifications + PREPRINT_REQUEST_WITHDRAWAL_APPROVED = 'preprint_request_withdrawal_approved' + PREPRINT_REQUEST_WITHDRAWAL_DECLINED = 'preprint_request_withdrawal_declined' + PREPRINT_CONTRIBUTOR_ADDED_PREPRINT_NODE_FROM_OSF = 'preprint_contributor_added_preprint_node_from_osf' + + # Collections Submission notifications + COLLECTION_SUBMISSION_REMOVED_ADMIN = 'collection_submission_removed_admin' + COLLECTION_SUBMISSION_REMOVED_MODERATOR = 'collection_submission_removed_moderator' + COLLECTION_SUBMISSION_REMOVED_PRIVATE = 'collection_submission_removed_private' + COLLECTION_SUBMISSION_SUBMITTED = 'collection_submission_submitted' + COLLECTION_SUBMISSION_ACCEPTED = 'collection_submission_accepted' + COLLECTION_SUBMISSION_REJECTED = 'collection_submission_rejected' + COLLECTION_SUBMISSION_CANCEL = 'collection_submission_cancel' + + # Schema Response notifications + SCHEMA_RESPONSE_REJECTED = 'schema_response_rejected' + SCHEMA_RESPONSE_APPROVED = 'schema_response_approved' + SCHEMA_RESPONSE_SUBMITTED = 'schema_response_submitted' + SCHEMA_RESPONSE_INITIATED = 'schema_response_initiated' + + REGISTRATION_BULK_UPLOAD_FAILURE_DUPLICATES = 'registration_bulk_upload_failure_duplicates' + + @property + def instance(self): + obj, created = NotificationType.objects.get_or_create(name=self.value) + return obj + + @classmethod + def user_types(cls): + return [member for member in cls if member.name.startswith('USER_')] + + @classmethod + def node_types(cls): + return [member for member in cls if member.name.startswith('NODE_')] + + @classmethod + def preprint_types(cls): + return [member for member in cls if member.name.startswith('PREPRINT_')] + + @classmethod + def provider_types(cls): + return [member for member in cls if member.name.startswith('PROVIDER_')] + + @classmethod + def schema_response_types(cls): + return [member for member in cls if member.name.startswith('SCHEMA_RESPONSE_')] + + @classmethod + def desk_types(cls): + return [member for member in cls if member.name.startswith('DESK_')] + + notification_interval_choices = ArrayField( + base_field=models.CharField(max_length=32), + default=get_default_frequency_choices, + blank=True + ) + + name: str = models.CharField(max_length=255, unique=True, null=False, blank=False) + + object_content_type = models.ForeignKey( + ContentType, + on_delete=models.SET_NULL, + null=True, + blank=True, + help_text='Content type for subscribed objects. Null means global event.' + ) + + template: str = models.TextField( + help_text='Template used to render the event_info. Supports Django template syntax.' + ) + subject: str = models.TextField( + blank=True, + null=True, + help_text='Template used to render the subject line of email. Supports Django template syntax.' + ) + + def emit(self, user, subscribed_object=None, message_frequency=None, event_context=None): + """Emit a notification to a user by creating Notification and NotificationSubscription objects. + + Args: + user (OSFUser): The recipient of the notification. + subscribed_object (optional): The object the subscription is related to. + event_context (dict, optional): Context for rendering the notification template. + """ + from osf.models.notification_subscription import NotificationSubscription + subscription, created = NotificationSubscription.objects.get_or_create( + notification_type=self, + user=user, + content_type=ContentType.objects.get_for_model(subscribed_object) if subscribed_object else None, + object_id=subscribed_object.pk if subscribed_object else None, + defaults={'message_frequency': message_frequency}, + ) + if subscription.message_frequency == 'instantly': + Notification.objects.create( + subscription=subscription, + event_context=event_context + ).send() + + def add_user_to_subscription(self, user, *args, **kwargs): + """ + """ + from osf.models.notification_subscription import NotificationSubscription + + provider = kwargs.pop('provider', None) + node = kwargs.pop('node', None) + data = {} + if subscribed_object := provider or node: + data = { + 'object_id': subscribed_object.id, + 'content_type_id': ContentType.objects.get_for_model(subscribed_object).id, + } + + notification, created = NotificationSubscription.objects.get_or_create( + user=user, + notification_type=self, + **data, + ) + return notification + + def remove_user_from_subscription(self, user): + """ + """ + from osf.models.notification_subscription import NotificationSubscription + notification, _ = NotificationSubscription.objects.update_or_create( + user=user, + notification_type=self, + defaults={'message_frequency': FrequencyChoices.NONE.value} + ) + + def __str__(self) -> str: + return self.name + + class Meta: + verbose_name = 'Notification Type' + verbose_name_plural = 'Notification Types' diff --git a/osf/models/notifications.py b/osf/models/notifications.py index 86be3424832..41ec120b4ee 100644 --- a/osf/models/notifications.py +++ b/osf/models/notifications.py @@ -1,15 +1,16 @@ from django.contrib.postgres.fields import ArrayField from django.db import models + +from website.notifications.constants import NOTIFICATION_TYPES from .node import Node from .user import OSFUser from .base import BaseModel, ObjectIDMixin from .validators import validate_subscription_type from osf.utils.fields import NonNaiveDateTimeField -from website.notifications.constants import NOTIFICATION_TYPES from website.util import api_v2_url -class NotificationSubscription(BaseModel): +class NotificationSubscriptionLegacy(BaseModel): primary_identifier_name = '_id' _id = models.CharField(max_length=100, db_index=True, unique=False) # pxyz_wiki_updated, uabc_comment_replies @@ -29,6 +30,7 @@ class NotificationSubscription(BaseModel): class Meta: # Both PreprintProvider and RegistrationProvider default instances use "osf" as their `_id` unique_together = ('_id', 'provider') + db_table = 'osf_notificationsubscription_legacy' @classmethod def load(cls, q): @@ -95,7 +97,6 @@ def remove_user_from_subscription(self, user, save=True): if save: self.save() - class NotificationDigest(ObjectIDMixin, BaseModel): user = models.ForeignKey('OSFUser', null=True, blank=True, on_delete=models.CASCADE) provider = models.ForeignKey('AbstractProvider', null=True, blank=True, on_delete=models.CASCADE) diff --git a/osf/models/provider.py b/osf/models/provider.py index 2ee920a77e5..b8dacc174bf 100644 --- a/osf/models/provider.py +++ b/osf/models/provider.py @@ -19,7 +19,7 @@ from .brand import Brand from .citation import CitationStyle from .licenses import NodeLicense -from .notifications import NotificationSubscription +from .notifications import NotificationSubscriptionLegacy from .storage import ProviderAssetFile from .subject import Subject from osf.utils.datetime_aware_jsonfield import DateTimeAwareJSONField @@ -464,7 +464,7 @@ def create_provider_auth_groups(sender, instance, created, **kwargs): def create_provider_notification_subscriptions(sender, instance, created, **kwargs): if created: for subscription in instance.DEFAULT_SUBSCRIPTIONS: - NotificationSubscription.objects.get_or_create( + NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{instance._id}_{subscription}', event_name=subscription, provider=instance diff --git a/osf/models/user.py b/osf/models/user.py index ede9c96d5e5..420171dc61f 100644 --- a/osf/models/user.py +++ b/osf/models/user.py @@ -57,11 +57,12 @@ from osf.utils.requests import check_select_for_update from osf.utils.permissions import API_CONTRIBUTOR_PERMISSIONS, MANAGER, MEMBER, ADMIN from website import settings as website_settings -from website import filters, mails +from website import filters from website.project import new_bookmark_collection from website.util.metrics import OsfSourceTags, unregistered_created_source_tag from importlib import import_module from osf.utils.requests import get_headers_from_request +from osf.models.notification_type import NotificationType SessionStore = import_module(settings.SESSION_ENGINE).SessionStore @@ -1071,13 +1072,13 @@ def set_password(self, raw_password, notify=True): raise ChangePasswordError(['Password cannot be the same as your email address']) super().set_password(raw_password) if had_existing_password and notify: - mails.send_mail( - to_addr=self.username, - mail=mails.PASSWORD_RESET, - user=self, - can_change_preferences=False, - osf_contact_email=website_settings.OSF_CONTACT_EMAIL - ) + notification_type = NotificationType.objects.filter(name='password_reset') + if not notification_type.exists(): + raise NotificationType.DoesNotExist( + 'NotificationType with name password_reset does not exist.', + ) + notification_type = notification_type.first() + notification_type.emit(user=self, message_frequency='instantly', event_context={'can_change_preferences': False, 'osf_contact_email': website_settings.OSF_CONTACT_EMAIL}) remove_sessions_for_user(self) @classmethod diff --git a/osf_tests/factories.py b/osf_tests/factories.py index 1310c9aed63..d1c7e640250 100644 --- a/osf_tests/factories.py +++ b/osf_tests/factories.py @@ -1040,9 +1040,20 @@ def handle_callback(self, response): } +class NotificationSubscriptionLegacyFactory(DjangoModelFactory): + class Meta: + model = models.NotificationSubscriptionLegacy + + class NotificationSubscriptionFactory(DjangoModelFactory): class Meta: model = models.NotificationSubscription + notification_type = factory.LazyAttribute(lambda o: NotificationTypeFactory()) + + +class NotificationTypeFactory(DjangoModelFactory): + class Meta: + model = models.NotificationType def make_node_lineage(): diff --git a/osf_tests/management_commands/test_migrate_notifications.py b/osf_tests/management_commands/test_migrate_notifications.py new file mode 100644 index 00000000000..35837f7cc7c --- /dev/null +++ b/osf_tests/management_commands/test_migrate_notifications.py @@ -0,0 +1,132 @@ +import pytest +from django.contrib.contenttypes.models import ContentType + +from osf.models import Node, RegistrationProvider +from osf_tests.factories import ( + AuthUserFactory, + PreprintProviderFactory, + ProjectFactory, +) +from osf.models import ( + NotificationType, + NotificationSubscription, + NotificationSubscriptionLegacy +) +from osf.management.commands.migrate_notifications import ( + migrate_legacy_notification_subscriptions, + populate_notification_types +) + +@pytest.mark.django_db +class TestNotificationSubscriptionMigration: + + @pytest.fixture(autouse=True) + def notification_types(self): + return populate_notification_types() + + @pytest.fixture() + def user(self): + return AuthUserFactory() + + @pytest.fixture() + def users(self): + return { + 'none': AuthUserFactory(), + 'digest': AuthUserFactory(), + 'transactional': AuthUserFactory(), + } + + @pytest.fixture() + def provider(self): + return PreprintProviderFactory() + + @pytest.fixture() + def provider2(self): + return PreprintProviderFactory() + + @pytest.fixture() + def node(self): + return ProjectFactory() + + def create_legacy_sub(self, event_name, users, user=None, provider=None, node=None): + legacy = NotificationSubscriptionLegacy.objects.create( + _id=f'{(provider or node)._id}_{event_name}', + user=user, + event_name=event_name, + provider=provider, + node=node + ) + legacy.none.add(users['none']) + legacy.email_digest.add(users['digest']) + legacy.email_transactional.add(users['transactional']) + return legacy + + def test_migrate_provider_subscription(self, user, provider, provider2): + NotificationSubscriptionLegacy.objects.get( + event_name='new_pending_submissions', + provider=provider + ) + NotificationSubscriptionLegacy.objects.get( + event_name='new_pending_submissions', + provider=provider2 + ) + NotificationSubscriptionLegacy.objects.get( + event_name='new_pending_submissions', + provider=RegistrationProvider.get_default() + ) + migrate_legacy_notification_subscriptions() + + subs = NotificationSubscription.objects.filter(notification_type__name='new_pending_submissions') + assert subs.count() == 3 + assert subs.get( + notification_type__name='new_pending_submissions', + object_id=provider.id, + content_type=ContentType.objects.get_for_model(provider.__class__) + ) + assert subs.get( + notification_type__name='new_pending_submissions', + object_id=provider2.id, + content_type=ContentType.objects.get_for_model(provider2.__class__) + ) + + def test_migrate_node_subscription(self, users, user, node): + self.create_legacy_sub('wiki_updated', users, user=user, node=node) + + migrate_legacy_notification_subscriptions() + + nt = NotificationType.objects.get(name='wiki_updated') + assert nt.object_content_type == ContentType.objects.get_for_model(Node) + + subs = NotificationSubscription.objects.filter(notification_type=nt) + assert subs.count() == 1 + + for sub in subs: + assert sub.subscribed_object == node + + def test_multiple_subscriptions_different_types(self, users, user, provider, node): + assert not NotificationSubscription.objects.filter(user=user) + self.create_legacy_sub('wiki_updated', users, user=user, node=node) + migrate_legacy_notification_subscriptions() + assert NotificationSubscription.objects.get(user=user).notification_type.name == 'wiki_updated' + assert NotificationSubscription.objects.get(notification_type__name='wiki_updated', user=user) + + def test_idempotent_migration(self, users, user, node, provider): + self.create_legacy_sub('file_updated', users, user=user, node=node) + migrate_legacy_notification_subscriptions() + migrate_legacy_notification_subscriptions() + assert NotificationSubscription.objects.get( + user=user, + object_id=node.id, + content_type=ContentType.objects.get_for_model(node.__class__), + notification_type__name='file_updated' + ) + + def test_errors_invalid_subscription(self, users): + legacy = NotificationSubscriptionLegacy.objects.create( + _id='broken', + event_name='invalid_event' + ) + legacy.none.add(users['none']) + + with pytest.raises(NotImplementedError): + migrate_legacy_notification_subscriptions() diff --git a/osf_tests/management_commands/test_migrate_preprint_affiliations.py b/osf_tests/management_commands/test_migrate_preprint_affiliations.py deleted file mode 100644 index 8c80737b3dd..00000000000 --- a/osf_tests/management_commands/test_migrate_preprint_affiliations.py +++ /dev/null @@ -1,151 +0,0 @@ -import pytest -from datetime import timedelta -from osf.management.commands.migrate_preprint_affiliation import AFFILIATION_TARGET_DATE, assign_affiliations_to_preprints -from osf_tests.factories import ( - PreprintFactory, - InstitutionFactory, - AuthUserFactory, -) - - -@pytest.mark.django_db -class TestAssignAffiliationsToPreprints: - - @pytest.fixture() - def institution(self): - return InstitutionFactory() - - @pytest.fixture() - def user_with_affiliation(self, institution): - user = AuthUserFactory() - user.add_or_update_affiliated_institution(institution) - user.save() - return user - - @pytest.fixture() - def user_without_affiliation(self): - return AuthUserFactory() - - @pytest.fixture() - def preprint_with_affiliated_contributor(self, user_with_affiliation): - preprint = PreprintFactory() - preprint.add_contributor( - user_with_affiliation, - permissions='admin', - visible=True - ) - preprint.created = AFFILIATION_TARGET_DATE - timedelta(days=1) - preprint.save() - return preprint - - @pytest.fixture() - def preprint_with_non_affiliated_contributor(self, user_without_affiliation): - preprint = PreprintFactory() - preprint.add_contributor( - user_without_affiliation, - permissions='admin', - visible=True - ) - preprint.created = AFFILIATION_TARGET_DATE - timedelta(days=1) - preprint.save() - return preprint - - @pytest.fixture() - def preprint_past_target_date_with_affiliated_contributor(self, user_with_affiliation): - preprint = PreprintFactory() - preprint.add_contributor( - user_with_affiliation, - permissions='admin', - visible=True - ) - preprint.created = AFFILIATION_TARGET_DATE + timedelta(days=1) - preprint.save() - return preprint - - @pytest.mark.parametrize('dry_run', [True, False]) - def test_assign_affiliations_with_affiliated_contributor(self, preprint_with_affiliated_contributor, institution, dry_run): - preprint = preprint_with_affiliated_contributor - preprint.affiliated_institutions.clear() - preprint.save() - - assign_affiliations_to_preprints(dry_run=dry_run) - - if dry_run: - assert not preprint.affiliated_institutions.exists() - else: - assert institution in preprint.affiliated_institutions.all() - - @pytest.mark.parametrize('dry_run', [True, False]) - def test_no_affiliations_for_non_affiliated_contributor(self, preprint_with_non_affiliated_contributor, dry_run): - preprint = preprint_with_non_affiliated_contributor - preprint.affiliated_institutions.clear() - preprint.save() - - assign_affiliations_to_preprints(dry_run=dry_run) - - assert not preprint.affiliated_institutions.exists() - - @pytest.mark.parametrize('dry_run', [True, False]) - def test_exclude_contributor_by_guid(self, preprint_with_affiliated_contributor, user_with_affiliation, institution, dry_run): - preprint = preprint_with_affiliated_contributor - preprint.affiliated_institutions.clear() - preprint.save() - - assert user_with_affiliation.get_affiliated_institutions() - assert user_with_affiliation in preprint.contributors.all() - exclude_guids = {user._id for user in preprint.contributors.all()} - - assign_affiliations_to_preprints(exclude_guids=exclude_guids, dry_run=dry_run) - - assert not preprint.affiliated_institutions.exists() - - @pytest.mark.parametrize('dry_run', [True, False]) - def test_affiliations_from_multiple_contributors(self, institution, dry_run): - institution_not_include = InstitutionFactory() - read_contrib = AuthUserFactory() - read_contrib.add_or_update_affiliated_institution(institution_not_include) - read_contrib.save() - - write_contrib = AuthUserFactory() - write_contrib.add_or_update_affiliated_institution(institution) - write_contrib.save() - - admin_contrib = AuthUserFactory() - institution2 = InstitutionFactory() - admin_contrib.add_or_update_affiliated_institution(institution2) - admin_contrib.save() - - preprint = PreprintFactory() - preprint.affiliated_institutions.clear() - preprint.created = AFFILIATION_TARGET_DATE - timedelta(days=1) - preprint.add_contributor(read_contrib, permissions='read', visible=True) - preprint.add_contributor(write_contrib, permissions='write', visible=True) - preprint.add_contributor(admin_contrib, permissions='admin', visible=True) - preprint.save() - - assign_affiliations_to_preprints(dry_run=dry_run) - - if dry_run: - assert not preprint.affiliated_institutions.exists() - else: - affiliations = set(preprint.affiliated_institutions.all()) - assert affiliations == {institution, institution2} - assert institution_not_include not in affiliations - - @pytest.mark.parametrize('dry_run', [True, False]) - def test_exclude_recent_preprints(self, preprint_past_target_date_with_affiliated_contributor, preprint_with_affiliated_contributor, institution, dry_run): - new_preprint = preprint_past_target_date_with_affiliated_contributor - new_preprint.affiliated_institutions.clear() - new_preprint.save() - - old_preprint = preprint_with_affiliated_contributor - old_preprint.affiliated_institutions.clear() - old_preprint.save() - - assign_affiliations_to_preprints(dry_run=dry_run) - - assert not new_preprint.affiliated_institutions.exists() - if dry_run: - assert not old_preprint.affiliated_institutions.exists() - else: - assert institution in old_preprint.affiliated_institutions.all() diff --git a/osf_tests/management_commands/test_move_egap_regs_to_provider.py b/osf_tests/management_commands/test_move_egap_regs_to_provider.py deleted file mode 100644 index 4e1ac7291aa..00000000000 --- a/osf_tests/management_commands/test_move_egap_regs_to_provider.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest - -from osf_tests.factories import ( - RegistrationFactory, - RegistrationProviderFactory -) - -from osf.models import ( - RegistrationSchema, - RegistrationProvider -) - -from osf.management.commands.move_egap_regs_to_provider import ( - main as move_egap_regs -) - - -@pytest.mark.django_db -class TestEGAPMoveToProvider: - - @pytest.fixture() - def egap_provider(self): - return RegistrationProviderFactory(_id='egap') - - @pytest.fixture() - def non_egap_provider(self): - return RegistrationProvider.get_default() - - @pytest.fixture() - def egap_reg(self): - egap_schema = RegistrationSchema.objects.filter( - name='EGAP Registration' - ).order_by( - '-schema_version' - )[0] - cos = RegistrationProvider.get_default() - return RegistrationFactory(schema=egap_schema, provider=cos) - - @pytest.fixture() - def egap_non_reg(self, non_egap_provider): - return RegistrationFactory(provider=non_egap_provider) - - def test_move_to_provider(self, egap_provider, egap_reg, non_egap_provider, egap_non_reg): - assert egap_reg.provider != egap_provider - assert egap_non_reg.provider != egap_provider - - move_egap_regs(dry_run=False) - - egap_reg.refresh_from_db() - assert egap_reg.provider == egap_provider - assert egap_non_reg.provider != egap_provider diff --git a/osf_tests/management_commands/test_populate_initial_schema_responses.py b/osf_tests/management_commands/test_populate_initial_schema_responses.py deleted file mode 100644 index 18949c09b33..00000000000 --- a/osf_tests/management_commands/test_populate_initial_schema_responses.py +++ /dev/null @@ -1,130 +0,0 @@ -import pytest - -from osf.management.commands.populate_initial_schema_responses import populate_initial_schema_responses -from osf.models import SchemaResponse, SchemaResponseBlock -from osf.utils.workflows import ApprovalStates, RegistrationModerationStates as RegStates -from osf_tests.factories import ProjectFactory, RegistrationFactory -from osf_tests.utils import get_default_test_schema - -DEFAULT_RESPONSES = { - 'q1': 'An answer', 'q2': 'Another answer', 'q3': 'A', 'q4': ['E'], 'q5': '', 'q6': [], -} - -@pytest.fixture -def control_registration(): - return RegistrationFactory() - - -@pytest.fixture -def test_registration(): - registration = RegistrationFactory(schema=get_default_test_schema()) - registration.schema_responses.clear() - registration.registration_responses = dict(DEFAULT_RESPONSES) - registration.save() - return registration - - -@pytest.fixture -def nested_registration(test_registration): - registration = RegistrationFactory( - project=ProjectFactory(parent=test_registration.registered_from), - parent=test_registration - ) - registration.schema_responses.clear() - return registration - - -@pytest.mark.django_db -class TestPopulateInitialSchemaResponses: - - def test_schema_response_created(self, test_registration): - assert not test_registration.schema_responses.exists() - - count = populate_initial_schema_responses() - assert count == 1 - - assert test_registration.schema_responses.count() == 1 - - schema_response = test_registration.schema_responses.get() - assert schema_response.schema == test_registration.registration_schema - assert schema_response.all_responses == test_registration.registration_responses - - @pytest.mark.parametrize( - 'registration_state, schema_response_state', - [ - (RegStates.INITIAL, ApprovalStates.UNAPPROVED), - (RegStates.PENDING, ApprovalStates.PENDING_MODERATION), - (RegStates.ACCEPTED, ApprovalStates.APPROVED), - (RegStates.EMBARGO, ApprovalStates.APPROVED), - (RegStates.PENDING_EMBARGO_TERMINATION, ApprovalStates.APPROVED), - (RegStates.PENDING_WITHDRAW_REQUEST, ApprovalStates.APPROVED), - (RegStates.PENDING_WITHDRAW, ApprovalStates.APPROVED), - (RegStates.WITHDRAWN, ApprovalStates.APPROVED), - (RegStates.REVERTED, ApprovalStates.UNAPPROVED), - (RegStates.REJECTED, ApprovalStates.PENDING_MODERATION), - ] - ) - def test_schema_response_state( - self, test_registration, registration_state, schema_response_state): - test_registration.moderation_state = registration_state.db_name - test_registration.save() - - populate_initial_schema_responses() - - schema_response = test_registration.schema_responses.get() - assert schema_response.state == schema_response_state - - def test_errors_from_invalid_keys_are_ignored(self, test_registration): - test_registration.registration_responses.update({'invalid_key': 'lolol'}) - test_registration.save() - - populate_initial_schema_responses() - - schema_response = test_registration.schema_responses.get() - assert schema_response.all_responses == DEFAULT_RESPONSES - - def test_populate_responses_is_atomic_per_registration(self, test_registration): - invalid_registration = RegistrationFactory() - invalid_registration.schema_responses.clear() - invalid_registration.registered_schema.clear() - - count = populate_initial_schema_responses() - assert count == 1 - - assert test_registration.schema_responses.exists() - assert not invalid_registration.schema_responses.exists() - - def test_dry_run(self, test_registration): - # donfirm that the delete works even if the schema_response isn't IN_PROGRESS - test_registration.moderation_state = RegStates.ACCEPTED.db_name - test_registration.save() - with pytest.raises(RuntimeError): - populate_initial_schema_responses(dry_run=True) - - assert not test_registration.schema_responses.exists() - assert not SchemaResponse.objects.exists() - assert not SchemaResponseBlock.objects.exists() - - def test_batch_size(self): - for _ in range(5): - r = RegistrationFactory() - r.schema_responses.clear() - assert not SchemaResponse.objects.exists() - - count = populate_initial_schema_responses(batch_size=3) - assert count == 3 - - assert SchemaResponse.objects.count() == 3 - - def test_schema_response_not_created_for_registration_with_response(self, control_registration): - control_registration_response = control_registration.schema_responses.get() - - count = populate_initial_schema_responses() - assert count == 0 - - assert control_registration.schema_responses.get() == control_registration_response - - def test_schema_response_not_created_for_nested_registration(self, nested_registration): - count = populate_initial_schema_responses() - assert count == 1 # parent registration - assert not nested_registration.schema_responses.exists() diff --git a/osf_tests/metadata/test_osf_gathering.py b/osf_tests/metadata/test_osf_gathering.py index 33be346e2df..bdac112be6e 100644 --- a/osf_tests/metadata/test_osf_gathering.py +++ b/osf_tests/metadata/test_osf_gathering.py @@ -31,11 +31,13 @@ from website import settings as website_settings from website.project import new_bookmark_collection from osf_tests.metadata._utils import assert_triples +from osf.management.commands.populate_notification_types import populate_notification_types class TestOsfGathering(TestCase): @classmethod def setUpTestData(cls): + populate_notification_types() # users: cls.user__admin = factories.UserFactory() cls.user__readwrite = factories.UserFactory( diff --git a/osf_tests/test_collection_submission.py b/osf_tests/test_collection_submission.py index 2ff2b279a6b..80500eaf979 100644 --- a/osf_tests/test_collection_submission.py +++ b/osf_tests/test_collection_submission.py @@ -336,7 +336,7 @@ def test_remove_success(self, user_role, node, unmoderated_collection_submission unmoderated_collection_submission.remove(user=user, comment='Test Comment') assert unmoderated_collection_submission.state == CollectionSubmissionStates.REMOVED - def test_notify_moderated_removed_admin(self, node, unmoderated_collection_submission, mock_send_grid): + def test_notify_moderated_removed_admin(self, mock_send_grid, node, unmoderated_collection_submission): unmoderated_collection_submission.state_machine.set_state(CollectionSubmissionStates.ACCEPTED) moderator = configure_test_auth(node, UserRoles.ADMIN_USER) diff --git a/osf_tests/test_merging_users.py b/osf_tests/test_merging_users.py index ee13c7bc107..d0b1978f508 100644 --- a/osf_tests/test_merging_users.py +++ b/osf_tests/test_merging_users.py @@ -24,7 +24,7 @@ from tests.utils import run_celery_tasks from waffle.testutils import override_flag from osf.features import ENABLE_GV -from conftest import start_mock_send_grid +from conftest import start_mock_send_grid, start_mock_notification_send SessionStore = import_module(django_conf_settings.SESSION_ENGINE).SessionStore @@ -40,6 +40,7 @@ def setUp(self): with self.context: handlers.celery_before_request() self.mock_send_grid = start_mock_send_grid(self) + self.mock_notification_send = start_mock_notification_send(self) def _add_unconfirmed_user(self): self.unconfirmed = UnconfirmedUserFactory() @@ -297,4 +298,4 @@ def test_merge_doesnt_send_signal(self): with override_flag(ENABLE_GV, active=True): self.user.merge_user(other_user) assert other_user.merged_by._id == self.user._id - assert self.mock_send_grid.called is False + assert self.mock_notification_send.called is False diff --git a/osf_tests/test_registration_moderation_notifications.py b/osf_tests/test_registration_moderation_notifications.py deleted file mode 100644 index 100c15e64e1..00000000000 --- a/osf_tests/test_registration_moderation_notifications.py +++ /dev/null @@ -1,457 +0,0 @@ -import pytest -from unittest import mock -from unittest.mock import call - -from django.utils import timezone -from osf.management.commands.add_notification_subscription import add_reviews_notification_setting -from osf.management.commands.populate_registration_provider_notification_subscriptions import populate_registration_provider_notification_subscriptions - -from osf.migrations import update_provider_auth_groups -from osf.models import Brand, NotificationDigest -from osf.models.action import RegistrationAction -from osf.utils.notifications import ( - notify_submit, - notify_accept_reject, - notify_moderator_registration_requests_withdrawal, - notify_reject_withdraw_request, - notify_withdraw_registration -) -from osf.utils.workflows import RegistrationModerationTriggers, RegistrationModerationStates - -from osf_tests.factories import ( - RegistrationFactory, - AuthUserFactory, - RetractionFactory -) - -from website import settings -from website.notifications import emails, tasks - - -def get_moderator(provider): - user = AuthUserFactory() - provider.add_to_group(user, 'moderator') - return user - - -def get_daily_moderator(provider): - user = AuthUserFactory() - provider.add_to_group(user, 'moderator') - for subscription_type in provider.DEFAULT_SUBSCRIPTIONS: - subscription = provider.notification_subscriptions.get(event_name=subscription_type) - subscription.add_user_to_subscription(user, 'email_digest') - return user - - -# Set USE_EMAIL to true and mock out the default mailer for consistency with other mocked settings -@pytest.mark.django_db -@pytest.mark.usefixtures('mock_send_grid') -class TestRegistrationMachineNotification: - - MOCK_NOW = timezone.now() - - @pytest.fixture(autouse=True) - def setup(self): - populate_registration_provider_notification_subscriptions() - with mock.patch('osf.utils.machines.timezone.now', return_value=self.MOCK_NOW): - yield - - @pytest.fixture() - def contrib(self): - return AuthUserFactory() - - @pytest.fixture() - def admin(self): - return AuthUserFactory() - - @pytest.fixture() - def registration(self, admin, contrib): - registration = RegistrationFactory(creator=admin) - registration.add_contributor(admin, 'admin') - registration.add_contributor(contrib, 'write') - update_provider_auth_groups() - return registration - - @pytest.fixture() - def registration_with_retraction(self, admin, contrib): - sanction = RetractionFactory(user=admin) - registration = sanction.target_registration - registration.update_moderation_state() - registration.add_contributor(admin, 'admin') - registration.add_contributor(contrib, 'write') - registration.save() - return registration - - @pytest.fixture() - def provider(self, registration): - return registration.provider - - @pytest.fixture() - def moderator(self, provider): - user = AuthUserFactory() - provider.add_to_group(user, 'moderator') - return user - - @pytest.fixture() - def daily_moderator(self, provider): - user = AuthUserFactory() - provider.add_to_group(user, 'moderator') - for subscription_type in provider.DEFAULT_SUBSCRIPTIONS: - subscription = provider.notification_subscriptions.get(event_name=subscription_type) - subscription.add_user_to_subscription(user, 'email_digest') - return user - - @pytest.fixture() - def accept_action(self, registration, admin): - registration_action = RegistrationAction.objects.create( - creator=admin, - target=registration, - trigger=RegistrationModerationTriggers.ACCEPT_SUBMISSION.db_name, - from_state=RegistrationModerationStates.INITIAL.db_name, - to_state=RegistrationModerationStates.ACCEPTED.db_name, - comment='yo' - ) - return registration_action - - @pytest.fixture() - def withdraw_request_action(self, registration, admin): - registration_action = RegistrationAction.objects.create( - creator=admin, - target=registration, - trigger=RegistrationModerationTriggers.REQUEST_WITHDRAWAL.db_name, - from_state=RegistrationModerationStates.ACCEPTED.db_name, - to_state=RegistrationModerationStates.PENDING_WITHDRAW.db_name, - comment='yo' - ) - return registration_action - - @pytest.fixture() - def withdraw_action(self, registration, admin): - registration_action = RegistrationAction.objects.create( - creator=admin, - target=registration, - trigger=RegistrationModerationTriggers.ACCEPT_WITHDRAWAL.db_name, - from_state=RegistrationModerationStates.PENDING_WITHDRAW.db_name, - to_state=RegistrationModerationStates.WITHDRAWN.db_name, - comment='yo' - ) - return registration_action - - def test_submit_notifications(self, registration, moderator, admin, contrib, provider, mock_send_grid): - """ - [REQS-96] "As moderator of branded registry, I receive email notification upon admin author(s) submission approval" - :param mock_email: - :param draft_registration: - :return: - """ - # Set up mock_send_mail as a pass-through to the original function. - # This lets us assert on the call/args and also implicitly ensures - # that the email acutally renders as normal in send_mail. - notify_submit(registration, admin) - - assert len(mock_send_grid.call_args_list) == 2 - admin_message, contrib_message = mock_send_grid.call_args_list - - assert admin_message[1]['to_addr'] == admin.email - assert contrib_message[1]['to_addr'] == contrib.email - assert admin_message[1]['subject'] == 'Confirmation of your submission to OSF Registries' - assert contrib_message[1]['subject'] == 'Confirmation of your submission to OSF Registries' - - assert NotificationDigest.objects.count() == 1 - digest = NotificationDigest.objects.last() - - assert digest.user == moderator - assert digest.send_type == 'email_transactional' - assert digest.event == 'new_pending_submissions' - - def test_accept_notifications(self, registration, moderator, admin, contrib, accept_action): - """ - [REQS-98] "As registration authors, we receive email notification upon moderator acceptance" - :param draft_registration: - :return: - """ - add_reviews_notification_setting('global_reviews') - - # Set up mock_email as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders correctly. - store_emails = emails.store_emails - with mock.patch.object(emails, 'store_emails', side_effect=store_emails) as mock_email: - notify_accept_reject(registration, registration.creator, accept_action, RegistrationModerationStates) - - assert len(mock_email.call_args_list) == 2 - - admin_message, contrib_message = mock_email.call_args_list - - assert admin_message == call( - [admin._id], - 'email_transactional', - 'global_reviews', - admin, - registration, - self.MOCK_NOW, - comment='yo', - document_type='registration', - domain='http://localhost:5000/', - draft_registration=registration.draft_registration.get(), - has_psyarxiv_chronos_text=False, - is_creator=True, - is_rejected=False, - notify_comment='yo', - provider_contact_email=settings.OSF_CONTACT_EMAIL, - provider_support_email=settings.OSF_SUPPORT_EMAIL, - provider_url='http://localhost:5000/', - requester=admin, - reviewable=registration, - template='reviews_submission_status', - was_pending=False, - workflow=None - ) - - assert contrib_message == call( - [contrib._id], - 'email_transactional', - 'global_reviews', - admin, - registration, - self.MOCK_NOW, - comment='yo', - document_type='registration', - domain='http://localhost:5000/', - draft_registration=registration.draft_registration.get(), - has_psyarxiv_chronos_text=False, - is_creator=False, - is_rejected=False, - notify_comment='yo', - provider_contact_email=settings.OSF_CONTACT_EMAIL, - provider_support_email=settings.OSF_SUPPORT_EMAIL, - provider_url='http://localhost:5000/', - reviewable=registration, - requester=admin, - template='reviews_submission_status', - was_pending=False, - workflow=None - ) - - def test_reject_notifications(self, registration, moderator, admin, contrib, accept_action): - """ - [REQS-100] "As authors of rejected by moderator registration, we receive email notification of registration returned - to draft state" - :param draft_registration: - :return: - """ - add_reviews_notification_setting('global_reviews') - - # Set up mock_email as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders correctly - store_emails = emails.store_emails - with mock.patch.object(emails, 'store_emails', side_effect=store_emails) as mock_email: - notify_accept_reject(registration, registration.creator, accept_action, RegistrationModerationStates) - - assert len(mock_email.call_args_list) == 2 - - admin_message, contrib_message = mock_email.call_args_list - - assert admin_message == call( - [admin._id], - 'email_transactional', - 'global_reviews', - admin, - registration, - self.MOCK_NOW, - comment='yo', - document_type='registration', - domain='http://localhost:5000/', - draft_registration=registration.draft_registration.get(), - has_psyarxiv_chronos_text=False, - is_creator=True, - is_rejected=False, - notify_comment='yo', - provider_contact_email=settings.OSF_CONTACT_EMAIL, - provider_support_email=settings.OSF_SUPPORT_EMAIL, - provider_url='http://localhost:5000/', - reviewable=registration, - requester=admin, - template='reviews_submission_status', - was_pending=False, - workflow=None - ) - - assert contrib_message == call( - [contrib._id], - 'email_transactional', - 'global_reviews', - admin, - registration, - self.MOCK_NOW, - comment='yo', - document_type='registration', - domain='http://localhost:5000/', - draft_registration=registration.draft_registration.get(), - has_psyarxiv_chronos_text=False, - is_creator=False, - is_rejected=False, - notify_comment='yo', - provider_contact_email=settings.OSF_CONTACT_EMAIL, - provider_support_email=settings.OSF_SUPPORT_EMAIL, - provider_url='http://localhost:5000/', - reviewable=registration, - requester=admin, - template='reviews_submission_status', - was_pending=False, - workflow=None - ) - - def test_notify_moderator_registration_requests_withdrawal_notifications(self, moderator, daily_moderator, registration, admin, provider): - """ - [REQS-106] "As moderator, I receive registration withdrawal request notification email" - - :param mock_email: - :param draft_registration: - :param contrib: - :return: - """ - assert NotificationDigest.objects.count() == 0 - notify_moderator_registration_requests_withdrawal(registration, admin) - - assert NotificationDigest.objects.count() == 2 - - daily_digest = NotificationDigest.objects.get(send_type='email_digest') - transactional_digest = NotificationDigest.objects.get(send_type='email_transactional') - assert daily_digest.user == daily_moderator - assert transactional_digest.user == moderator - - for digest in (daily_digest, transactional_digest): - assert 'requested withdrawal' in digest.message - assert digest.event == 'new_pending_withdraw_requests' - assert digest.provider == provider - - def test_withdrawal_registration_accepted_notifications(self, registration_with_retraction, contrib, admin, withdraw_action, mock_send_grid): - """ - [REQS-109] "As registration author(s) requesting registration withdrawal, we receive notification email of moderator - decision" - - :param mock_email: - :param draft_registration: - :param contrib: - :return: - """ - # Set up mock_send_mail as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders as normal in send_mail. - notify_withdraw_registration(registration_with_retraction, withdraw_action) - - assert len(mock_send_grid.call_args_list) == 2 - admin_message, contrib_message = mock_send_grid.call_args_list - - assert admin_message[1]['to_addr'] == admin.email - assert contrib_message[1]['to_addr'] == contrib.email - assert admin_message[1]['subject'] == 'Your registration has been withdrawn' - assert contrib_message[1]['subject'] == 'Your registration has been withdrawn' - - def test_withdrawal_registration_rejected_notifications(self, registration, contrib, admin, withdraw_request_action, mock_send_grid): - """ - [REQS-109] "As registration author(s) requesting registration withdrawal, we receive notification email of moderator - decision" - - :param mock_email: - :param draft_registration: - :param contrib: - :return: - """ - # Set up mock_send_mail as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders as normal in send_mail. - notify_reject_withdraw_request(registration, withdraw_request_action) - - assert len(mock_send_grid.call_args_list) == 2 - admin_message, contrib_message = mock_send_grid.call_args_list - - assert admin_message[1]['to_addr'] == admin.email - assert contrib_message[1]['to_addr'] == contrib.email - assert admin_message[1]['subject'] == 'Your withdrawal request has been declined' - assert contrib_message[1]['subject'] == 'Your withdrawal request has been declined' - - def test_withdrawal_registration_force_notifications(self, registration_with_retraction, contrib, admin, withdraw_action, mock_send_grid): - """ - [REQS-109] "As registration author(s) requesting registration withdrawal, we receive notification email of moderator - decision" - - :param mock_email: - :param draft_registration: - :param contrib: - :return: - """ - # Set up mock_send_mail as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders as normal in send_mail. - notify_withdraw_registration(registration_with_retraction, withdraw_action) - - assert len(mock_send_grid.call_args_list) == 2 - admin_message, contrib_message = mock_send_grid.call_args_list - - assert admin_message[1]['to_addr'] == admin.email - assert contrib_message[1]['to_addr'] == contrib.email - assert admin_message[1]['subject'] == 'Your registration has been withdrawn' - assert contrib_message[1]['subject'] == 'Your registration has been withdrawn' - - @pytest.mark.parametrize( - 'digest_type, expected_recipient', - [('email_transactional', get_moderator), ('email_digest', get_daily_moderator)] - ) - def test_submissions_and_withdrawals_both_appear_in_moderator_digest(self, digest_type, expected_recipient, registration, admin, provider, mock_send_grid): - # Invoke the fixture function to get the recipient because parametrize - expected_recipient = expected_recipient(provider) - - notify_submit(registration, admin) - notify_moderator_registration_requests_withdrawal(registration, admin) - - # One user, one provider => one email - grouped_notifications = list(tasks.get_moderators_emails(digest_type)) - assert len(grouped_notifications) == 1 - - moderator_message = grouped_notifications[0] - assert moderator_message['user_id'] == expected_recipient._id - assert moderator_message['provider_id'] == provider.id - - # No fixed ordering of the entires, so just make sure that - # keywords for each action type are in some message - updates = moderator_message['info'] - assert len(updates) == 2 - assert any('submitted' in entry['message'] for entry in updates) - assert any('requested withdrawal' in entry['message'] for entry in updates) - - @pytest.mark.parametrize('digest_type', ['email_transactional', 'email_digest']) - def test_submsissions_and_withdrawals_do_not_appear_in_node_digest(self, digest_type, registration, admin, moderator, daily_moderator): - notify_submit(registration, admin) - notify_moderator_registration_requests_withdrawal(registration, admin) - - assert not list(tasks.get_users_emails(digest_type)) - - def test_moderator_digest_emails_render(self, registration, admin, moderator, mock_send_grid): - notify_moderator_registration_requests_withdrawal(registration, admin) - # Set up mock_send_mail as a pass-through to the original function. - # This lets us assert on the call count/args and also implicitly - # ensures that the email acutally renders as normal in send_mail. - tasks._send_reviews_moderator_emails('email_transactional') - - mock_send_grid.assert_called() - - def test_branded_provider_notification_renders(self, registration, admin, moderator): - # Set brand details to be checked in notify_base.mako - provider = registration.provider - provider.brand = Brand.objects.create(hero_logo_image='not-a-url', primary_color='#FFA500') - provider.name = 'Test Provider' - provider.save() - - # Implicitly check that all of our uses of notify_base.mako render with branded details: - # - # notify_submit renders reviews_submission_confirmation using context from - # osf.utils.notifications and stores emails to be picked up in the moderator digest - # - # _send_Reviews_moderator_emails renders digest_reviews_moderators using context from - # website.notifications.tasks - notify_submit(registration, admin) - tasks._send_reviews_moderator_emails('email_transactional') - assert True # everything rendered! diff --git a/osf_tests/test_s3_folder_migration.py b/osf_tests/test_s3_folder_migration.py deleted file mode 100644 index 067e63c34a3..00000000000 --- a/osf_tests/test_s3_folder_migration.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from osf.management.commands.add_colon_delim_to_s3_buckets import update_folder_names, reverse_update_folder_names - -@pytest.mark.django_db -class TestUpdateFolderNamesMigration: - - def test_update_folder_names_migration(self): - from addons.s3.models import NodeSettings - from addons.s3.tests.factories import S3NodeSettingsFactory - # Create sample folder names and IDs - S3NodeSettingsFactory(folder_name='Folder 1 (Location 1)', folder_id='folder1') - S3NodeSettingsFactory(folder_name='Folder 2', folder_id='folder2') - S3NodeSettingsFactory(folder_name='Folder 3 (Location 3)', folder_id='folder3') - S3NodeSettingsFactory(folder_name='Folder 4:/ (Location 4)', folder_id='folder4:/') - - update_folder_names() - - # Verify updated folder names and IDs - updated_folder_names_ids = NodeSettings.objects.values_list('folder_name', 'folder_id') - expected_updated_folder_names_ids = { - ('Folder 1:/ (Location 1)', 'folder1:/'), - ('Folder 2:/', 'folder2:/'), - ('Folder 3:/ (Location 3)', 'folder3:/'), - ('Folder 3:/ (Location 3)', 'folder3:/'), - ('Folder 4:/ (Location 4)', 'folder4:/'), - - } - assert set(updated_folder_names_ids) == expected_updated_folder_names_ids - - # Reverse the migration - reverse_update_folder_names() - - # Verify the folder names and IDs after the reverse migration - reverted_folder_names_ids = NodeSettings.objects.values_list('folder_name', 'folder_id') - expected_reverted_folder_names_ids = { - ('Folder 1 (Location 1)', 'folder1'), - ('Folder 2', 'folder2'), - ('Folder 3 (Location 3)', 'folder3'), - ('Folder 4 (Location 4)', 'folder4'), - } - assert set(reverted_folder_names_ids) == expected_reverted_folder_names_ids diff --git a/osf_tests/test_schema_responses.py b/osf_tests/test_schema_responses.py index 40965c7cf31..cc50d57a2e3 100644 --- a/osf_tests/test_schema_responses.py +++ b/osf_tests/test_schema_responses.py @@ -96,6 +96,7 @@ def revised_response(initial_response): @pytest.mark.enable_bookmark_creation @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestCreateSchemaResponse(): def test_create_initial_response_sets_attributes(self, registration, schema): @@ -142,11 +143,11 @@ def test_create_initial_response_assigns_default_values(self, registration): for block in response.response_blocks.all(): assert block.response == DEFAULT_SCHEMA_RESPONSE_VALUES[block.schema_key] - def test_create_initial_response_does_not_notify(self, registration, admin_user, mock_send_grid): + def test_create_initial_response_does_not_notify(self, registration, admin_user, mock_notification_send): schema_response.SchemaResponse.create_initial_response( parent=registration, initiator=admin_user ) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_create_initial_response_fails_if_no_schema_and_no_parent_schema(self, registration): registration.registered_schema.clear() @@ -543,6 +544,7 @@ def test_delete_fails_if_state_is_invalid(self, invalid_response_state, initial_ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestUnmoderatedSchemaResponseApprovalFlows(): def test_submit_response_adds_pending_approvers( @@ -584,13 +586,13 @@ def test_submit_response_notification( assert mock_send_grid.called - def test_no_submit_notification_on_initial_response(self, initial_response, admin_user, mock_send_grid): + def test_no_submit_notification_on_initial_response(self, initial_response, admin_user, mock_notification_send): initial_response.approvals_state_machine.set_state(ApprovalStates.IN_PROGRESS) initial_response.update_responses({'q1': 'must change one response or can\'t submit'}) initial_response.revision_justification = 'has for valid revision_justification for submission' initial_response.save() initial_response.submit(user=admin_user, required_approvers=[admin_user]) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_submit_response_requires_user(self, initial_response, admin_user): initial_response.approvals_state_machine.set_state(ApprovalStates.IN_PROGRESS) @@ -682,13 +684,13 @@ def test_approve_response_notification( revised_response.approve(user=alternate_user) assert mock_send_grid.called - def test_no_approve_notification_on_initial_response(self, initial_response, admin_user, mock_send_grid): + def test_no_approve_notification_on_initial_response(self, initial_response, admin_user, mock_notification_send): initial_response.approvals_state_machine.set_state(ApprovalStates.UNAPPROVED) initial_response.save() initial_response.pending_approvers.add(admin_user) initial_response.approve(user=admin_user) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_approve_response_requires_user(self, initial_response, admin_user): initial_response.approvals_state_machine.set_state(ApprovalStates.UNAPPROVED) @@ -748,13 +750,13 @@ def test_reject_response_notification( assert mock_send_grid.called - def test_no_reject_notification_on_initial_response(self, initial_response, admin_user, mock_send_grid): + def test_no_reject_notification_on_initial_response(self, initial_response, admin_user, mock_notification_send): initial_response.approvals_state_machine.set_state(ApprovalStates.UNAPPROVED) initial_response.save() initial_response.pending_approvers.add(admin_user) initial_response.reject(user=admin_user) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_reject_response_requires_user(self, initial_response, admin_user): initial_response.approvals_state_machine.set_state(ApprovalStates.UNAPPROVED) @@ -802,6 +804,7 @@ def test_internal_accept_clears_pending_approvers(self, initial_response, admin_ @pytest.mark.django_db @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestModeratedSchemaResponseApprovalFlows(): @pytest.fixture @@ -909,12 +912,12 @@ def test_moderator_accept_notification( assert mock_send_grid.called def test_no_moderator_accept_notification_on_initial_response( - self, initial_response, moderator, mock_send_grid): + self, initial_response, moderator, mock_notification_send): initial_response.approvals_state_machine.set_state(ApprovalStates.PENDING_MODERATION) initial_response.save() initial_response.accept(user=moderator) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_moderator_reject(self, initial_response, admin_user, moderator): initial_response.approvals_state_machine.set_state(ApprovalStates.PENDING_MODERATION) @@ -947,12 +950,12 @@ def test_moderator_reject_notification( assert mock_send_grid.called def test_no_moderator_reject_notification_on_initial_response( - self, initial_response, moderator, mock_send_grid): + self, initial_response, moderator, mock_notification_send): initial_response.approvals_state_machine.set_state(ApprovalStates.PENDING_MODERATION) initial_response.save() initial_response.reject(user=moderator) - assert not mock_send_grid.called + assert not mock_notification_send.called def test_moderator_cannot_submit(self, initial_response, moderator): initial_response.approvals_state_machine.set_state(ApprovalStates.IN_PROGRESS) diff --git a/osf_tests/test_user.py b/osf_tests/test_user.py index 3a2e508dd2d..70d3a7ceb17 100644 --- a/osf_tests/test_user.py +++ b/osf_tests/test_user.py @@ -886,6 +886,7 @@ def test_get_user_by_cookie_no_session(self): @pytest.mark.usefixtures('mock_send_grid') +@pytest.mark.usefixtures('mock_notification_send') class TestChangePassword: def test_change_password(self, user): @@ -897,19 +898,19 @@ def test_change_password(self, user): user.change_password(old_password, new_password, confirm_password) assert bool(user.check_password(new_password)) is True - def test_set_password_notify_default(self, mock_send_grid, user): + def test_set_password_notify_default(self, mock_notification_send, user): old_password = 'password' user.set_password(old_password) user.save() - assert mock_send_grid.called is True + assert mock_notification_send.called is True - def test_set_password_no_notify(self, mock_send_grid, user): + def test_set_password_no_notify(self, mock_notification_send, user): old_password = 'password' user.set_password(old_password, notify=False) user.save() - assert mock_send_grid.called is False + assert mock_notification_send.called is False - def test_check_password_upgrade_hasher_no_notify(self, mock_send_grid, user, settings): + def test_check_password_upgrade_hasher_no_notify(self, mock_notification_send, user, settings): # NOTE: settings fixture comes from pytest-django. # changes get reverted after tests run settings.PASSWORD_HASHERS = ( @@ -920,7 +921,7 @@ def test_check_password_upgrade_hasher_no_notify(self, mock_send_grid, user, set user.password = 'sha1$lNb72DKWDv6P$e6ae16dada9303ae0084e14fc96659da4332bb05' user.check_password(raw_password) assert user.password.startswith('md5$') - assert mock_send_grid.called is False + assert mock_notification_send.called is False def test_change_password_invalid(self, old_password=None, new_password=None, confirm_password=None, error_message='Old password is invalid'): diff --git a/osf_tests/utils.py b/osf_tests/utils.py index a8364a15478..b3f3c92bc88 100644 --- a/osf_tests/utils.py +++ b/osf_tests/utils.py @@ -16,7 +16,7 @@ Sanction, RegistrationProvider, RegistrationSchema, - NotificationSubscription + NotificationSubscriptionLegacy ) from osf.utils.migrations import create_schema_blocks_for_atomic_schema @@ -229,7 +229,7 @@ def _ensure_subscriptions(provider): Avoid that. ''' for subscription in provider.DEFAULT_SUBSCRIPTIONS: - NotificationSubscription.objects.get_or_create( + NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{provider._id}_{subscription}', event_name=subscription, provider=provider diff --git a/scripts/add_global_subscriptions.py b/scripts/add_global_subscriptions.py index b326c6f9f67..52746875d79 100644 --- a/scripts/add_global_subscriptions.py +++ b/scripts/add_global_subscriptions.py @@ -6,13 +6,13 @@ import logging import sys +from osf.models.notifications import NotificationSubscriptionLegacy from website.app import setup_django setup_django() from django.apps import apps from django.db import transaction from website.app import init_app -from osf.models import NotificationSubscription from website.notifications import constants from website.notifications.utils import to_subscription_key @@ -35,10 +35,10 @@ def add_global_subscriptions(dry=True): for user_event in user_events: user_event_id = to_subscription_key(user._id, user_event) - subscription = NotificationSubscription.load(user_event_id) + subscription = NotificationSubscriptionLegacy.load(user_event_id) if not subscription: logger.info(f'No {user_event} subscription found for user {user._id}. Subscribing...') - subscription = NotificationSubscription(_id=user_event_id, owner=user, event_name=user_event) + subscription = NotificationSubscriptionLegacy(_id=user_event_id, owner=user, event_name=user_event) subscription.save() # Need to save in order to access m2m fields subscription.add_user_to_subscription(user, notification_type) subscription.save() diff --git a/scripts/remove_notification_subscriptions_from_registrations.py b/scripts/remove_notification_subscriptions_from_registrations.py index 8984cb25b50..94b20a19a93 100644 --- a/scripts/remove_notification_subscriptions_from_registrations.py +++ b/scripts/remove_notification_subscriptions_from_registrations.py @@ -17,7 +17,7 @@ def remove_notification_subscriptions_from_registrations(dry_run=True): Registration = apps.get_model('osf.Registration') NotificationSubscription = apps.get_model('osf.NotificationSubscription') - notifications_to_delete = NotificationSubscription.objects.filter(node__type='osf.registration') + notifications_to_delete = NotificationSubscriptionLegacy.objects.filter(node__type='osf.registration') registrations_affected = Registration.objects.filter( id__in=notifications_to_delete.values_list( 'node_id', flat=True diff --git a/tests/test_adding_contributor_views.py b/tests/test_adding_contributor_views.py index 17c2da39bc3..64d04357ddc 100644 --- a/tests/test_adding_contributor_views.py +++ b/tests/test_adding_contributor_views.py @@ -49,7 +49,8 @@ send_claim_registered_email, ) from website.util.metrics import OsfSourceTags, OsfClaimedTags, provider_source_tag, provider_claimed_tag -from conftest import start_mock_send_grid +from conftest import start_mock_send_grid, start_mock_notification_send + @pytest.mark.enable_implicit_clean @mock.patch('website.mails.settings.USE_EMAIL', True) @@ -65,6 +66,7 @@ def setUp(self): contributor_added.connect(notify_added_contributor) self.mock_send_grid = start_mock_send_grid(self) + self.mock_notification_send = start_mock_notification_send(self) def test_serialize_unregistered_without_record(self): name, email = fake.name(), fake_email() @@ -241,7 +243,7 @@ def test_add_contributors_post_only_sends_one_email_to_registered_user(self): self.app.post(url, json=payload, auth=self.creator.auth) # send_mail should only have been called once - assert self.mock_send_grid.call_count == 1 + assert self.mock_notification_send.call_count == 1 def test_add_contributors_post_sends_email_if_user_not_contributor_on_parent_node(self): # Project has a component with a sub-component @@ -268,7 +270,7 @@ def test_add_contributors_post_sends_email_if_user_not_contributor_on_parent_nod self.app.post(url, json=payload, auth=self.creator.auth) # send_mail is called for both the project and the sub-component - assert self.mock_send_grid.call_count == 2 + assert self.mock_notification_send.call_count == 2 @mock.patch('website.project.views.contributor.send_claim_email') def test_email_sent_when_unreg_user_is_added(self, send_mail): @@ -299,7 +301,7 @@ def test_email_sent_when_reg_user_is_added(self): project = ProjectFactory(creator=self.auth.user) project.add_contributors(contributors, auth=self.auth) project.save() - assert self.mock_send_grid.called + assert self.mock_notification_send.called assert contributor.contributor_added_email_records[project._id]['last_sent'] == approx(int(time.time()), rel=1) @@ -308,17 +310,17 @@ def test_contributor_added_email_sent_to_unreg_user(self): project = ProjectFactory() project.add_unregistered_contributor(fullname=unreg_user.fullname, email=unreg_user.email, auth=Auth(project.creator)) project.save() - assert self.mock_send_grid.called + assert self.mock_notification_send.called def test_forking_project_does_not_send_contributor_added_email(self): project = ProjectFactory() project.fork_node(auth=Auth(project.creator)) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called def test_templating_project_does_not_send_contributor_added_email(self): project = ProjectFactory() project.use_as_template(auth=Auth(project.creator)) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called @mock.patch('website.archiver.tasks.archive') def test_registering_project_does_not_send_contributor_added_email(self, mock_archive): @@ -331,18 +333,18 @@ def test_registering_project_does_not_send_contributor_added_email(self, mock_ar None, provider=provider ) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called def test_notify_contributor_email_does_not_send_before_throttle_expires(self): contributor = UserFactory() project = ProjectFactory() auth = Auth(project.creator) notify_added_contributor(project, contributor, auth) - assert self.mock_send_grid.called + assert self.mock_notification_send.called # 2nd call does not send email because throttle period has not expired notify_added_contributor(project, contributor, auth) - assert self.mock_send_grid.call_count == 1 + assert self.mock_notification_send.call_count == 1 def test_notify_contributor_email_sends_after_throttle_expires(self): throttle = 0.5 @@ -351,37 +353,37 @@ def test_notify_contributor_email_sends_after_throttle_expires(self): project = ProjectFactory() auth = Auth(project.creator) notify_added_contributor(project, contributor, auth, throttle=throttle) - assert self.mock_send_grid.called + assert self.mock_notification_send.called time.sleep(1) # throttle period expires notify_added_contributor(project, contributor, auth, throttle=throttle) - assert self.mock_send_grid.call_count == 2 + assert self.mock_notification_send.call_count == 2 def test_add_contributor_to_fork_sends_email(self): contributor = UserFactory() fork = self.project.fork_node(auth=Auth(self.creator)) fork.add_contributor(contributor, auth=Auth(self.creator)) fork.save() - assert self.mock_send_grid.called - assert self.mock_send_grid.call_count == 1 + assert self.mock_notification_send.called + assert self.mock_notification_send.call_count == 1 def test_add_contributor_to_template_sends_email(self): contributor = UserFactory() template = self.project.use_as_template(auth=Auth(self.creator)) template.add_contributor(contributor, auth=Auth(self.creator)) template.save() - assert self.mock_send_grid.called - assert self.mock_send_grid.call_count == 1 + assert self.mock_notification_send.called + assert self.mock_notification_send.call_count == 1 def test_creating_fork_does_not_email_creator(self): contributor = UserFactory() fork = self.project.fork_node(auth=Auth(self.creator)) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called def test_creating_template_does_not_email_creator(self): contributor = UserFactory() template = self.project.use_as_template(auth=Auth(self.creator)) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called def test_add_multiple_contributors_only_adds_one_log(self): n_logs_pre = self.project.logs.count() @@ -448,6 +450,7 @@ def setUp(self): self.invite_url = f'/api/v1/project/{self.project._primary_key}/invite_contributor/' self.mock_send_grid = start_mock_send_grid(self) + self.mock_notification_send = start_mock_notification_send(self) def test_invite_contributor_post_if_not_in_db(self): name, email = fake.name(), fake_email() @@ -527,7 +530,7 @@ def test_send_claim_email_to_given_email(self): project.save() send_claim_email(email=given_email, unclaimed_user=unreg_user, node=project) - self.mock_send_grid.assert_called() + self.mock_notification_send.assert_called() def test_send_claim_email_to_referrer(self): project = ProjectFactory() @@ -540,7 +543,7 @@ def test_send_claim_email_to_referrer(self): project.save() send_claim_email(email=real_email, unclaimed_user=unreg_user, node=project) - assert self.mock_send_grid.called + assert self.mock_notification_send.called def test_send_claim_email_before_throttle_expires(self): project = ProjectFactory() @@ -552,11 +555,11 @@ def test_send_claim_email_before_throttle_expires(self): ) project.save() send_claim_email(email=fake_email(), unclaimed_user=unreg_user, node=project) - self.mock_send_grid.reset_mock() + self.mock_notification_send.reset_mock() # 2nd call raises error because throttle hasn't expired with pytest.raises(HTTPError): send_claim_email(email=fake_email(), unclaimed_user=unreg_user, node=project) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called @pytest.mark.enable_implicit_clean @@ -594,6 +597,7 @@ def setUp(self): self.project.save() self.mock_send_grid = start_mock_send_grid(self) + self.mock_notification_send = start_mock_notification_send(self) @mock.patch('website.project.views.contributor.send_claim_email') def test_claim_user_already_registered_redirects_to_claim_user_registered(self, claim_email): @@ -704,14 +708,8 @@ def test_claim_user_post_with_registered_user_id(self): res = self.app.post(url, json=payload) # mail was sent - assert self.mock_send_grid.call_count == 2 + assert self.mock_notification_send.call_count == 2 # ... to the correct address - referrer_call = self.mock_send_grid.call_args_list[0] - claimer_call = self.mock_send_grid.call_args_list[1] - - assert referrer_call[1]['to_addr'] == self.referrer.email - assert claimer_call[1]['to_addr'] == reg_user.email - # view returns the correct JSON assert res.json == { 'status': 'success', @@ -726,11 +724,7 @@ def test_send_claim_registered_email(self): unclaimed_user=self.user, node=self.project ) - assert self.mock_send_grid.call_count == 2 - first_call_args = self.mock_send_grid.call_args_list[0][1] - assert first_call_args['to_addr'] == self.referrer.email - second_call_args = self.mock_send_grid.call_args_list[1][1] - assert second_call_args['to_addr'] == reg_user.email + assert self.mock_notification_send.call_count == 2 def test_send_claim_registered_email_before_throttle_expires(self): reg_user = UserFactory() @@ -739,7 +733,7 @@ def test_send_claim_registered_email_before_throttle_expires(self): unclaimed_user=self.user, node=self.project, ) - self.mock_send_grid.reset_mock() + self.mock_notification_send.reset_mock() # second call raises error because it was called before throttle period with pytest.raises(HTTPError): send_claim_registered_email( @@ -747,7 +741,7 @@ def test_send_claim_registered_email_before_throttle_expires(self): unclaimed_user=self.user, node=self.project, ) - assert not self.mock_send_grid.called + assert not self.mock_notification_send.called @mock.patch('website.project.views.contributor.send_claim_registered_email') def test_claim_user_post_with_email_already_registered_sends_correct_email( @@ -935,17 +929,17 @@ def test_claim_user_post_returns_fullname(self): }, ) assert res.json['fullname'] == self.given_name - assert self.mock_send_grid.called + assert self.mock_notification_send.called def test_claim_user_post_if_email_is_different_from_given_email(self): email = fake_email() # email that is different from the one the referrer gave url = f'/api/v1/user/{self.user._primary_key}/{self.project._primary_key}/claim/email/' self.app.post(url, json={'value': email, 'pk': self.user._primary_key} ) - assert self.mock_send_grid.called - assert self.mock_send_grid.call_count == 2 - call_to_invited = self.mock_send_grid.mock_calls[0] + assert self.mock_notification_send.called + assert self.mock_notification_send.call_count == 2 + call_to_invited = self.mock_notification_send.mock_calls[0] call_to_invited.assert_called_with(to_addr=email) - call_to_referrer = self.mock_send_grid.mock_calls[1] + call_to_referrer = self.mock_notification_send.mock_calls[1] call_to_referrer.assert_called_with(to_addr=self.given_email) def test_claim_url_with_bad_token_returns_400(self): diff --git a/tests/test_auth.py b/tests/test_auth.py index 6088c608e67..52156529d92 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -36,7 +36,7 @@ must_have_addon, must_be_addon_authorizer, ) from website.util import api_url_for -from conftest import start_mock_send_grid +from conftest import start_mock_send_grid, start_mock_notification_send from tests.test_cas_authentication import generate_external_user_with_resp @@ -50,6 +50,7 @@ class TestAuthUtils(OsfTestCase): def setUp(self): super().setUp() self.mock_send_grid = start_mock_send_grid(self) + self.start_mock_notification_send = start_mock_notification_send(self) def test_citation_with_only_fullname(self): user = UserFactory() @@ -173,11 +174,7 @@ def test_password_change_sends_email(self): user = UserFactory() user.set_password('killerqueen') user.save() - assert len(self.mock_send_grid.call_args_list) == 1 - empty, kwargs = self.mock_send_grid.call_args - - assert empty == () - assert kwargs['to_addr'] == user.username + assert len(self.start_mock_notification_send.call_args_list) == 1 @mock.patch('framework.auth.utils.requests.post') def test_validate_recaptcha_success(self, req_post): diff --git a/tests/test_events.py b/tests/test_events.py index 866bf6ec337..c9e30273b49 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -131,7 +131,7 @@ def setUp(self): self.user_2 = factories.AuthUserFactory() self.project = factories.ProjectFactory(creator=self.user_1) # subscription - self.sub = factories.NotificationSubscriptionFactory( + self.sub = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + 'file_updated', owner=self.project, event_name='file_updated', @@ -157,7 +157,7 @@ def setUp(self): self.user = factories.UserFactory() self.consolidate_auth = Auth(user=self.user) self.project = factories.ProjectFactory() - self.project_subscription = factories.NotificationSubscriptionFactory( + self.project_subscription = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' @@ -184,7 +184,7 @@ def setUp(self): self.user = factories.UserFactory() self.consolidate_auth = Auth(user=self.user) self.project = factories.ProjectFactory() - self.project_subscription = factories.NotificationSubscriptionFactory( + self.project_subscription = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' @@ -219,7 +219,7 @@ def setUp(self): self.user = factories.UserFactory() self.consolidate_auth = Auth(user=self.user) self.project = factories.ProjectFactory() - self.project_subscription = factories.NotificationSubscriptionFactory( + self.project_subscription = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' @@ -249,7 +249,7 @@ def setUp(self): self.user_2 = factories.AuthUserFactory() self.project = factories.ProjectFactory(creator=self.user_1) # subscription - self.sub = factories.NotificationSubscriptionFactory( + self.sub = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + 'file_updated', owner=self.project, event_name='file_updated', @@ -303,21 +303,21 @@ def setUp(self): ) # Subscriptions # for parent node - self.sub = factories.NotificationSubscriptionFactory( + self.sub = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' ) self.sub.save() # for private node - self.private_sub = factories.NotificationSubscriptionFactory( + self.private_sub = factories.NotificationSubscriptionLegacyFactory( _id=self.private_node._id + '_file_updated', owner=self.private_node, event_name='file_updated' ) self.private_sub.save() # for file subscription - self.file_sub = factories.NotificationSubscriptionFactory( + self.file_sub = factories.NotificationSubscriptionLegacyFactory( _id='{pid}_{wbid}_file_updated'.format( pid=self.project._id, wbid=self.event.waterbutler_id @@ -398,21 +398,21 @@ def setUp(self): ) # Subscriptions # for parent node - self.sub = factories.NotificationSubscriptionFactory( + self.sub = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' ) self.sub.save() # for private node - self.private_sub = factories.NotificationSubscriptionFactory( + self.private_sub = factories.NotificationSubscriptionLegacyFactory( _id=self.private_node._id + '_file_updated', owner=self.private_node, event_name='file_updated' ) self.private_sub.save() # for file subscription - self.file_sub = factories.NotificationSubscriptionFactory( + self.file_sub = factories.NotificationSubscriptionLegacyFactory( _id='{pid}_{wbid}_file_updated'.format( pid=self.project._id, wbid=self.event.waterbutler_id @@ -480,21 +480,21 @@ def setUp(self): ) # Subscriptions # for parent node - self.sub = factories.NotificationSubscriptionFactory( + self.sub = factories.NotificationSubscriptionLegacyFactory( _id=self.project._id + '_file_updated', owner=self.project, event_name='file_updated' ) self.sub.save() # for private node - self.private_sub = factories.NotificationSubscriptionFactory( + self.private_sub = factories.NotificationSubscriptionLegacyFactory( _id=self.private_node._id + '_file_updated', owner=self.private_node, event_name='file_updated' ) self.private_sub.save() # for file subscription - self.file_sub = factories.NotificationSubscriptionFactory( + self.file_sub = factories.NotificationSubscriptionLegacyFactory( _id='{pid}_{wbid}_file_updated'.format( pid=self.project._id, wbid=self.event.waterbutler_id diff --git a/tests/test_notifications.py b/tests/test_notifications.py deleted file mode 100644 index 49c6f1083d2..00000000000 --- a/tests/test_notifications.py +++ /dev/null @@ -1,1174 +0,0 @@ -import collections -from unittest import mock - -import pytest -from babel import dates, Locale -from schema import Schema, And, Use, Or -from django.utils import timezone - -from framework.auth import Auth -from osf.models import Comment, NotificationDigest, NotificationSubscription, Guid, OSFUser - -from website.notifications.tasks import get_users_emails, send_users_email, group_by_node, remove_notifications -from website.notifications.exceptions import InvalidSubscriptionError -from website.notifications import constants -from website.notifications import emails -from website.notifications import utils -from website import mails -from website.profile.utils import get_profile_image_url -from website.project.signals import contributor_removed, node_deleted -from website.reviews import listeners -from website.util import api_url_for -from website.util import web_url_for -from website import settings - -from osf_tests import factories -from osf.utils import permissions -from tests.base import capture_signals -from tests.base import OsfTestCase, NotificationTestCase - - - -class TestNotificationsModels(OsfTestCase): - - def setUp(self): - super().setUp() - # Create project with component - self.user = factories.UserFactory() - self.consolidate_auth = Auth(user=self.user) - self.parent = factories.ProjectFactory(creator=self.user) - self.node = factories.NodeFactory(creator=self.user, parent=self.parent) - - def test_has_permission_on_children(self): - non_admin_user = factories.UserFactory() - parent = factories.ProjectFactory() - parent.add_contributor(contributor=non_admin_user, permissions=permissions.READ) - parent.save() - - node = factories.NodeFactory(parent=parent, category='project') - sub_component = factories.NodeFactory(parent=node) - sub_component.add_contributor(contributor=non_admin_user) - sub_component.save() - sub_component2 = factories.NodeFactory(parent=node) - - assert node.has_permission_on_children(non_admin_user, permissions.READ) - - def test_check_user_has_permission_excludes_deleted_components(self): - non_admin_user = factories.UserFactory() - parent = factories.ProjectFactory() - parent.add_contributor(contributor=non_admin_user, permissions=permissions.READ) - parent.save() - - node = factories.NodeFactory(parent=parent, category='project') - sub_component = factories.NodeFactory(parent=node) - sub_component.add_contributor(contributor=non_admin_user) - sub_component.is_deleted = True - sub_component.save() - sub_component2 = factories.NodeFactory(parent=node) - - assert not node.has_permission_on_children(non_admin_user, permissions.READ) - - def test_check_user_does_not_have_permission_on_private_node_child(self): - non_admin_user = factories.UserFactory() - parent = factories.ProjectFactory() - parent.add_contributor(contributor=non_admin_user, permissions=permissions.READ) - parent.save() - node = factories.NodeFactory(parent=parent, category='project') - sub_component = factories.NodeFactory(parent=node) - - assert not node.has_permission_on_children(non_admin_user,permissions.READ) - - def test_check_user_child_node_permissions_false_if_no_children(self): - non_admin_user = factories.UserFactory() - parent = factories.ProjectFactory() - parent.add_contributor(contributor=non_admin_user, permissions=permissions.READ) - parent.save() - node = factories.NodeFactory(parent=parent, category='project') - - assert not node.has_permission_on_children(non_admin_user,permissions.READ) - - def test_check_admin_has_permissions_on_private_component(self): - parent = factories.ProjectFactory() - node = factories.NodeFactory(parent=parent, category='project') - sub_component = factories.NodeFactory(parent=node) - - assert node.has_permission_on_children(parent.creator,permissions.READ) - - def test_check_user_private_node_child_permissions_excludes_pointers(self): - user = factories.UserFactory() - parent = factories.ProjectFactory() - pointed = factories.ProjectFactory(creator=user) - parent.add_pointer(pointed, Auth(parent.creator)) - parent.save() - - assert not parent.has_permission_on_children(user,permissions.READ) - - def test_new_project_creator_is_subscribed(self): - user = factories.UserFactory() - factories.ProjectFactory(creator=user) - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - assert len(user_subscriptions) == 1 # subscribed to file_updated - assert 'file_updated' in event_types - - def test_new_node_creator_is_not_subscribed(self): - user = factories.UserFactory() - factories.NodeFactory(creator=user) - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - - assert len(user_subscriptions) == 0 - - def test_new_project_creator_is_subscribed_with_global_settings(self): - user = factories.UserFactory() - - factories.NotificationSubscriptionFactory( - _id=user._id + '_' + 'global_file_updated', - user=user, - event_name='global_file_updated' - ).add_user_to_subscription(user, 'none') - - node = factories.ProjectFactory(creator=user) - - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - file_updated_subscription = NotificationSubscription.objects.get(_id=node._id + '_file_updated') - - assert len(user_subscriptions) == 2 # subscribed to both node and user settings - assert 'file_updated' in event_types - assert 'global_file_updated' in event_types - assert file_updated_subscription.none.count() == 1 - assert file_updated_subscription.email_transactional.count() == 0 - - def test_new_node_creator_is_not_subscribed_with_global_settings(self): - user = factories.UserFactory() - - factories.NotificationSubscriptionFactory( - _id=user._id + '_' + 'global_file_updated', - user=user, - event_name='global_file_updated' - ).add_user_to_subscription(user, 'none') - - node = factories.NodeFactory(creator=user) - - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - assert len(user_subscriptions) == 1 # subscribed to only user settings - assert 'global_file_updated' in event_types - - def test_subscribe_user_to_global_notfiications(self): - user = factories.UserFactory() - utils.subscribe_user_to_global_notifications(user) - subscription_event_names = list(user.notification_subscriptions.values_list('event_name', flat=True)) - for event_name in constants.USER_SUBSCRIPTIONS_AVAILABLE: - assert event_name in subscription_event_names - - def test_subscribe_user_to_registration_notifications(self): - registration = factories.RegistrationFactory() - with pytest.raises(InvalidSubscriptionError): - utils.subscribe_user_to_notifications(registration, self.user) - - def test_new_project_creator_is_subscribed_with_default_global_settings(self): - user = factories.UserFactory() - - factories.NotificationSubscriptionFactory( - _id=user._id + '_' + 'global_file_updated', - user=user, - event_name='global_file_updated' - ).add_user_to_subscription(user, 'email_transactional') - - node = factories.ProjectFactory(creator=user) - - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - file_updated_subscription = NotificationSubscription.objects.get(_id=node._id + '_file_updated') - - assert len(user_subscriptions) == 2 # subscribed to both node and user settings - assert 'file_updated' in event_types - assert 'global_file_updated' in event_types - assert file_updated_subscription.email_transactional.count() == 1 - - def test_new_fork_creator_is_subscribed_with_default_global_settings(self): - user = factories.UserFactory() - project = factories.ProjectFactory(creator=user) - - factories.NotificationSubscriptionFactory( - _id=user._id + '_' + 'global_file_updated', - user=user, - event_name='global_file_updated' - ).add_user_to_subscription(user, 'email_transactional') - - node = factories.ForkFactory(project=project) - - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - node_file_updated_subscription = NotificationSubscription.objects.get(_id=node._id + '_file_updated') - project_file_updated_subscription = NotificationSubscription.objects.get(_id=project._id + '_file_updated') - - assert len(user_subscriptions) == 3 # subscribed to project, fork, and user settings - assert 'file_updated' in event_types - assert 'global_file_updated' in event_types - assert node_file_updated_subscription.email_transactional.count() == 1 - assert project_file_updated_subscription.email_transactional.count() == 1 - - def test_new_node_creator_is_not_subscribed_with_default_global_settings(self): - user = factories.UserFactory() - - factories.NotificationSubscriptionFactory( - _id=user._id + '_' + 'global_file_updated', - user=user, - event_name='global_file_updated' - ).add_user_to_subscription(user, 'email_transactional') - - node = factories.NodeFactory(creator=user) - - user_subscriptions = list(utils.get_all_user_subscriptions(user)) - event_types = [sub.event_name for sub in user_subscriptions] - - assert len(user_subscriptions) == 1 # subscribed to only user settings - assert 'global_file_updated' in event_types - - - def test_contributor_subscribed_when_added_to_project(self): - user = factories.UserFactory() - contributor = factories.UserFactory() - project = factories.ProjectFactory(creator=user) - project.add_contributor(contributor=contributor) - contributor_subscriptions = list(utils.get_all_user_subscriptions(contributor)) - event_types = [sub.event_name for sub in contributor_subscriptions] - - assert len(contributor_subscriptions) == 1 - assert 'file_updated' in event_types - - def test_contributor_subscribed_when_added_to_component(self): - user = factories.UserFactory() - contributor = factories.UserFactory() - - factories.NotificationSubscriptionFactory( - _id=contributor._id + '_' + 'global_file_updated', - user=contributor, - event_name='global_file_updated' - ).add_user_to_subscription(contributor, 'email_transactional') - - node = factories.NodeFactory(creator=user) - node.add_contributor(contributor=contributor) - - contributor_subscriptions = list(utils.get_all_user_subscriptions(contributor)) - event_types = [sub.event_name for sub in contributor_subscriptions] - - file_updated_subscription = NotificationSubscription.objects.get(_id=node._id + '_file_updated') - - assert len(contributor_subscriptions) == 2 # subscribed to both node and user settings - assert 'file_updated' in event_types - assert 'global_file_updated' in event_types - assert file_updated_subscription.email_transactional.count() == 1 - - def test_unregistered_contributor_not_subscribed_when_added_to_project(self): - user = factories.AuthUserFactory() - unregistered_contributor = factories.UnregUserFactory() - project = factories.ProjectFactory(creator=user) - project.add_unregistered_contributor( - unregistered_contributor.fullname, - unregistered_contributor.email, - Auth(user), - existing_user=unregistered_contributor - ) - - contributor_subscriptions = list(utils.get_all_user_subscriptions(unregistered_contributor)) - assert len(contributor_subscriptions) == 0 - - -class TestRemoveNodeSignal(OsfTestCase): - - def test_node_subscriptions_and_backrefs_removed_when_node_is_deleted(self): - project = factories.ProjectFactory() - component = factories.NodeFactory(parent=project, creator=project.creator) - - s = NotificationSubscription.objects.filter(email_transactional=project.creator) - assert s.count() == 1 - - s = NotificationSubscription.objects.filter(email_transactional=component.creator) - assert s.count() == 1 - - with capture_signals() as mock_signals: - project.remove_node(auth=Auth(project.creator)) - project.reload() - component.reload() - - assert project.is_deleted - assert component.is_deleted - assert mock_signals.signals_sent() == {node_deleted} - - s = NotificationSubscription.objects.filter(email_transactional=project.creator) - assert s.count() == 0 - - s = NotificationSubscription.objects.filter(email_transactional=component.creator) - assert s.count() == 0 - - with pytest.raises(NotificationSubscription.DoesNotExist): - NotificationSubscription.objects.get(node=project) - - with pytest.raises(NotificationSubscription.DoesNotExist): - NotificationSubscription.objects.get(node=component) - - -def list_or_dict(data): - # Generator only returns lists or dicts from list or dict - if isinstance(data, dict): - for key in data: - if isinstance(data[key], dict) or isinstance(data[key], list): - yield data[key] - elif isinstance(data, list): - for item in data: - if isinstance(item, dict) or isinstance(item, list): - yield item - - -def has(data, sub_data): - # Recursive approach to look for a subset of data in data. - # WARNING: Don't use on huge structures - # :param data: Data structure - # :param sub_data: subset being checked for - # :return: True or False - try: - next(item for item in data if item == sub_data) - return True - except StopIteration: - lists_and_dicts = list_or_dict(data) - for item in lists_and_dicts: - if has(item, sub_data): - return True - return False - - -def subscription_schema(project, structure, level=0): - # builds a schema from a list of nodes and events - # :param project: validation type - # :param structure: list of nodes (another list) and events - # :return: schema - sub_list = [] - for item in list_or_dict(structure): - sub_list.append(subscription_schema(project, item, level=level+1)) - sub_list.append(event_schema(level)) - - node_schema = { - 'node': { - 'id': Use(type(project._id), error=f'node_id{level}'), - 'title': Use(type(project.title), error=f'node_title{level}'), - 'url': Use(type(project.url), error=f'node_{level}') - }, - 'kind': And(str, Use(lambda s: s in ('node', 'folder'), - error=f"kind didn't match node or folder {level}")), - 'nodeType': Use(lambda s: s in ('project', 'component'), error='nodeType not project or component'), - 'category': Use(lambda s: s in settings.NODE_CATEGORY_MAP, error='category not in settings.NODE_CATEGORY_MAP'), - 'permissions': { - 'view': Use(lambda s: s in (True, False), error='view permissions is not True/False') - }, - 'children': sub_list - } - if level == 0: - return Schema([node_schema]) - return node_schema - - -def event_schema(level=None): - return { - 'event': { - 'title': And(Use(str, error=f'event_title{level} not a string'), - Use(lambda s: s in constants.NOTIFICATION_TYPES, - error=f'event_title{level} not in list')), - 'description': And(Use(str, error=f'event_desc{level} not a string'), - Use(lambda s: s in constants.NODE_SUBSCRIPTIONS_AVAILABLE, - error=f'event_desc{level} not in list')), - 'notificationType': And(str, Or('adopt_parent', lambda s: s in constants.NOTIFICATION_TYPES)), - 'parent_notification_type': Or(None, 'adopt_parent', lambda s: s in constants.NOTIFICATION_TYPES) - }, - 'kind': 'event', - 'children': And(list, lambda l: len(l) == 0) - } - - -class TestNotificationUtils(OsfTestCase): - - def setUp(self): - super().setUp() - self.user = factories.UserFactory() - self.project = factories.ProjectFactory(creator=self.user) - - self.user.notifications_configured[self.project._id] = True - self.user.save() - - self.node = factories.NodeFactory(parent=self.project, creator=self.user) - - self.user_subscription = [ - factories.NotificationSubscriptionFactory( - _id=self.user._id + '_' + 'global_file_updated', - user=self.user, - event_name='global_file_updated' - )] - - for x in self.user_subscription: - x.save() - for x in self.user_subscription: - x.email_transactional.add(self.user) - for x in self.user_subscription: - x.save() - - def test_to_subscription_key(self): - key = utils.to_subscription_key('xyz', 'comments') - assert key == 'xyz_comments' - - def test_from_subscription_key(self): - parsed_key = utils.from_subscription_key('xyz_comment_replies') - assert parsed_key == { - 'uid': 'xyz', - 'event': 'comment_replies' - } - - def test_get_configured_project_ids_does_not_return_user_or_node_ids(self): - configured_nodes = utils.get_configured_projects(self.user) - configured_ids = [n._id for n in configured_nodes] - # No duplicates! - assert len(configured_nodes) == 1 - - assert self.project._id in configured_ids - assert self.node._id not in configured_ids - assert self.user._id not in configured_ids - - def test_get_configured_project_ids_excludes_deleted_projects(self): - project = factories.ProjectFactory() - project.is_deleted = True - project.save() - assert project not in utils.get_configured_projects(self.user) - - def test_get_configured_project_ids_excludes_node_with_project_category(self): - node = factories.NodeFactory(parent=self.project, category='project') - assert node not in utils.get_configured_projects(self.user) - - def test_get_configured_project_ids_includes_top_level_private_projects_if_subscriptions_on_node(self): - private_project = factories.ProjectFactory() - node = factories.NodeFactory(parent=private_project) - node_comments_subscription = factories.NotificationSubscriptionFactory( - _id=node._id + '_' + 'comments', - node=node, - event_name='comments' - ) - node_comments_subscription.save() - node_comments_subscription.email_transactional.add(node.creator) - node_comments_subscription.save() - - node.creator.notifications_configured[node._id] = True - node.creator.save() - configured_project_nodes = utils.get_configured_projects(node.creator) - assert private_project in configured_project_nodes - - def test_get_configured_project_ids_excludes_private_projects_if_no_subscriptions_on_node(self): - user = factories.UserFactory() - - private_project = factories.ProjectFactory() - node = factories.NodeFactory(parent=private_project) - node.add_contributor(user) - - utils.remove_contributor_from_subscriptions(node, user) - - configured_project_nodes = utils.get_configured_projects(user) - assert private_project not in configured_project_nodes - - def test_format_user_subscriptions(self): - data = utils.format_user_subscriptions(self.user) - expected = [ - { - 'event': { - 'title': 'global_file_updated', - 'description': constants.USER_SUBSCRIPTIONS_AVAILABLE['global_file_updated'], - 'notificationType': 'email_transactional', - 'parent_notification_type': None, - }, - 'kind': 'event', - 'children': [] - }, { - 'event': { - 'title': 'global_reviews', - 'description': constants.USER_SUBSCRIPTIONS_AVAILABLE['global_reviews'], - 'notificationType': 'email_transactional', - 'parent_notification_type': None - }, - 'kind': 'event', - 'children': [] - } - ] - - assert data == expected - - def test_format_data_user_settings(self): - data = utils.format_user_and_project_subscriptions(self.user) - expected = [ - { - 'node': { - 'id': self.user._id, - 'title': 'Default Notification Settings', - 'help': 'These are default settings for new projects you create or are added to. Modifying these settings will not modify settings on existing projects.' - }, - 'kind': 'heading', - 'children': utils.format_user_subscriptions(self.user) - }, - { - 'node': { - 'help': 'These are settings for each of your projects. Modifying these settings will only modify the settings for the selected project.', - 'id': '', - 'title': 'Project Notifications' - }, - 'kind': 'heading', - 'children': utils.format_data(self.user, utils.get_configured_projects(self.user)) - }] - assert data == expected - - -class TestCompileSubscriptions(NotificationTestCase): - def setUp(self): - super().setUp() - self.user_1 = factories.UserFactory() - self.user_2 = factories.UserFactory() - self.user_3 = factories.UserFactory() - self.user_4 = factories.UserFactory() - # Base project + 1 project shared with 3 + 1 project shared with 2 - self.base_project = factories.ProjectFactory(is_public=False, creator=self.user_1) - self.shared_node = factories.NodeFactory(parent=self.base_project, is_public=False, creator=self.user_1) - self.private_node = factories.NodeFactory(parent=self.base_project, is_public=False, creator=self.user_1) - # Adding contributors - for node in [self.base_project, self.shared_node, self.private_node]: - node.add_contributor(self.user_2, permissions=permissions.ADMIN) - self.base_project.add_contributor(self.user_3, permissions=permissions.WRITE) - self.shared_node.add_contributor(self.user_3, permissions=permissions.WRITE) - # Setting basic subscriptions - self.base_sub = factories.NotificationSubscriptionFactory( - _id=self.base_project._id + '_file_updated', - node=self.base_project, - event_name='file_updated' - ) - self.base_sub.save() - self.shared_sub = factories.NotificationSubscriptionFactory( - _id=self.shared_node._id + '_file_updated', - node=self.shared_node, - event_name='file_updated' - ) - self.shared_sub.save() - self.private_sub = factories.NotificationSubscriptionFactory( - _id=self.private_node._id + '_file_updated', - node=self.private_node, - event_name='file_updated' - ) - self.private_sub.save() - - def test_no_subscription(self): - node = factories.NodeFactory() - result = emails.compile_subscriptions(node, 'file_updated') - assert {'email_transactional': [], 'none': [], 'email_digest': []} == result - - def test_no_subscribers(self): - node = factories.NodeFactory() - node_sub = factories.NotificationSubscriptionFactory( - _id=node._id + '_file_updated', - node=node, - event_name='file_updated' - ) - node_sub.save() - result = emails.compile_subscriptions(node, 'file_updated') - assert {'email_transactional': [], 'none': [], 'email_digest': []} == result - - def test_creator_subbed_parent(self): - # Basic sub check - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - result = emails.compile_subscriptions(self.base_project, 'file_updated') - assert {'email_transactional': [self.user_1._id], 'none': [], 'email_digest': []} == result - - def test_creator_subbed_to_parent_from_child(self): - # checks the parent sub is the one to appear without a child sub - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - result = emails.compile_subscriptions(self.shared_node, 'file_updated') - assert {'email_transactional': [self.user_1._id], 'none': [], 'email_digest': []} == result - - def test_creator_subbed_to_both_from_child(self): - # checks that only one sub is in the list. - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - self.shared_sub.email_transactional.add(self.user_1) - self.shared_sub.save() - result = emails.compile_subscriptions(self.shared_node, 'file_updated') - assert {'email_transactional': [self.user_1._id], 'none': [], 'email_digest': []} == result - - def test_creator_diff_subs_to_both_from_child(self): - # Check that the child node sub overrides the parent node sub - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - self.shared_sub.none.add(self.user_1) - self.shared_sub.save() - result = emails.compile_subscriptions(self.shared_node, 'file_updated') - assert {'email_transactional': [], 'none': [self.user_1._id], 'email_digest': []} == result - - def test_user_wo_permission_on_child_node_not_listed(self): - # Tests to see if a user without permission gets an Email about a node they cannot see. - self.base_sub.email_transactional.add(self.user_3) - self.base_sub.save() - result = emails.compile_subscriptions(self.private_node, 'file_updated') - assert {'email_transactional': [], 'none': [], 'email_digest': []} == result - - def test_several_nodes_deep(self): - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - node2 = factories.NodeFactory(parent=self.shared_node) - node3 = factories.NodeFactory(parent=node2) - node4 = factories.NodeFactory(parent=node3) - node5 = factories.NodeFactory(parent=node4) - subs = emails.compile_subscriptions(node5, 'file_updated') - assert subs == {'email_transactional': [self.user_1._id], 'email_digest': [], 'none': []} - - def test_several_nodes_deep_precedence(self): - self.base_sub.email_transactional.add(self.user_1) - self.base_sub.save() - node2 = factories.NodeFactory(parent=self.shared_node) - node3 = factories.NodeFactory(parent=node2) - node4 = factories.NodeFactory(parent=node3) - node4_subscription = factories.NotificationSubscriptionFactory( - _id=node4._id + '_file_updated', - node=node4, - event_name='file_updated' - ) - node4_subscription.save() - node4_subscription.email_digest.add(self.user_1) - node4_subscription.save() - node5 = factories.NodeFactory(parent=node4) - subs = emails.compile_subscriptions(node5, 'file_updated') - assert subs == {'email_transactional': [], 'email_digest': [self.user_1._id], 'none': []} - - -class TestMoveSubscription(NotificationTestCase): - def setUp(self): - super().setUp() - self.blank = {key: [] for key in constants.NOTIFICATION_TYPES} # For use where it is blank. - self.user_1 = factories.AuthUserFactory() - self.auth = Auth(user=self.user_1) - self.user_2 = factories.AuthUserFactory() - self.user_3 = factories.AuthUserFactory() - self.user_4 = factories.AuthUserFactory() - self.project = factories.ProjectFactory(creator=self.user_1) - self.private_node = factories.NodeFactory(parent=self.project, is_public=False, creator=self.user_1) - self.sub = factories.NotificationSubscriptionFactory( - _id=self.project._id + '_file_updated', - node=self.project, - event_name='file_updated' - ) - self.sub.email_transactional.add(self.user_1) - self.sub.save() - self.file_sub = factories.NotificationSubscriptionFactory( - _id=self.project._id + '_xyz42_file_updated', - node=self.project, - event_name='xyz42_file_updated' - ) - self.file_sub.save() - - def test_separate_users(self): - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - subbed, removed = utils.separate_users( - self.private_node, [self.user_2._id, self.user_3._id, self.user_4._id] - ) - assert [self.user_2._id, self.user_3._id] == subbed - assert [self.user_4._id] == removed - - def test_event_subs_same(self): - self.file_sub.email_transactional.add(self.user_2, self.user_3, self.user_4) - self.file_sub.save() - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - results = utils.users_to_remove('xyz42_file_updated', self.project, self.private_node) - assert {'email_transactional': [self.user_4._id], 'email_digest': [], 'none': []} == results - - def test_event_nodes_same(self): - self.file_sub.email_transactional.add(self.user_2, self.user_3, self.user_4) - self.file_sub.save() - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - results = utils.users_to_remove('xyz42_file_updated', self.project, self.project) - assert {'email_transactional': [], 'email_digest': [], 'none': []} == results - - def test_move_sub(self): - # Tests old sub is replaced with new sub. - utils.move_subscription(self.blank, 'xyz42_file_updated', self.project, 'abc42_file_updated', self.private_node) - self.file_sub.reload() - assert 'abc42_file_updated' == self.file_sub.event_name - assert self.private_node == self.file_sub.owner - assert self.private_node._id + '_abc42_file_updated' == self.file_sub._id - - def test_move_sub_with_none(self): - # Attempt to reproduce an error that is seen when moving files - self.project.add_contributor(self.user_2, permissions=permissions.WRITE, auth=self.auth) - self.project.save() - self.file_sub.none.add(self.user_2) - self.file_sub.save() - results = utils.users_to_remove('xyz42_file_updated', self.project, self.private_node) - assert {'email_transactional': [], 'email_digest': [], 'none': [self.user_2._id]} == results - - def test_remove_one_user(self): - # One user doesn't have permissions on the node the sub is moved to. Should be listed. - self.file_sub.email_transactional.add(self.user_2, self.user_3, self.user_4) - self.file_sub.save() - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - results = utils.users_to_remove('xyz42_file_updated', self.project, self.private_node) - assert {'email_transactional': [self.user_4._id], 'email_digest': [], 'none': []} == results - - def test_remove_one_user_warn_another(self): - # Two users do not have permissions on new node, but one has a project sub. Both should be listed. - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.save() - self.project.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.project.save() - self.sub.email_digest.add(self.user_3) - self.sub.save() - self.file_sub.email_transactional.add(self.user_2, self.user_4) - - results = utils.users_to_remove('xyz42_file_updated', self.project, self.private_node) - utils.move_subscription(results, 'xyz42_file_updated', self.project, 'abc42_file_updated', self.private_node) - assert {'email_transactional': [self.user_4._id], 'email_digest': [self.user_3._id], 'none': []} == results - assert self.sub.email_digest.filter(id=self.user_3.id).exists() # Is not removed from the project subscription. - - def test_warn_user(self): - # One user with a project sub does not have permission on new node. User should be listed. - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.save() - self.project.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.project.save() - self.sub.email_digest.add(self.user_3) - self.sub.save() - self.file_sub.email_transactional.add(self.user_2) - results = utils.users_to_remove('xyz42_file_updated', self.project, self.private_node) - utils.move_subscription(results, 'xyz42_file_updated', self.project, 'abc42_file_updated', self.private_node) - assert {'email_transactional': [], 'email_digest': [self.user_3._id], 'none': []} == results - assert self.user_3 in self.sub.email_digest.all() # Is not removed from the project subscription. - - def test_user_node_subbed_and_not_removed(self): - self.project.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.project.save() - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - self.sub.email_digest.add(self.user_3) - self.sub.save() - utils.move_subscription(self.blank, 'xyz42_file_updated', self.project, 'abc42_file_updated', self.private_node) - assert not self.file_sub.email_digest.filter().exists() - - # Regression test for commit ea15186 - def test_garrulous_event_name(self): - self.file_sub.email_transactional.add(self.user_2, self.user_3, self.user_4) - self.file_sub.save() - self.private_node.add_contributor(self.user_2, permissions=permissions.ADMIN, auth=self.auth) - self.private_node.add_contributor(self.user_3, permissions=permissions.WRITE, auth=self.auth) - self.private_node.save() - results = utils.users_to_remove('complicated/path_to/some/file/ASDFASDF.txt_file_updated', self.project, self.private_node) - assert {'email_transactional': [], 'email_digest': [], 'none': []} == results - -class TestSendEmails(NotificationTestCase): - def setUp(self): - super().setUp() - self.user = factories.AuthUserFactory() - self.project = factories.ProjectFactory() - self.node = factories.NodeFactory(parent=self.project) - - - def test_get_settings_url_for_node(self): - url = emails.get_settings_url(self.project._id, self.user) - assert url == self.project.absolute_url + 'settings/' - - def test_get_settings_url_for_user(self): - url = emails.get_settings_url(self.user._id, self.user) - assert url == web_url_for('user_notifications', _absolute=True) - - def test_get_node_lineage(self): - node_lineage = emails.get_node_lineage(self.node) - assert node_lineage == [self.project._id, self.node._id] - - def test_fix_locale(self): - assert emails.fix_locale('en') == 'en' - assert emails.fix_locale('de_DE') == 'de_DE' - assert emails.fix_locale('de_de') == 'de_DE' - - def test_localize_timestamp(self): - timestamp = timezone.now() - self.user.timezone = 'America/New_York' - self.user.locale = 'en_US' - self.user.save() - tz = dates.get_timezone(self.user.timezone) - locale = Locale(self.user.locale) - formatted_date = dates.format_date(timestamp, format='full', locale=locale) - formatted_time = dates.format_time(timestamp, format='short', tzinfo=tz, locale=locale) - formatted_datetime = f'{formatted_time} on {formatted_date}' - assert emails.localize_timestamp(timestamp, self.user) == formatted_datetime - - def test_localize_timestamp_empty_timezone(self): - timestamp = timezone.now() - self.user.timezone = '' - self.user.locale = 'en_US' - self.user.save() - tz = dates.get_timezone('Etc/UTC') - locale = Locale(self.user.locale) - formatted_date = dates.format_date(timestamp, format='full', locale=locale) - formatted_time = dates.format_time(timestamp, format='short', tzinfo=tz, locale=locale) - formatted_datetime = f'{formatted_time} on {formatted_date}' - assert emails.localize_timestamp(timestamp, self.user) == formatted_datetime - - def test_localize_timestamp_empty_locale(self): - timestamp = timezone.now() - self.user.timezone = 'America/New_York' - self.user.locale = '' - self.user.save() - tz = dates.get_timezone(self.user.timezone) - locale = Locale('en') - formatted_date = dates.format_date(timestamp, format='full', locale=locale) - formatted_time = dates.format_time(timestamp, format='short', tzinfo=tz, locale=locale) - formatted_datetime = f'{formatted_time} on {formatted_date}' - assert emails.localize_timestamp(timestamp, self.user) == formatted_datetime - - def test_localize_timestamp_handles_unicode(self): - timestamp = timezone.now() - self.user.timezone = 'Europe/Moscow' - self.user.locale = 'ru_RU' - self.user.save() - tz = dates.get_timezone(self.user.timezone) - locale = Locale(self.user.locale) - formatted_date = dates.format_date(timestamp, format='full', locale=locale) - formatted_time = dates.format_time(timestamp, format='short', tzinfo=tz, locale=locale) - formatted_datetime = f'{formatted_time} on {formatted_date}' - assert emails.localize_timestamp(timestamp, self.user) == formatted_datetime - - -@mock.patch('website.mails.settings.USE_EMAIL', True) -@mock.patch('website.mails.settings.USE_CELERY', False) -class TestSendDigest(OsfTestCase): - def setUp(self): - super().setUp() - self.user_1 = factories.UserFactory() - self.user_2 = factories.UserFactory() - self.project = factories.ProjectFactory() - self.timestamp = timezone.now() - - from conftest import start_mock_send_grid - self.mock_send_grid = start_mock_send_grid(self) - - def test_group_notifications_by_user_transactional(self): - send_type = 'email_transactional' - d = factories.NotificationDigestFactory( - user=self.user_1, - send_type=send_type, - timestamp=self.timestamp, - message='Hello', - node_lineage=[self.project._id] - ) - d.save() - d2 = factories.NotificationDigestFactory( - user=self.user_2, - send_type=send_type, - timestamp=self.timestamp, - message='Hello', - node_lineage=[self.project._id] - ) - d2.save() - d3 = factories.NotificationDigestFactory( - user=self.user_2, - send_type='email_digest', - timestamp=self.timestamp, - message='Hello, but this should not appear (this is a digest)', - node_lineage=[self.project._id] - ) - d3.save() - user_groups = list(get_users_emails(send_type)) - expected = [ - { - 'user_id': self.user_1._id, - 'info': [{ - 'message': 'Hello', - 'node_lineage': [str(self.project._id)], - '_id': d._id - }] - }, - { - 'user_id': self.user_2._id, - 'info': [{ - 'message': 'Hello', - 'node_lineage': [str(self.project._id)], - '_id': d2._id - }] - } - ] - - assert len(user_groups) == 2 - assert user_groups == expected - digest_ids = [d._id, d2._id, d3._id] - remove_notifications(email_notification_ids=digest_ids) - - def test_group_notifications_by_user_digest(self): - send_type = 'email_digest' - d2 = factories.NotificationDigestFactory( - user=self.user_2, - send_type=send_type, - timestamp=self.timestamp, - message='Hello', - node_lineage=[self.project._id] - ) - d2.save() - d3 = factories.NotificationDigestFactory( - user=self.user_2, - send_type='email_transactional', - timestamp=self.timestamp, - message='Hello, but this should not appear (this is transactional)', - node_lineage=[self.project._id] - ) - d3.save() - user_groups = list(get_users_emails(send_type)) - expected = [ - { - 'user_id': str(self.user_2._id), - 'info': [{ - 'message': 'Hello', - 'node_lineage': [str(self.project._id)], - '_id': str(d2._id) - }] - } - ] - - assert len(user_groups) == 1 - assert user_groups == expected - digest_ids = [d2._id, d3._id] - remove_notifications(email_notification_ids=digest_ids) - - def test_send_users_email_called_with_correct_args(self): - send_type = 'email_transactional' - d = factories.NotificationDigestFactory( - send_type=send_type, - event='comment_replies', - timestamp=timezone.now(), - message='Hello', - node_lineage=[factories.ProjectFactory()._id] - ) - d.save() - user_groups = list(get_users_emails(send_type)) - send_users_email(send_type) - mock_send_grid = self.mock_send_grid - assert mock_send_grid.called - assert mock_send_grid.call_count == len(user_groups) - - last_user_index = len(user_groups) - 1 - user = OSFUser.load(user_groups[last_user_index]['user_id']) - args, kwargs = mock_send_grid.call_args - - assert kwargs['to_addr'] == user.username - - def test_send_users_email_ignores_disabled_users(self): - send_type = 'email_transactional' - d = factories.NotificationDigestFactory( - send_type=send_type, - event='comment_replies', - timestamp=timezone.now(), - message='Hello', - node_lineage=[factories.ProjectFactory()._id] - ) - d.save() - - user_groups = list(get_users_emails(send_type)) - last_user_index = len(user_groups) - 1 - - user = OSFUser.load(user_groups[last_user_index]['user_id']) - user.is_disabled = True - user.save() - - send_users_email(send_type) - assert not self.mock_send_grid.called - - def test_remove_sent_digest_notifications(self): - d = factories.NotificationDigestFactory( - event='comment_replies', - timestamp=timezone.now(), - message='Hello', - node_lineage=[factories.ProjectFactory()._id] - ) - digest_id = d._id - remove_notifications(email_notification_ids=[digest_id]) - with pytest.raises(NotificationDigest.DoesNotExist): - NotificationDigest.objects.get(_id=digest_id) - - -@mock.patch('website.mails.settings.USE_EMAIL', True) -@mock.patch('website.mails.settings.USE_CELERY', False) -class TestNotificationsReviews(OsfTestCase): - def setUp(self): - super().setUp() - self.provider = factories.PreprintProviderFactory(_id='engrxiv') - self.preprint = factories.PreprintFactory(provider=self.provider) - self.user = factories.UserFactory() - self.sender = factories.UserFactory() - self.context_info = { - 'domain': 'osf.io', - 'reviewable': self.preprint, - 'workflow': 'pre-moderation', - 'provider_contact_email': settings.OSF_CONTACT_EMAIL, - 'provider_support_email': settings.OSF_SUPPORT_EMAIL, - 'document_type': 'preprint', - 'referrer': self.sender, - 'provider_url': self.provider.landing_url, - } - self.action = factories.ReviewActionFactory() - factories.NotificationSubscriptionFactory( - _id=self.user._id + '_' + 'global_comments', - user=self.user, - event_name='global_comments' - ).add_user_to_subscription(self.user, 'email_transactional') - - factories.NotificationSubscriptionFactory( - _id=self.user._id + '_' + 'global_file_updated', - user=self.user, - event_name='global_file_updated' - ).add_user_to_subscription(self.user, 'email_transactional') - - factories.NotificationSubscriptionFactory( - _id=self.user._id + '_' + 'global_reviews', - user=self.user, - event_name='global_reviews' - ).add_user_to_subscription(self.user, 'email_transactional') - - from conftest import start_mock_send_grid - self.mock_send_grid = start_mock_send_grid(self) - - def test_reviews_base_notification(self): - contributor_subscriptions = list(utils.get_all_user_subscriptions(self.user)) - event_types = [sub.event_name for sub in contributor_subscriptions] - assert 'global_reviews' in event_types - - def test_reviews_submit_notification(self): - listeners.reviews_submit_notification(self, context=self.context_info, recipients=[self.sender, self.user]) - assert self.mock_send_grid.called - - @mock.patch('website.notifications.emails.notify_global_event') - def test_reviews_notification(self, mock_notify): - listeners.reviews_notification(self, creator=self.sender, context=self.context_info, action=self.action, template='test.html.mako') - assert mock_notify.called - - -class QuerySetMatcher: - def __init__(self, some_obj): - self.some_obj = some_obj - - def __eq__(self, other): - return list(self.some_obj) == list(other) - - -class TestNotificationsReviewsModerator(OsfTestCase): - - def setUp(self): - super().setUp() - self.provider = factories.PreprintProviderFactory(_id='engrxiv') - self.preprint = factories.PreprintFactory(provider=self.provider) - self.submitter = factories.UserFactory() - self.moderator_transacitonal = factories.UserFactory() - self.moderator_digest= factories.UserFactory() - - self.context_info_submission = { - 'referrer': self.submitter, - 'domain': 'osf.io', - 'reviewable': self.preprint, - 'workflow': 'pre-moderation', - 'provider_contact_email': settings.OSF_CONTACT_EMAIL, - 'provider_support_email': settings.OSF_SUPPORT_EMAIL, - } - - self.context_info_request = { - 'requester': self.submitter, - 'domain': 'osf.io', - 'reviewable': self.preprint, - 'workflow': 'pre-moderation', - 'provider_contact_email': settings.OSF_CONTACT_EMAIL, - 'provider_support_email': settings.OSF_SUPPORT_EMAIL, - } - - self.action = factories.ReviewActionFactory() - self.subscription = NotificationSubscription.load(self.provider._id+'_new_pending_submissions') - self.subscription.add_user_to_subscription(self.moderator_transacitonal, 'email_transactional') - self.subscription.add_user_to_subscription(self.moderator_digest, 'email_digest') - - @mock.patch('website.notifications.emails.store_emails') - def test_reviews_submit_notification(self, mock_store): - time_now = timezone.now() - - preprint = self.context_info_submission['reviewable'] - provider = preprint.provider - - self.context_info_submission['message'] = f'submitted {preprint.title}.' - self.context_info_submission['profile_image_url'] = get_profile_image_url(self.context_info_submission['referrer']) - self.context_info_submission['reviews_submission_url'] = f'{settings.DOMAIN}reviews/preprints/{provider._id}/{preprint._id}' - listeners.reviews_submit_notification_moderators(self, time_now, self.context_info_submission) - subscription = NotificationSubscription.load(self.provider._id + '_new_pending_submissions') - digest_subscriber_ids = list(subscription.email_digest.all().values_list('guids___id', flat=True)) - instant_subscriber_ids = list(subscription.email_transactional.all().values_list('guids___id', flat=True)) - - mock_store.assert_any_call( - digest_subscriber_ids, - 'email_digest', - 'new_pending_submissions', - self.context_info_submission['referrer'], - self.context_info_submission['reviewable'], - time_now, - abstract_provider=self.context_info_submission['reviewable'].provider, - **self.context_info_submission - ) - - mock_store.assert_any_call( - instant_subscriber_ids, - 'email_transactional', - 'new_pending_submissions', - self.context_info_submission['referrer'], - self.context_info_submission['reviewable'], - time_now, - abstract_provider=self.context_info_request['reviewable'].provider, - **self.context_info_submission - ) - - @mock.patch('website.notifications.emails.store_emails') - def test_reviews_request_notification(self, mock_store): - time_now = timezone.now() - self.context_info_request['message'] = 'has requested withdrawal of {} "{}".'.format(self.context_info_request['reviewable'].provider.preprint_word, - self.context_info_request['reviewable'].title) - self.context_info_request['profile_image_url'] = get_profile_image_url(self.context_info_request['requester']) - self.context_info_request['reviews_submission_url'] = '{}reviews/preprints/{}/{}'.format(settings.DOMAIN, - self.context_info_request[ - 'reviewable'].provider._id, - self.context_info_request[ - 'reviewable']._id) - listeners.reviews_withdrawal_requests_notification(self, time_now, self.context_info_request) - subscription = NotificationSubscription.load(self.provider._id + '_new_pending_submissions') - digest_subscriber_ids = subscription.email_digest.all().values_list('guids___id', flat=True) - instant_subscriber_ids = subscription.email_transactional.all().values_list('guids___id', flat=True) - mock_store.assert_any_call(QuerySetMatcher(digest_subscriber_ids), - 'email_digest', - 'new_pending_submissions', - self.context_info_request['requester'], - self.context_info_request['reviewable'], - time_now, - abstract_provider=self.context_info_request['reviewable'].provider, - **self.context_info_request) - - mock_store.assert_any_call(QuerySetMatcher(instant_subscriber_ids), - 'email_transactional', - 'new_pending_submissions', - self.context_info_request['requester'], - self.context_info_request['reviewable'], - time_now, - abstract_provider=self.context_info_request['reviewable'].provider, - **self.context_info_request) diff --git a/tests/test_registrations/test_embargoes.py b/tests/test_registrations/test_embargoes.py index 4c310eecd79..6ae8f8e953a 100644 --- a/tests/test_registrations/test_embargoes.py +++ b/tests/test_registrations/test_embargoes.py @@ -29,7 +29,7 @@ from osf.models.sanctions import SanctionCallbackMixin, Embargo from osf.utils import permissions from osf.models import Registration, Contributor, OSFUser, SpamStatus -from conftest import start_mock_send_grid +from conftest import start_mock_send_grid, start_mock_notification_send DUMMY_TOKEN = tokens.encode({ 'dummy': 'token' @@ -1102,6 +1102,7 @@ def setUp(self): }) self.mock_send_grid = start_mock_send_grid(self) + self.mock_notification_send = start_mock_notification_send(self) @mock.patch('osf.models.sanctions.EmailApprovableSanction.ask') @@ -1159,8 +1160,6 @@ def test_embargoed_registration_set_privacy_sends_mail(self): for contributor in self.registration.contributors: if Contributor.objects.get(user_id=contributor.id, node_id=self.registration.id).permission == permissions.ADMIN: admin_contributors.append(contributor) - for admin in admin_contributors: - assert any([each[1]['to_addr'] == admin.username for each in self.mock_send_grid.call_args_list]) @mock.patch('osf.models.sanctions.EmailApprovableSanction.ask') def test_make_child_embargoed_registration_public_asks_all_admins_in_tree(self, mock_ask): diff --git a/tests/test_registrations/test_retractions.py b/tests/test_registrations/test_retractions.py index 22ee51827dd..dcc62d40b8b 100644 --- a/tests/test_registrations/test_retractions.py +++ b/tests/test_registrations/test_retractions.py @@ -807,7 +807,6 @@ def test_POST_retraction_does_not_send_email_to_unregistered_admins(self): json={'justification': ''}, auth=self.user.auth, ) - # Only the creator gets an email; the unreg user does not get emailed assert self.mock_send_grid.call_count == 1 def test_POST_pending_embargo_returns_HTTPError_HTTPOK(self): diff --git a/tests/test_user_profile_view.py b/tests/test_user_profile_view.py index 8403a9d63c9..bb801340423 100644 --- a/tests/test_user_profile_view.py +++ b/tests/test_user_profile_view.py @@ -1,102 +1,31 @@ #!/usr/bin/env python3 """Views tests for the OSF.""" -from unittest.mock import MagicMock, ANY -from urllib import parse - -import datetime as dt -import time -import unittest from hashlib import md5 -from http.cookies import SimpleCookie from unittest import mock -from urllib.parse import quote_plus import pytest -from django.core.exceptions import ValidationError -from django.utils import timezone -from flask import request, g -from lxml import html -from pytest import approx from rest_framework import status as http_status from addons.github.tests.factories import GitHubAccountFactory -from addons.osfstorage import settings as osfstorage_settings -from addons.wiki.models import WikiPage -from framework import auth -from framework.auth import Auth, authenticate, cas, core -from framework.auth.campaigns import ( - get_campaigns, - is_institution_login, - is_native_login, - is_proxy_login, - campaign_url_for -) -from framework.auth.exceptions import InvalidTokenError -from framework.auth.utils import impute_names_model, ensure_external_identity_uniqueness -from framework.auth.views import login_and_register_handler from framework.celery_tasks import handlers -from framework.exceptions import HTTPError, TemplateHTTPError -from framework.flask import redirect -from framework.transactions.handlers import no_auto_transaction from osf.external.spam import tasks as spam_tasks from osf.models import ( - Comment, - AbstractNode, - OSFUser, - Tag, - SpamStatus, - NodeRelation, NotableDomain ) -from osf.utils import permissions from osf_tests.factories import ( fake_email, ApiOAuth2ApplicationFactory, ApiOAuth2PersonalTokenFactory, AuthUserFactory, - CollectionFactory, - CommentFactory, - NodeFactory, - PreprintFactory, - PreprintProviderFactory, - PrivateLinkFactory, - ProjectFactory, - ProjectWithAddonFactory, - RegistrationProviderFactory, - UserFactory, - UnconfirmedUserFactory, - UnregUserFactory, RegionFactory, - DraftRegistrationFactory, ) from tests.base import ( - assert_is_redirect, - capture_signals, fake, - get_default_metaschema, OsfTestCase, - assert_datetime_equal, - test_app -) -from tests.test_cas_authentication import generate_external_user_with_resp -from tests.utils import run_celery_tasks -from website import mailchimp_utils, mails, settings, language -from website.profile.utils import add_contributor_json, serialize_unregistered -from website.profile.views import update_osf_help_mails_subscription -from website.project.decorators import check_can_access -from website.project.model import has_anonymous_link -from website.project.signals import contributor_added -from website.project.views.contributor import ( - deserialize_contributors, - notify_added_contributor, - send_claim_email, - send_claim_registered_email, ) -from website.project.views.node import _should_show_wiki_widget, abbrev_authors +from website import mailchimp_utils from website.settings import MAILCHIMP_GENERAL_LIST from website.util import api_url_for, web_url_for -from website.util import rubeus -from website.util.metrics import OsfSourceTags, OsfClaimedTags, provider_source_tag, provider_claimed_tag from conftest import start_mock_send_grid diff --git a/tests/test_webtests.py b/tests/test_webtests.py index ae1a30e7618..c55e6b523f4 100644 --- a/tests/test_webtests.py +++ b/tests/test_webtests.py @@ -36,7 +36,7 @@ from addons.wiki.tests.factories import WikiFactory, WikiVersionFactory from website import language from website.util import web_url_for, api_url_for -from conftest import start_mock_send_grid +from conftest import start_mock_send_grid, start_mock_notification_send logging.getLogger('website.project.model').setLevel(logging.ERROR) @@ -805,6 +805,7 @@ def setUp(self): self.user.save() self.mock_send_grid = start_mock_send_grid(self) + self.start_mock_notification_send = start_mock_notification_send(self) # log users out before they land on forgot password page def test_forgot_password_logs_out_user(self): @@ -833,7 +834,7 @@ def test_can_receive_reset_password_email(self): res = form.submit(self.app) # check mail was sent - assert self.mock_send_grid.called + assert self.start_mock_notification_send.called # check http 200 response assert res.status_code == 200 # check request URL is /forgotpassword @@ -923,6 +924,7 @@ def setUp(self): self.user.save() self.mock_send_grid = start_mock_send_grid(self) + self.start_mock_notification_send = start_mock_notification_send(self) # log users out before they land on institutional forgot password page def test_forgot_password_logs_out_user(self): @@ -949,7 +951,7 @@ def test_can_receive_reset_password_email(self): res = self.app.post(self.post_url, data={'forgot_password-email': self.user.username}) # check mail was sent - assert self.mock_send_grid.called + assert self.start_mock_notification_send.called # check http 200 response assert res.status_code == 200 # check request URL is /forgotpassword diff --git a/website/notifications/emails.py b/website/notifications/emails.py index d26d43351d5..56f513920af 100644 --- a/website/notifications/emails.py +++ b/website/notifications/emails.py @@ -2,7 +2,8 @@ from babel import dates, core, Locale -from osf.models import AbstractNode, NotificationDigest, NotificationSubscription +from osf.models import AbstractNode, NotificationSubscriptionLegacy +from osf.models.notifications import NotificationDigest from osf.utils.permissions import ADMIN, READ from website import mails from website.notifications import constants @@ -159,7 +160,7 @@ def check_node(node, event): """Return subscription for a particular node and event.""" node_subscriptions = {key: [] for key in constants.NOTIFICATION_TYPES} if node: - subscription = NotificationSubscription.load(utils.to_subscription_key(node._id, event)) + subscription = NotificationSubscriptionLegacy.load(utils.to_subscription_key(node._id, event)) for notification_type in node_subscriptions: users = getattr(subscription, notification_type, []) if users: @@ -172,7 +173,7 @@ def check_node(node, event): def get_user_subscriptions(user, event): if user.is_disabled: return {} - user_subscription = NotificationSubscription.load(utils.to_subscription_key(user._id, event)) + user_subscription = NotificationSubscriptionLegacy.load(utils.to_subscription_key(user._id, event)) if user_subscription: return {key: list(getattr(user_subscription, key).all().values_list('guids___id', flat=True)) for key in constants.NOTIFICATION_TYPES} else: diff --git a/website/notifications/utils.py b/website/notifications/utils.py index bc79781abc4..51d487ff67a 100644 --- a/website/notifications/utils.py +++ b/website/notifications/utils.py @@ -91,10 +91,10 @@ def remove_supplemental_node(node): @app.task(max_retries=5, default_retry_delay=60) def remove_subscription_task(node_id): AbstractNode = apps.get_model('osf.AbstractNode') - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') node = AbstractNode.load(node_id) - NotificationSubscription.objects.filter(node=node).delete() + NotificationSubscriptionLegacy.objects.filter(node=node).delete() parent = node.parent_node if parent and parent.child_node_subscriptions: @@ -144,12 +144,12 @@ def users_to_remove(source_event, source_node, new_node): :param new_node: Node instance where a sub or new sub will be. :return: Dict of notification type lists with user_ids """ - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') removed_users = {key: [] for key in constants.NOTIFICATION_TYPES} if source_node == new_node: return removed_users - old_sub = NotificationSubscription.load(to_subscription_key(source_node._id, source_event)) - old_node_sub = NotificationSubscription.load(to_subscription_key(source_node._id, + old_sub = NotificationSubscriptionLegacy.load(to_subscription_key(source_node._id, source_event)) + old_node_sub = NotificationSubscriptionLegacy.load(to_subscription_key(source_node._id, '_'.join(source_event.split('_')[-2:]))) if not old_sub and not old_node_sub: return removed_users @@ -172,11 +172,11 @@ def move_subscription(remove_users, source_event, source_node, new_event, new_no :param new_node: Instance of Node :return: Returns a NOTIFICATION_TYPES list of removed users without permissions """ - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') OSFUser = apps.get_model('osf.OSFUser') if source_node == new_node: return - old_sub = NotificationSubscription.load(to_subscription_key(source_node._id, source_event)) + old_sub = NotificationSubscriptionLegacy.load(to_subscription_key(source_node._id, source_event)) if not old_sub: return elif old_sub: @@ -236,8 +236,8 @@ def check_project_subscriptions_are_all_none(user, node): def get_all_user_subscriptions(user, extra=None): """ Get all Subscription objects that the user is subscribed to""" - NotificationSubscription = apps.get_model('osf.NotificationSubscription') - queryset = NotificationSubscription.objects.filter( + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') + queryset = NotificationSubscriptionLegacy.objects.filter( Q(none=user.pk) | Q(email_digest=user.pk) | Q(email_transactional=user.pk) @@ -391,14 +391,14 @@ def get_parent_notification_type(node, event, user): :return: str notification type (e.g. 'email_transactional') """ AbstractNode = apps.get_model('osf.AbstractNode') - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') if node and isinstance(node, AbstractNode) and node.parent_node and node.parent_node.has_permission(user, READ): parent = node.parent_node key = to_subscription_key(parent._id, event) try: - subscription = NotificationSubscription.objects.get(_id=key) - except NotificationSubscription.DoesNotExist: + subscription = NotificationSubscriptionLegacy.objects.get(_id=key) + except NotificationSubscriptionLegacy.DoesNotExist: return get_parent_notification_type(parent, event, user) for notification_type in constants.NOTIFICATION_TYPES: @@ -428,19 +428,19 @@ def check_if_all_global_subscriptions_are_none(user): # This function predates comment mentions, which is a global_ notification that cannot be disabled # Therefore, an actual check would never return True. # If this changes, an optimized query would look something like: - # not NotificationSubscription.objects.filter(Q(event_name__startswith='global_') & (Q(email_digest=user.pk)|Q(email_transactional=user.pk))).exists() + # not NotificationSubscriptionLegacy.objects.filter(Q(event_name__startswith='global_') & (Q(email_digest=user.pk)|Q(email_transactional=user.pk))).exists() return False def subscribe_user_to_global_notifications(user): - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') notification_type = 'email_transactional' user_events = constants.USER_SUBSCRIPTIONS_AVAILABLE for user_event in user_events: user_event_id = to_subscription_key(user._id, user_event) # get_or_create saves on creation - subscription, created = NotificationSubscription.objects.get_or_create(_id=user_event_id, user=user, event_name=user_event) + subscription, created = NotificationSubscriptionLegacy.objects.get_or_create(_id=user_event_id, user=user, event_name=user_event) subscription.add_user_to_subscription(user, notification_type) subscription.save() @@ -449,7 +449,7 @@ def subscribe_user_to_notifications(node, user): """ Update the notification settings for the creator or contributors :param user: User to subscribe to notifications """ - NotificationSubscription = apps.get_model('osf.NotificationSubscription') + NotificationSubscriptionLegacy = apps.get_model('osf.NotificationSubscriptionLegacy') Preprint = apps.get_model('osf.Preprint') DraftRegistration = apps.get_model('osf.DraftRegistration') if isinstance(node, Preprint): @@ -475,16 +475,16 @@ def subscribe_user_to_notifications(node, user): for event in events: event_id = to_subscription_key(target_id, event) global_event_id = to_subscription_key(user._id, 'global_' + event) - global_subscription = NotificationSubscription.load(global_event_id) + global_subscription = NotificationSubscriptionLegacy.load(global_event_id) - subscription = NotificationSubscription.load(event_id) + subscription = NotificationSubscriptionLegacy.load(event_id) # If no subscription for component and creator is the user, do not create subscription # If no subscription exists for the component, this means that it should adopt its # parent's settings if not (node and node.parent_node and not subscription and node.creator == user): if not subscription: - subscription = NotificationSubscription(_id=event_id, owner=node, event_name=event) + subscription = NotificationSubscriptionLegacy(_id=event_id, owner=node, event_name=event) # Need to save here in order to access m2m fields subscription.save() if global_subscription: diff --git a/website/notifications/views.py b/website/notifications/views.py index 8ca4775367d..1cbb62ee08d 100644 --- a/website/notifications/views.py +++ b/website/notifications/views.py @@ -6,7 +6,8 @@ from framework.auth.decorators import must_be_logged_in from framework.exceptions import HTTPError -from osf.models import AbstractNode, NotificationSubscription, Registration +from osf.models import AbstractNode, Registration +from osf.models.notifications import NotificationSubscriptionLegacy from osf.utils.permissions import READ from website.notifications import utils from website.notifications.constants import NOTIFICATION_TYPES @@ -95,17 +96,17 @@ def configure_subscription(auth): raise HTTPError(http_status.HTTP_400_BAD_REQUEST) # If adopt_parent make sure that this subscription is None for the current User - subscription = NotificationSubscription.load(event_id) + subscription = NotificationSubscriptionLegacy.load(event_id) if not subscription: return {} # We're done here subscription.remove_user_from_subscription(user) return {} - subscription = NotificationSubscription.load(event_id) + subscription = NotificationSubscriptionLegacy.load(event_id) if not subscription: - subscription = NotificationSubscription(_id=event_id, owner=owner, event_name=event) + subscription = NotificationSubscriptionLegacy(_id=event_id, owner=owner, event_name=event) subscription.save() if node and node._id not in user.notifications_configured: diff --git a/website/project/views/contributor.py b/website/project/views/contributor.py index f3e06aff3fc..d144642634d 100644 --- a/website/project/views/contributor.py +++ b/website/project/views/contributor.py @@ -20,14 +20,14 @@ from osf.models import Tag from osf.exceptions import NodeStateError from osf.models import AbstractNode, DraftRegistration, OSFUser, Preprint, PreprintProvider, RecentlyAddedContributor +from osf.models.notification_type import FrequencyChoices, NotificationType from osf.utils import sanitize from osf.utils.permissions import ADMIN -from website import mails, language, settings +from website import language, settings from website.notifications.utils import check_if_all_global_subscriptions_are_none from website.profile import utils as profile_utils from website.project.decorators import (must_have_permission, must_be_valid_project, must_not_be_registration, must_be_contributor_or_public, must_be_contributor) -from website.project.views.node import serialize_preprints from website.project.model import has_anonymous_link from website.project.signals import unreg_contributor_added, contributor_added from website.util import web_url_for, is_json_request @@ -421,29 +421,41 @@ def send_claim_registered_email(claimer, unclaimed_user, node, throttle=24 * 360 ) # Send mail to referrer, telling them to forward verification link to claimer - mails.send_mail( - referrer.username, - mails.FORWARD_INVITE_REGISTERED, - user=unclaimed_user, - referrer=referrer, - node=node, - claim_url=claim_url, - fullname=unclaimed_record['name'], - can_change_preferences=False, - osf_contact_email=settings.OSF_CONTACT_EMAIL, + notification_type_name = NotificationType.Type.USER_FORWARD_INVITE_REGISTERED.value + notification_type = NotificationType.objects.get(name=notification_type_name) + event_context = { + 'referrer': { + 'fullname': referrer.fullname, + }, + 'node': { + 'title': node.title, + }, + 'fullname': unclaimed_record['name'], + 'claim_url': claim_url, + 'osf_contact_email': settings.OSF_CONTACT_EMAIL, + } + notification_type.emit( + user=referrer, + subscribed_object=node, + message_frequency=FrequencyChoices.INSTANTLY.value, + event_context=event_context, ) unclaimed_record['last_sent'] = get_timestamp() unclaimed_user.save() # Send mail to claimer, telling them to wait for referrer - mails.send_mail( - claimer.username, - mails.PENDING_VERIFICATION_REGISTERED, - fullname=claimer.fullname, - referrer=referrer, - node=node, - can_change_preferences=False, - osf_contact_email=settings.OSF_CONTACT_EMAIL, + notification_type_name = NotificationType.Type.USER_PENDING_VERIFICATION_REGISTERED.value + notification_type = NotificationType.objects.get(name=notification_type_name) + event_context = { + 'fullname': claimer.fullname, + 'can_change_preferences': False, + 'osf_contact_email': settings.OSF_CONTACT_EMAIL, + } + notification_type.emit( + user=claimer, + subscribed_object=node, + message_frequency=FrequencyChoices.INSTANTLY.value, + event_context=event_context, ) @@ -474,20 +486,17 @@ def send_claim_email(email, unclaimed_user, node, notify=True, throttle=24 * 360 # Option 1: # When adding the contributor, the referrer provides both name and email. # The given email is the same provided by user, just send to that email. - logo = None if unclaimed_record.get('email') == claimer_email: # check email template for branded preprints if email_template == 'preprint': if node.provider.is_default: - mail_tpl = mails.INVITE_OSF_PREPRINT - logo = settings.OSF_PREPRINTS_LOGO + notification_type_name = NotificationType.Type.USER_INVITE_OSF_PREPRINT.value else: - mail_tpl = mails.INVITE_PREPRINT(node.provider) - logo = node.provider._id + notification_type_name = NotificationType.Type.PROVIDER_USER_INVITE_PREPRINT.value elif email_template == 'draft_registration': - mail_tpl = mails.INVITE_DRAFT_REGISTRATION + notification_type_name = NotificationType.Type.USER_INVITE_DRAFT_REGISTRATION.value else: - mail_tpl = mails.INVITE_DEFAULT + notification_type_name = NotificationType.Type.USER_INVITE_DEFAULT.value to_addr = claimer_email unclaimed_record['claimer_email'] = claimer_email @@ -515,34 +524,40 @@ def send_claim_email(email, unclaimed_user, node, notify=True, throttle=24 * 360 claim_url = unclaimed_user.get_claim_url(node._primary_key, external=True) # send an email to the invited user without `claim_url` if notify: - pending_mail = mails.PENDING_VERIFICATION - mails.send_mail( - claimer_email, - pending_mail, + notification_type_name = NotificationType.Type.USER_PENDING_VERIFICATION_REGISTERED.value + notification_type = NotificationType.objects.get(name=notification_type_name) + event_context = { + 'fullname': unclaimed_record['name'], + 'can_change_preferences': False, + 'osf_contact_email': settings.OSF_CONTACT_EMAIL, + } + notification_type.emit( user=unclaimed_user, - referrer=referrer, - fullname=unclaimed_record['name'], - node=node, - can_change_preferences=False, - osf_contact_email=settings.OSF_CONTACT_EMAIL, + subscribed_object=node, + message_frequency=FrequencyChoices.INSTANTLY.value, + event_context=event_context, ) - mail_tpl = mails.FORWARD_INVITE + notification_type_name = NotificationType.Type.USER_FORWARD_INVITE.value to_addr = referrer.username # Send an email to the claimer (Option 1) or to the referrer (Option 2) with `claim_url` - mails.send_mail( - to_addr, - mail_tpl, + notification_type = NotificationType.objects.get(name=notification_type_name) + event_context = { + 'fullname': unclaimed_record['name'], + 'referrer': { + 'fullname': referrer.fullname, + }, + 'node': { + 'title': node.title, + }, + 'claim_url': claim_url, + 'osf_contact_email': settings.OSF_CONTACT_EMAIL, + } + notification_type.emit( user=unclaimed_user, - referrer=referrer, - node=node, - claim_url=claim_url, - email=claimer_email, - fullname=unclaimed_record['name'], - branded_service=node.provider, - can_change_preferences=False, - logo=logo if logo else settings.OSF_LOGO, - osf_contact_email=settings.OSF_CONTACT_EMAIL, + subscribed_object=node, + message_frequency=FrequencyChoices.INSTANTLY.value, + event_context=event_context, ) return to_addr @@ -562,7 +577,6 @@ def check_email_throttle(node, contributor, throttle=None): @contributor_added.connect def notify_added_contributor(node, contributor, auth=None, email_template='default', throttle=None, *args, **kwargs): - logo = settings.OSF_LOGO if check_email_throttle(node, contributor, throttle=throttle): return if email_template == 'false': @@ -584,35 +598,32 @@ def notify_added_contributor(node, contributor, auth=None, email_template='defau if contrib_on_parent_node: if email_template == 'preprint': if node.provider.is_default: - email_template = mails.CONTRIBUTOR_ADDED_OSF_PREPRINT - logo = settings.OSF_PREPRINTS_LOGO + notification_type_name = NotificationType.Type.USER_CONTRIBUTOR_ADDED_OSF_PREPRINT.value else: - email_template = mails.CONTRIBUTOR_ADDED_PREPRINT(node.provider) - logo = node.provider._id + notification_type_name = NotificationType.Type.USER_CONTRIBUTOR_ADDED_OSF_PREPRINT.value elif email_template == 'draft_registration': - email_template = mails.CONTRIBUTOR_ADDED_DRAFT_REGISTRATION + notification_type_name = NotificationType.Type.USER_CONTRIBUTOR_ADDED_DRAFT_REGISTRATION.value elif email_template == 'access_request': - email_template = mails.CONTRIBUTOR_ADDED_ACCESS_REQUEST + notification_type_name = NotificationType.Type.NODE_CONTRIBUTOR_ADDED_ACCESS_REQUEST.value elif node.has_linked_published_preprints: - # Project holds supplemental materials for a published preprint - email_template = mails.CONTRIBUTOR_ADDED_PREPRINT_NODE_FROM_OSF - logo = settings.OSF_PREPRINTS_LOGO + notification_type_name = NotificationType.Type.PREPRINT_CONTRIBUTOR_ADDED_PREPRINT_NODE_FROM_OSF.value else: - email_template = mails.CONTRIBUTOR_ADDED_DEFAULT - - mails.send_mail( - to_addr=contributor.username, - mail=email_template, + notification_type_name = NotificationType.Type.USER_CONTRIBUTOR_ADDED_DEFAULT.value + + notification_type = NotificationType.objects.get(name=notification_type_name) + event_context = { + 'referrer_name': auth.user.fullname if auth else '', + 'is_initiator': getattr(auth, 'user', False) == contributor, + 'all_global_subscriptions_none': check_if_all_global_subscriptions_are_none(contributor), + 'branded_service': node.id, + 'can_change_preferences': False, + 'osf_contact_email': settings.OSF_CONTACT_EMAIL, + } + notification_type.emit( user=contributor, - node=node, - referrer_name=auth.user.fullname if auth else '', - is_initiator=getattr(auth, 'user', False) == contributor, - all_global_subscriptions_none=check_if_all_global_subscriptions_are_none(contributor), - branded_service=node.provider, - can_change_preferences=False, - logo=logo, - osf_contact_email=settings.OSF_CONTACT_EMAIL, - published_preprints=[] if isinstance(node, (Preprint, DraftRegistration)) else serialize_preprints(node, user=None) + message_frequency=FrequencyChoices.INSTANTLY.value, + subscribed_object=node, + event_context=event_context, ) contributor.contributor_added_email_records[node._id]['last_sent'] = get_timestamp() diff --git a/website/reviews/listeners.py b/website/reviews/listeners.py index 27a15c2c337..d6f3471dac7 100644 --- a/website/reviews/listeners.py +++ b/website/reviews/listeners.py @@ -71,7 +71,7 @@ def reviews_submit_notification_moderators(self, timestamp, context): Handle email notifications to notify moderators of new submissions or resubmission. """ # imports moved here to avoid AppRegistryNotReady error - from osf.models import NotificationSubscription + from osf.models import NotificationSubscriptionLegacy from website.profile.utils import get_profile_image_url from website.notifications.emails import store_emails @@ -103,7 +103,7 @@ def reviews_submit_notification_moderators(self, timestamp, context): context['message'] = f'submitted "{resource.title}".' # Get NotificationSubscription instance, which contains reference to all subscribers - provider_subscription, created = NotificationSubscription.objects.get_or_create( + provider_subscription, created = NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{provider._id}_new_pending_submissions', provider=provider ) @@ -138,7 +138,7 @@ def reviews_submit_notification_moderators(self, timestamp, context): @reviews_signals.reviews_withdraw_requests_notification_moderators.connect def reviews_withdraw_requests_notification_moderators(self, timestamp, context): # imports moved here to avoid AppRegistryNotReady error - from osf.models import NotificationSubscription + from osf.models import NotificationSubscriptionLegacy from website.profile.utils import get_profile_image_url from website.notifications.emails import store_emails @@ -146,7 +146,7 @@ def reviews_withdraw_requests_notification_moderators(self, timestamp, context): provider = resource.provider # Get NotificationSubscription instance, which contains reference to all subscribers - provider_subscription, created = NotificationSubscription.objects.get_or_create( + provider_subscription, created = NotificationSubscriptionLegacy.objects.get_or_create( _id=f'{provider._id}_new_pending_withdraw_requests', provider=provider ) @@ -191,13 +191,13 @@ def reviews_withdraw_requests_notification_moderators(self, timestamp, context): @reviews_signals.reviews_email_withdrawal_requests.connect def reviews_withdrawal_requests_notification(self, timestamp, context): # imports moved here to avoid AppRegistryNotReady error - from osf.models import NotificationSubscription + from osf.models import NotificationSubscriptionLegacy from website.notifications.emails import store_emails from website.profile.utils import get_profile_image_url from website import settings # Get NotificationSubscription instance, which contains reference to all subscribers - provider_subscription = NotificationSubscription.load( + provider_subscription = NotificationSubscriptionLegacy.load( '{}_new_pending_submissions'.format(context['reviewable'].provider._id)) preprint = context['reviewable'] preprint_word = preprint.provider.preprint_word diff --git a/website/settings/defaults.py b/website/settings/defaults.py index 80cc6b18ed1..a20a50c3e52 100644 --- a/website/settings/defaults.py +++ b/website/settings/defaults.py @@ -179,6 +179,7 @@ def parent_dir(path): MAILCHIMP_LIST_MAP = { MAILCHIMP_GENERAL_LIST: '123', } +NOTIFICATION_TYPES_YAML = 'notifications.yaml' #Triggered emails OSF_HELP_LIST = 'Open Science Framework Help' @@ -440,7 +441,6 @@ class CeleryConfig: 'osf.management.commands.migrate_pagecounter_data', 'osf.management.commands.migrate_deleted_date', 'osf.management.commands.addon_deleted_date', - 'osf.management.commands.migrate_registration_responses', 'osf.management.commands.archive_registrations_on_IA' 'osf.management.commands.sync_doi_metadata', 'osf.management.commands.sync_collection_provider_indices', @@ -692,9 +692,6 @@ class CeleryConfig: # 'task': 'management.commands.migrate_pagecounter_data', # 'schedule': crontab(minute=0, hour=7), # Daily 2:00 a.m. # }, - # 'migrate_registration_responses': { - # 'task': 'management.commands.migrate_registration_responses', - # 'schedule': crontab(minute=32, hour=7), # Daily 2:32 a.m. # 'migrate_deleted_date': { # 'task': 'management.commands.migrate_deleted_date', # 'schedule': crontab(minute=0, hour=3), @@ -702,10 +699,6 @@ class CeleryConfig: # 'task': 'management.commands.addon_deleted_date', # 'schedule': crontab(minute=0, hour=3), # Daily 11:00 p.m. # }, - # 'populate_branched_from': { - # 'task': 'management.commands.populate_branched_from', - # 'schedule': crontab(minute=0, hour=3), - # }, 'generate_sitemap': { 'task': 'scripts.generate_sitemap', 'schedule': crontab(minute=0, hour=5), # Daily 12:00 a.m.