Skip to content

Commit

Permalink
Allow extending any model's interfaces via graphql_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
zerolab committed Sep 11, 2023
1 parent 8325bf6 commit 75cbefd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
32 changes: 18 additions & 14 deletions grapple/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,21 +189,21 @@ def get_nested_field(cls, extract_key):
return getattr(cls, extract_key[0])

# Get data from nested field
field = getattr(cls, extract_key[0])
if field is None:
nested_field = getattr(cls, extract_key[0])
if nested_field is None:
return None
if issubclass(type(field), models.Manager):
field = field.all()
if issubclass(type(nested_field), models.Manager):
nested_field = nested_field.all()

# If field data is a list then iterate over it
if isinstance(field, Iterable):
if isinstance(nested_field, Iterable):
return [
get_nested_field(nested_cls, extract_key[1:])
for nested_cls in field
for nested_cls in nested_field
]

# If single value then return it.
return get_nested_field(field, extract_key[1:])
return get_nested_field(nested_field, extract_key[1:])

if field.extract_key:
return [
Expand Down Expand Up @@ -266,18 +266,22 @@ class UnmanagedMeta:
class StubMeta:
model = stub_model

# Gather any interfaces, and discard None values
interfaces = {interface, *getattr(cls, "graphql_interfaces", ())}
interfaces.discard(None)

type_meta = {
"Meta": StubMeta,
"type": lambda: {
"cls": cls,
"lazy": True,
"name": type_name,
"base_type": base_type,
"interface": interface,
"interfaces": tuple(interfaces),
},
}

return type("Stub" + type_name, (DjangoObjectType,), type_meta)
return type(f"Stub{type_name}", (base_type,), type_meta)


def load_type_fields():
Expand All @@ -290,13 +294,13 @@ def load_type_fields():
# Get the original django model data
cls = type_definition.get("cls")
base_type = type_definition.get("base_type")
interface = type_definition.get("interface")
type_name = type_definition.get("name")
_interfaces = type_definition.get("interfaces")

# Recreate the graphene type with the fields set
class Meta:
model = cls
interfaces = (interface,) if interface is not None else ()
interfaces = _interfaces if _interfaces is not None else ()

type_meta = {"Meta": Meta, "id": graphene.ID(), "name": type_name}

Expand All @@ -310,8 +314,8 @@ class Meta:
# Either remove the custom field or remove the field from the "exclude" list.' warning
if (
field == "id"
or hasattr(interface, field)
or hasattr(base_type_for_exclusion_checks, field)
or any(hasattr(interface, field) for interface in _interfaces)
):
continue

Expand Down Expand Up @@ -457,7 +461,7 @@ def build_streamfield_type(
"""

# Alias the argument name so we can use it in the class block
interfaces_ = interfaces
_interfaces = interfaces

# Create a new blank node type
class Meta:
Expand All @@ -466,7 +470,7 @@ class Meta:
registry.streamfield_blocks.get(block) for block in cls.graphql_types
]
else:
interfaces = interfaces_ if interfaces_ is not None else ()
interfaces = _interfaces if _interfaces is not None else ()
# Add description to type if the Meta class declares it
description = getattr(cls._meta_class, "graphql_description", None)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ def test_schema_for_streamfield_block_with_custom_interface(self):
[{"name": "CustomInterface"}, {"name": "StreamFieldInterface"}],
)

def test_schema_for_page_with_graphql_interface(self):
results = self.introspect_schema_by_type("AuthorPage")
self.assertListEqual(
sorted(results["data"]["__type"]["interfaces"], key=lambda x: x["name"]),
[{"name": "CustomInterface"}, {"name": "PageInterface"}],
)

def test_schem_for_snippet_with_graphql_interface(self):
results = self.introspect_schema_by_type("Advert")
self.assertListEqual(
sorted(results["data"]["__type"]["interfaces"], key=lambda x: x["name"]),
[{"name": "CustomInterface"}],
)

def test_schema_for_django_model_with_graphql_interfaces(self):
results = self.introspect_schema_by_type("SimpleModel")
self.assertListEqual(
sorted(results["data"]["__type"]["interfaces"], key=lambda x: x["name"]),
[{"name": "CustomInterface"}],
)


@tag("needs-custom-settings")
@skipUnless(
Expand Down
5 changes: 4 additions & 1 deletion tests/testapp/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from grapple.utils import resolve_paginated_queryset
from testapp.blocks import StreamFieldBlock
from testapp.interfaces import CustomInterface


document_model_string = getattr(
Expand All @@ -49,7 +50,7 @@

@register_singular_query_field("simpleModel")
class SimpleModel(models.Model):
pass
graphql_interfaces = (CustomInterface,)


def custom_middleware_one(next, root, info, **args):
Expand Down Expand Up @@ -82,6 +83,7 @@ class AuthorPage(Page):
content_panels = Page.content_panels + [FieldPanel("name")]

graphql_fields = [GraphQLString("name")]
graphql_interfaces = (CustomInterface,)


class BlogPageTag(TaggedItemBase):
Expand Down Expand Up @@ -262,6 +264,7 @@ class Advert(models.Model):
GraphQLString("string_rich_text", source="rich_text"),
GraphQLString("extra_rich_text", deprecation_reason="Use rich_text instead"),
]
graphql_interfaces = (CustomInterface,)

def __str__(self):
return self.text
Expand Down

0 comments on commit 75cbefd

Please sign in to comment.