diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index f556d85..4574cb4 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -4,7 +4,8 @@ from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group -from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, + PermissionDenied) from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -322,27 +323,27 @@ def update_user_groups(self, user, claim_groups): """ if settings.GROUPS_CLAIM is not None: # Update the user's group memberships - django_groups = [group.name for group in user.groups.all()] + user_group_names = user.groups.all().values_list("name", flat=True) + + if sorted(claim_groups) != sorted(user_group_names): + # Get the list of already existing groups in one SQL query + existing_claimed_groups = Group.objects.filter(name__in=claim_groups) - if sorted(claim_groups) != sorted(django_groups): - existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator()) - existing_group_names = frozenset(group.name for group in existing_groups) - new_groups = [] if settings.MIRROR_GROUPS: - new_groups = [ + existing_claimed_group_names = ( + group.name for group in existing_claimed_groups + ) + # One SQL query by created group. + # bulk_create could have been used here but we want to send signals. + new_claimed_groups = [ Group.objects.get_or_create(name=name)[0] - for name in claim_groups - if name not in existing_group_names + for name in claim_groups if name not in existing_claimed_group_names ] + # Associate the users to all claimed groups + user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) else: - for name in claim_groups: - if name not in existing_group_names: - try: - group = Group.objects.get(name=name) - new_groups.append(group) - except ObjectDoesNotExist: - pass - user.groups.set(existing_groups + new_groups) + # Associate the user to only existing claimed groups + user.groups.set(existing_claimed_groups) def update_user_flags(self, user, claims, claim_groups): """ diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9f38cbe..c16691f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -3,21 +3,22 @@ from django_auth_adfs.exceptions import MFARequired try: - from urllib.parse import urlparse, parse_qs + from urllib.parse import parse_qs, urlparse except ImportError: # Python 2.7 from urlparse import urlparse, parse_qs from copy import deepcopy -from django.contrib.auth.models import User, Group +from django.contrib.auth.models import Group, User from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from django.db.models.signals import post_save -from django.test import TestCase, RequestFactory +from django.test import RequestFactory, TestCase from mock import Mock, patch from django_auth_adfs import signals from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.config import ProviderConfig, Settings + from .models import Profile from .utils import mock_adfs @@ -175,6 +176,38 @@ def test_no_group_claim(self): self.assertEqual(user.email, "john.doe@example.com") self.assertEqual(len(user.groups.all()), 0) + @mock_adfs("2016") + def test_group_claim_with_mirror_groups(self): + # Remove one group + Group.objects.filter(name="group1").delete() + + backend = AdfsAuthCodeBackend() + with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): + user = backend.authenticate(self.request, authorization_code="dummycode") + self.assertIsInstance(user, User) + self.assertEqual(user.first_name, "John") + self.assertEqual(user.last_name, "Doe") + self.assertEqual(user.email, "john.doe@example.com") + # group1 is restored + group_names = user.groups.order_by("name").values_list("name", flat=True) + self.assertSequenceEqual(group_names, ['group1', 'group2']) + + @mock_adfs("2016") + def test_group_claim_without_mirror_groups(self): + # Remove one group + Group.objects.filter(name="group1").delete() + + backend = AdfsAuthCodeBackend() + with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False): + user = backend.authenticate(self.request, authorization_code="dummycode") + self.assertIsInstance(user, User) + self.assertEqual(user.first_name, "John") + self.assertEqual(user.last_name, "Doe") + self.assertEqual(user.email, "john.doe@example.com") + # User is not added to group1 because the group doesn't exist + group_names = user.groups.values_list("name", flat=True) + self.assertSequenceEqual(group_names, ['group2']) + @mock_adfs("2016", empty_keys=True) def test_empty_keys(self): backend = AdfsAuthCodeBackend()