Skip to content

Commit

Permalink
Merge pull request #255 from stephane/speedup-group-creation
Browse files Browse the repository at this point in the history
Reduce the number of SQL queries in updates of groups
  • Loading branch information
JonasKs authored Oct 4, 2022
2 parents de92fa1 + e440441 commit e51135c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
35 changes: 18 additions & 17 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.models import Group
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
PermissionDenied)

from django_auth_adfs import signals
from django_auth_adfs.config import provider_config, settings
Expand Down Expand Up @@ -322,27 +323,27 @@ def update_user_groups(self, user, claim_groups):
"""
if settings.GROUPS_CLAIM is not None:
# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
user_group_names = user.groups.all().values_list("name", flat=True)

if sorted(claim_groups) != sorted(user_group_names):
# Get the list of already existing groups in one SQL query
existing_claimed_groups = Group.objects.filter(name__in=claim_groups)

if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator())
existing_group_names = frozenset(group.name for group in existing_groups)
new_groups = []
if settings.MIRROR_GROUPS:
new_groups = [
existing_claimed_group_names = (
group.name for group in existing_claimed_groups
)
# One SQL query by created group.
# bulk_create could have been used here but we want to send signals.
new_claimed_groups = [
Group.objects.get_or_create(name=name)[0]
for name in claim_groups
if name not in existing_group_names
for name in claim_groups if name not in existing_claimed_group_names
]
# Associate the users to all claimed groups
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))
else:
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)
# Associate the user to only existing claimed groups
user.groups.set(existing_claimed_groups)

def update_user_flags(self, user, claims, claim_groups):
"""
Expand Down
39 changes: 36 additions & 3 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from django_auth_adfs.exceptions import MFARequired

try:
from urllib.parse import urlparse, parse_qs
from urllib.parse import parse_qs, urlparse
except ImportError: # Python 2.7
from urlparse import urlparse, parse_qs

from copy import deepcopy

from django.contrib.auth.models import User, Group
from django.contrib.auth.models import Group, User
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db.models.signals import post_save
from django.test import TestCase, RequestFactory
from django.test import RequestFactory, TestCase
from mock import Mock, patch

from django_auth_adfs import signals
from django_auth_adfs.backend import AdfsAuthCodeBackend
from django_auth_adfs.config import ProviderConfig, Settings

from .models import Profile
from .utils import mock_adfs

Expand Down Expand Up @@ -175,6 +176,38 @@ def test_no_group_claim(self):
self.assertEqual(user.email, "[email protected]")
self.assertEqual(len(user.groups.all()), 0)

@mock_adfs("2016")
def test_group_claim_with_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True):
user = backend.authenticate(self.request, authorization_code="dummycode")
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
self.assertEqual(user.email, "[email protected]")
# group1 is restored
group_names = user.groups.order_by("name").values_list("name", flat=True)
self.assertSequenceEqual(group_names, ['group1', 'group2'])

@mock_adfs("2016")
def test_group_claim_without_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False):
user = backend.authenticate(self.request, authorization_code="dummycode")
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
self.assertEqual(user.email, "[email protected]")
# User is not added to group1 because the group doesn't exist
group_names = user.groups.values_list("name", flat=True)
self.assertSequenceEqual(group_names, ['group2'])

@mock_adfs("2016", empty_keys=True)
def test_empty_keys(self):
backend = AdfsAuthCodeBackend()
Expand Down

0 comments on commit e51135c

Please sign in to comment.