Skip to content

Commit

Permalink
Merge pull request #281 from seb-b/fixes/227
Browse files Browse the repository at this point in the history
Use all middleware defined in decorator
  • Loading branch information
dopry authored May 11, 2023
2 parents 985a3e3 + 43f90b6 commit f0b73e5
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 3 deletions.
6 changes: 6 additions & 0 deletions example/home/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AuthorPage,
BlogPage,
BlogPageRelatedLink,
MiddlewareModel,
Person,
SimpleModel,
)
Expand Down Expand Up @@ -163,3 +164,8 @@ class Meta:
class SimpleModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SimpleModel


class MiddlewareModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = MiddlewareModel
26 changes: 26 additions & 0 deletions example/home/migrations/0006_middleware_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 3.2.16 on 2022-12-02 05:10

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("home", "0005_auto_20220909_0959"),
]

operations = [
migrations.CreateModel(
name="MiddlewareModel",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
],
),
]
20 changes: 20 additions & 0 deletions example/home/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ class SimpleModel(models.Model):
pass


def custom_middleware_one(next, root, info, **args):
info.context.custom_middleware_one = True
return next(root, info, **args)


def custom_middleware_two(next, root, info, **args):
if not info.context.custom_middleware_one:
raise Exception("Middleware one should have been called")
if args["id"] == 2:
return None
return next(root, info, **args)


@register_query_field(
"middlewareModel", middleware=[custom_middleware_one, custom_middleware_two]
)
class MiddlewareModel(models.Model):
pass


class HomePage(Page):
pass

Expand Down
27 changes: 26 additions & 1 deletion example/home/test/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from django.contrib.auth.models import AnonymousUser
from django.test import RequestFactory, override_settings
from home.factories import AdvertFactory, BlogPageFactory, SimpleModelFactory
from home.factories import (
AdvertFactory,
BlogPageFactory,
MiddlewareModelFactory,
SimpleModelFactory,
)

from example.tests.test_grapple import BaseGrappleTest

Expand Down Expand Up @@ -80,6 +85,7 @@ def setUp(self):
self.blog_post = BlogPageFactory(parent=self.home, slug="post-one")
self.another_post = BlogPageFactory(parent=self.home, slug="post-two")
self.child_post = BlogPageFactory(parent=self.another_post, slug="post-one")
self.middleware_instance = MiddlewareModelFactory()

def test_query_field_plural(self):
query = """
Expand Down Expand Up @@ -141,6 +147,25 @@ def test_query_field(self):
data = results["data"]["post"]
self.assertEqual(int(data["id"]), self.another_post.id)

def test_multiple_middleware(self):
query = """
query ($id: Int) {
middlewareModel(id: $id) {
id
}
}
"""
results = self.client.execute(
query, variables={"id": 1}, context_value=self.request
)
# Check that both middleware ran ok, value returned means the check for middleware_1 passed in middleware_2
self.assertEqual(int(results["data"]["middlewareModel"]["id"]), 1)
results = self.client.execute(
query, variables={"id": 2}, context_value=self.request
)
# Check that the second middleware failed when id = 2
self.assertEqual(results["data"]["middlewareModel"], None)


class TestRegisterPaginatedQueryField(BaseGrappleTest):
def setUp(self):
Expand Down
7 changes: 5 additions & 2 deletions grapple/middleware.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

from graphene import ResolveInfo
from graphql.execution.middleware import get_middleware_resolvers

Expand Down Expand Up @@ -37,7 +39,8 @@ def resolve(self, next, root, info: ResolveInfo, **kwargs):
field_name = info.field_name
parent_name = info.parent_type.name
if field_name in self.field_middlewares and parent_name in ROOT_TYPES:
for middleware in self.field_middlewares[field_name]:
return middleware(next, root, info, **kwargs)
middlewares = self.field_middlewares[field_name].copy()
while middlewares:
next = partial(middlewares.pop(), next)

return next(root, info, **kwargs)

0 comments on commit f0b73e5

Please sign in to comment.