Skip to content

Commit

Permalink
feat(sqlalchemy): allow repository functions to be filter by expressi…
Browse files Browse the repository at this point in the history
…ons (litestar-org#2265)

* feat: allow queries to be filter by expressions as well.

* feat: allow queries to be filter by expressions as well.

* fix(sqlalchemy-repositories): Fixed type hints and extended filter support for all relevant repository methods.

* test(sqlalchemy-repository-filter): Added tests for relevant repository methods that supports filter.

* fix: updated lock file

* fix: test case update

* feat: allow queries to be filter by expressions as well.

* fix(sqlalchemy-repositories): Fixed type hints and extended filter support for all relevant repository methods.

* test(sqlalchemy-repository-filter): Added tests for relevant repository methods that supports filter.

* fix: updated lock file

* fix: test case update

* chore: remove FIXME as the issue has been fixed in previous commit

* fix: filtering updates

---------

Co-authored-by: Alc-Alc <alc@localhost>
Co-authored-by: Na'aman Hirschfeld <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 6fe22e4 commit 7badaf5
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 30 deletions.
47 changes: 32 additions & 15 deletions litestar/contrib/sqlalchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from sqlalchemy import func as sql_func
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument

from litestar.repository import AbstractAsyncRepository, RepositoryError
from litestar.repository.filters import (
Expand All @@ -42,6 +43,8 @@

DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950

WhereClauseT = ColumnExpressionArgument[bool]


class SQLAlchemyAsyncRepository(AbstractAsyncRepository[ModelT], Generic[ModelT]):
"""SQLAlchemy based implementation of the repository interface."""
Expand Down Expand Up @@ -247,7 +250,7 @@ async def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
async def exists(self, *filters: FilterTypes | ColumnElement[bool], **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
Expand Down Expand Up @@ -438,7 +441,7 @@ async def get_or_create(

async def count(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
) -> int:
Expand All @@ -455,9 +458,10 @@ async def count(
"""
statement = self._get_base_stmt(statement)
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
results = await self._execute(statement)
return results.scalar_one() # type: ignore

Expand Down Expand Up @@ -567,7 +571,7 @@ def _get_update_many_statement(model_type: type[ModelT], supports_returning: boo

async def list_and_count(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
Expand Down Expand Up @@ -627,7 +631,7 @@ async def _refresh(

async def _list_and_count_window(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -663,7 +667,7 @@ async def _list_and_count_window(

async def _list_and_count_basic(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -795,7 +799,7 @@ async def upsert_many(

async def list(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -890,7 +894,10 @@ def _apply_limit_offset_pagination(
return statement

def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: StatementLambdaElement
self,
*filters: FilterTypes | ColumnElement[bool],
apply_pagination: bool = True,
statement: StatementLambdaElement,
) -> StatementLambdaElement:
"""Apply filters to a select statement.
Expand All @@ -906,7 +913,9 @@ def _apply_filters(
The select with filters applied.
"""
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if isinstance(filter_, ColumnElement):
statement = self._filter_by_expression(expression=filter_, statement=statement)
elif isinstance(filter_, LimitOffset):
if apply_pagination:
statement = self._apply_limit_offset_pagination(filter_.limit, filter_.offset, statement=statement)
elif isinstance(filter_, BeforeAfter):
Expand Down Expand Up @@ -987,10 +996,18 @@ def _filter_select_by_kwargs(
statement = self._filter_by_where(statement, key, val) # pyright: ignore[reportGeneralTypeIssues]
return statement

def _filter_by_where(self, statement: StatementLambdaElement, key: str, val: Any) -> StatementLambdaElement:
def _filter_by_expression(
self, statement: StatementLambdaElement, expression: ColumnElement[bool]
) -> StatementLambdaElement:
statement += lambda s: s.filter(expression)
return statement

def _filter_by_where(
self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: Any
) -> StatementLambdaElement:
model_type = self.model_type
field = get_instrumented_attr(model_type, key)
statement += lambda s: s.where(field == val)
field = get_instrumented_attr(model_type, field_name)
statement += lambda s: s.where(field == value)
return statement

def _filter_by_like(
Expand All @@ -1005,9 +1022,9 @@ def _filter_by_like(
return statement

def _filter_by_not_like(
self, statement: StatementLambdaElement, field_name: str, value: str, ignore_case: bool
self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool
) -> StatementLambdaElement:
field = getattr(self.model_type, field_name)
field = get_instrumented_attr(self.model_type, field_name)
search_text = f"%{value}%"
if ignore_case:
statement += lambda s: s.where(field.not_ilike(search_text))
Expand Down
47 changes: 32 additions & 15 deletions litestar/contrib/sqlalchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from sqlalchemy import func as sql_func
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument

from litestar.repository import AbstractSyncRepository, RepositoryError
from litestar.repository.filters import (
Expand All @@ -43,6 +44,8 @@

DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950

WhereClauseT = ColumnExpressionArgument[bool]


class SQLAlchemySyncRepository(AbstractSyncRepository[ModelT], Generic[ModelT]):
"""SQLAlchemy based implementation of the repository interface."""
Expand Down Expand Up @@ -248,7 +251,7 @@ def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
def exists(self, *filters: FilterTypes | ColumnElement[bool], **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
Expand Down Expand Up @@ -439,7 +442,7 @@ def get_or_create(

def count(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
) -> int:
Expand All @@ -456,9 +459,10 @@ def count(
"""
statement = self._get_base_stmt(statement)
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
results = self._execute(statement)
return results.scalar_one() # type: ignore

Expand Down Expand Up @@ -568,7 +572,7 @@ def _get_update_many_statement(model_type: type[ModelT], supports_returning: boo

def list_and_count(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
Expand Down Expand Up @@ -628,7 +632,7 @@ def _refresh(

def _list_and_count_window(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -664,7 +668,7 @@ def _list_and_count_window(

def _list_and_count_basic(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -796,7 +800,7 @@ def upsert_many(

def list(
self,
*filters: FilterTypes,
*filters: FilterTypes | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -891,7 +895,10 @@ def _apply_limit_offset_pagination(
return statement

def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: StatementLambdaElement
self,
*filters: FilterTypes | ColumnElement[bool],
apply_pagination: bool = True,
statement: StatementLambdaElement,
) -> StatementLambdaElement:
"""Apply filters to a select statement.
Expand All @@ -907,7 +914,9 @@ def _apply_filters(
The select with filters applied.
"""
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if isinstance(filter_, ColumnElement):
statement = self._filter_by_expression(expression=filter_, statement=statement)
elif isinstance(filter_, LimitOffset):
if apply_pagination:
statement = self._apply_limit_offset_pagination(filter_.limit, filter_.offset, statement=statement)
elif isinstance(filter_, BeforeAfter):
Expand Down Expand Up @@ -988,10 +997,18 @@ def _filter_select_by_kwargs(
statement = self._filter_by_where(statement, key, val) # pyright: ignore[reportGeneralTypeIssues]
return statement

def _filter_by_where(self, statement: StatementLambdaElement, key: str, val: Any) -> StatementLambdaElement:
def _filter_by_expression(
self, statement: StatementLambdaElement, expression: ColumnElement[bool]
) -> StatementLambdaElement:
statement += lambda s: s.filter(expression)
return statement

def _filter_by_where(
self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: Any
) -> StatementLambdaElement:
model_type = self.model_type
field = get_instrumented_attr(model_type, key)
statement += lambda s: s.where(field == val)
field = get_instrumented_attr(model_type, field_name)
statement += lambda s: s.where(field == value)
return statement

def _filter_by_like(
Expand All @@ -1006,9 +1023,9 @@ def _filter_by_like(
return statement

def _filter_by_not_like(
self, statement: StatementLambdaElement, field_name: str, value: str, ignore_case: bool
self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool
) -> StatementLambdaElement:
field = getattr(self.model_type, field_name)
field = get_instrumented_attr(self.model_type, field_name)
search_text = f"%{value}%"
if ignore_case:
statement += lambda s: s.where(field.not_ilike(search_text))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,23 @@ async def test_repo_count_method(author_repo: AuthorRepository) -> None:
assert await maybe_async(author_repo.count()) == 2


async def test_repo_count_method_with_filters(raw_authors: RawRecordData, author_repo: AuthorRepository) -> None:
"""Test SQLAlchemy count with filters.
Args:
author_repo: The author mock repository
"""
assert (
await maybe_async(
author_repo.count(
author_repo.model_type.id == raw_authors[0]["id"],
author_repo.model_type.name == raw_authors[0]["name"],
)
)
== 1
)


async def test_repo_list_and_count_method(raw_authors: RawRecordData, author_repo: AuthorRepository) -> None:
"""Test SQLAlchemy list with count in asyncpg.
Expand All @@ -573,6 +590,27 @@ async def test_repo_list_and_count_method(raw_authors: RawRecordData, author_rep
assert len(collection) == exp_count


async def test_repo_list_and_count_method_with_filters(
raw_authors: RawRecordData, author_repo: AuthorRepository
) -> None:
"""Test SQLAlchemy list with count and filters in asyncpg.
Args:
raw_authors: list of authors pre-seeded into the mock repository
author_repo: The author mock repository
"""
exp_name = raw_authors[0]["name"]
exp_id = raw_authors[0]["id"]
collection, count = await maybe_async(
author_repo.list_and_count(author_repo.model_type.id == exp_id, author_repo.model_type.name == exp_name)
)
assert count == 1
assert isinstance(collection, list)
assert len(collection) == 1
assert collection[0].id == exp_id
assert collection[0].name == exp_name


async def test_repo_list_and_count_basic_method(raw_authors: RawRecordData, author_repo: AuthorRepository) -> None:
"""Test SQLAlchemy basic list with count in asyncpg.
Expand Down Expand Up @@ -624,6 +662,21 @@ async def test_repo_list_method(
assert len(collection) == exp_count


async def test_repo_list_method_with_filters(
raw_authors: RawRecordData,
author_repo: AuthorRepository,
) -> None:
exp_name = raw_authors[0]["name"]
exp_id = raw_authors[0]["id"]
collection = await maybe_async(
author_repo.list(author_repo.model_type.id == exp_id, author_repo.model_type.name == exp_name)
)
assert isinstance(collection, list)
assert len(collection) == 1
assert collection[0].id == exp_id
assert collection[0].name == exp_name


async def test_repo_add_method(
raw_authors: RawRecordData, author_repo: AuthorRepository, author_model: AuthorModel
) -> None:
Expand Down Expand Up @@ -675,6 +728,18 @@ async def test_repo_exists_method(author_repo: AuthorRepository, first_author_id
assert exists


async def test_repo_exists_method_with_filters(
raw_authors: RawRecordData, author_repo: AuthorRepository, first_author_id: Any
) -> None:
exists = await maybe_async(
author_repo.exists(
author_repo.model_type.name == raw_authors[0]["name"],
id=first_author_id,
)
)
assert exists


async def test_repo_update_method(author_repo: AuthorRepository, first_author_id: Any) -> None:
obj = await maybe_async(author_repo.get(first_author_id))
obj.name = "Updated Name"
Expand Down

0 comments on commit 7badaf5

Please sign in to comment.