From 9f8448b4eacdee0e9eba9cd9129e2ad1eef8ffd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Tue, 6 Sep 2022 00:12:37 +0200 Subject: [PATCH 1/4] Reduce the number of SQL queries - only one query instead of one for each group when no MIRROR_GROUPS - the use of iterator() inside a list() creates a useless SQL cursor - avoid to create Django instances of groups by using values_list() - replace a frozenset by a tuple --- django_auth_adfs/backend.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index f556d85..847d17d 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -4,7 +4,8 @@ from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group -from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, + PermissionDenied) from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -322,27 +323,24 @@ def update_user_groups(self, user, claim_groups): """ if settings.GROUPS_CLAIM is not None: # Update the user's group memberships - django_groups = [group.name for group in user.groups.all()] + user_group_names = user.groups.all().values_list("name", flat=True) - if sorted(claim_groups) != sorted(django_groups): - existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator()) - existing_group_names = frozenset(group.name for group in existing_groups) - new_groups = [] + if sorted(claim_groups) != sorted(user_group_names): + # Get the list of already existing groups in one query + existing_claimed_groups = Group.objects.filter(name__in=claim_groups) + existing_claimed_group_names = ( + group.name for group in existing_claimed_groups + ) + + new_claimed_group_names = (name for name in claim_groups if name not in existing_claimed_group_names) if settings.MIRROR_GROUPS: - new_groups = [ + new_claimed_groups = [ Group.objects.get_or_create(name=name)[0] - for name in claim_groups - if name not in existing_group_names + for name in new_claimed_group_names ] else: - for name in claim_groups: - if name not in existing_group_names: - try: - group = Group.objects.get(name=name) - new_groups.append(group) - except ObjectDoesNotExist: - pass - user.groups.set(existing_groups + new_groups) + new_claimed_groups = Group.objects.filter(name__in=new_claimed_group_names) + user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) def update_user_flags(self, user, claims, claim_groups): """ From 40f2de6ed14072f64bfa2d9e08c453cb0616e59f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Tue, 6 Sep 2022 00:49:32 +0200 Subject: [PATCH 2/4] Add coverage for MIRROR_GROUPS --- tests/test_authentication.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9f38cbe..d7bc2cf 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -3,21 +3,22 @@ from django_auth_adfs.exceptions import MFARequired try: - from urllib.parse import urlparse, parse_qs + from urllib.parse import parse_qs, urlparse except ImportError: # Python 2.7 from urlparse import urlparse, parse_qs from copy import deepcopy -from django.contrib.auth.models import User, Group +from django.contrib.auth.models import Group, User from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from django.db.models.signals import post_save -from django.test import TestCase, RequestFactory +from django.test import RequestFactory, TestCase from mock import Mock, patch from django_auth_adfs import signals from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.config import ProviderConfig, Settings + from .models import Profile from .utils import mock_adfs @@ -175,6 +176,22 @@ def test_no_group_claim(self): self.assertEqual(user.email, "john.doe@example.com") self.assertEqual(len(user.groups.all()), 0) + @mock_adfs("2016") + def test_group_claim_with_mirror_groups(self): + # Remove one group + Group.objects.filter(name="group1").delete() + + backend = AdfsAuthCodeBackend() + with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "group"): + with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): + user = backend.authenticate(self.request, authorization_code="dummycode") + self.assertIsInstance(user, User) + self.assertEqual(user.first_name, "John") + self.assertEqual(user.last_name, "Doe") + self.assertEqual(user.email, "john.doe@example.com") + # Group restored + self.assertEqual(len(user.groups.all()), 2) + @mock_adfs("2016", empty_keys=True) def test_empty_keys(self): backend = AdfsAuthCodeBackend() From c956402098e7f736bc43cc207e59ed950cbcc488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Sat, 24 Sep 2022 16:18:47 +0200 Subject: [PATCH 3/4] Don't filter on new claimed groups when no MIRROR_GROUPS --- django_auth_adfs/backend.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 847d17d..4574cb4 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -326,21 +326,24 @@ def update_user_groups(self, user, claim_groups): user_group_names = user.groups.all().values_list("name", flat=True) if sorted(claim_groups) != sorted(user_group_names): - # Get the list of already existing groups in one query + # Get the list of already existing groups in one SQL query existing_claimed_groups = Group.objects.filter(name__in=claim_groups) - existing_claimed_group_names = ( - group.name for group in existing_claimed_groups - ) - new_claimed_group_names = (name for name in claim_groups if name not in existing_claimed_group_names) if settings.MIRROR_GROUPS: + existing_claimed_group_names = ( + group.name for group in existing_claimed_groups + ) + # One SQL query by created group. + # bulk_create could have been used here but we want to send signals. new_claimed_groups = [ Group.objects.get_or_create(name=name)[0] - for name in new_claimed_group_names + for name in claim_groups if name not in existing_claimed_group_names ] + # Associate the users to all claimed groups + user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) else: - new_claimed_groups = Group.objects.filter(name__in=new_claimed_group_names) - user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) + # Associate the user to only existing claimed groups + user.groups.set(existing_claimed_groups) def update_user_flags(self, user, claims, claim_groups): """ From e440441d591369e92e640478026aee1414c6a60e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Sat, 24 Sep 2022 16:37:18 +0200 Subject: [PATCH 4/4] Test without MIRROR_GROUPS --- tests/test_authentication.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index d7bc2cf..c16691f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -182,15 +182,31 @@ def test_group_claim_with_mirror_groups(self): Group.objects.filter(name="group1").delete() backend = AdfsAuthCodeBackend() - with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "group"): - with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): - user = backend.authenticate(self.request, authorization_code="dummycode") - self.assertIsInstance(user, User) - self.assertEqual(user.first_name, "John") - self.assertEqual(user.last_name, "Doe") - self.assertEqual(user.email, "john.doe@example.com") - # Group restored - self.assertEqual(len(user.groups.all()), 2) + with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): + user = backend.authenticate(self.request, authorization_code="dummycode") + self.assertIsInstance(user, User) + self.assertEqual(user.first_name, "John") + self.assertEqual(user.last_name, "Doe") + self.assertEqual(user.email, "john.doe@example.com") + # group1 is restored + group_names = user.groups.order_by("name").values_list("name", flat=True) + self.assertSequenceEqual(group_names, ['group1', 'group2']) + + @mock_adfs("2016") + def test_group_claim_without_mirror_groups(self): + # Remove one group + Group.objects.filter(name="group1").delete() + + backend = AdfsAuthCodeBackend() + with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False): + user = backend.authenticate(self.request, authorization_code="dummycode") + self.assertIsInstance(user, User) + self.assertEqual(user.first_name, "John") + self.assertEqual(user.last_name, "Doe") + self.assertEqual(user.email, "john.doe@example.com") + # User is not added to group1 because the group doesn't exist + group_names = user.groups.values_list("name", flat=True) + self.assertSequenceEqual(group_names, ['group2']) @mock_adfs("2016", empty_keys=True) def test_empty_keys(self):