Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jsonb 'contained by' query #3643

Merged
merged 12 commits into from
Oct 14, 2024
6 changes: 3 additions & 3 deletions mula/scheduler/storage/filters/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def cast_expression(expression: BinaryExpression, filter_: Filter) -> BinaryExpr
# if the value can be decoded.
try:
decoded_value = json.loads(filter_.value)
if isinstance(decoded_value, dict):
# If it's a JSON object, return the expression as is. We don't
# need to cast it.
# If the string is a JSON object, return the expression as is.
# We don't need to cast it.
if isinstance(decoded_value, dict | list):
underdarknl marked this conversation as resolved.
Show resolved Hide resolved
return expression
expression = expression.astext
except json.JSONDecodeError:
Expand Down
28 changes: 28 additions & 0 deletions mula/tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,34 @@ def test_pop_queue_filters_nested(self):
self.assertEqual(second_item_id, response.json().get("id"))
self.assertEqual(0, self.scheduler.queue.qsize())

def test_pop_queue_filters_nested_contained_by(self):
# Add one task to the queue
first_item = create_task_in(1, data=functions.TestModel(id="123", name="test", categories=["foo", "bar"]))
response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item)
self.assertEqual(response.status_code, 201)
self.assertEqual(1, self.scheduler.queue.qsize())

# Add second item to the queue
second_item = create_task_in(2, data=functions.TestModel(id="456", name="test", categories=["baz", "bat"]))
response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item)
second_item_id = response.json().get("id")
self.assertEqual(response.status_code, 201)
self.assertEqual(2, self.scheduler.queue.qsize())

# Test contained by
response = self.client.post(
f"/queues/{self.scheduler.scheduler_id}/pop",
json={
"filters": [
{"column": "data", "operator": "<@", "field": "categories", "value": json.dumps(["baz", "bat"])}
]
},
)

self.assertEqual(200, response.status_code)
self.assertEqual(second_item_id, response.json().get("id"))
self.assertEqual(1, self.scheduler.queue.qsize())

def test_pop_empty(self):
"""When queue is empty it should return an empty response"""
response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/pop")
Expand Down
38 changes: 32 additions & 6 deletions mula/tests/unit/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,31 @@ def setUp(self):
age=25,
height=1.8,
is_active=True,
data={"foo": "bar", "score": 15, "nested": {"bar": "baz"}, "list": ["foo", "bar"]},
data={"foo": "bar", "score": 15, "nested": {"bar": "baz"}, "list": ["ipv4", "network/local"]},
),
TestModel(
name="Bob",
age=30,
height=1.7,
is_active=False,
data={"foo": "baz", "score": 25, "nested": {"bar": "baz"}, "list": ["bar", "baz"]},
data={
"foo": "baz",
"score": 25,
"nested": {"bar": "baz"},
"list": ["ipv4", "ipv6", "network/local"],
},
),
TestModel(
name="Charlie",
age=28,
height=1.6,
is_active=True,
data={"foo": "bar", "score": 35, "nested": {"bar": "baz"}, "list": ["foo", "bar"]},
data={
"foo": "bar",
"score": 35,
"nested": {"bar": "baz"},
"list": ["ipv4", "ipv6", "network/internet"],
},
),
]
)
Expand Down Expand Up @@ -705,9 +715,25 @@ def test_apply_filter_jsonb_contains(self):
self.assertEqual(results[0].name, "Alice")
self.assertEqual(results[1].name, "Charlie")

def test_apply_filter_jsonb_in_list(self):
def test_apply_filter_jsonb_contains_list(self):
filter_request = FilterRequest(
filters=[Filter(column="data", operator="@>", value=json.dumps({"list": ["foo"]}))]
filters=[Filter(column="data", field="list", operator="@>", value=json.dumps(["ipv4"]))]
)

query = session.query(TestModel)
filtered_query = apply_filter(TestModel, query, filter_request)

results = filtered_query.order_by(TestModel.name).all()
self.assertEqual(len(results), 3)
self.assertEqual(results[0].name, "Alice")
self.assertEqual(results[1].name, "Bob")
self.assertEqual(results[2].name, "Charlie")

def test_apply_filter_jsonb_contained_by_list(self):
filter_request = FilterRequest(
filters=[
Filter(column="data", field="list", operator="<@", value=json.dumps(["ipv4", "ipv6", "network/local"]))
]
)

query = session.query(TestModel)
Expand All @@ -716,4 +742,4 @@ def test_apply_filter_jsonb_in_list(self):
results = filtered_query.order_by(TestModel.name).all()
self.assertEqual(len(results), 2)
self.assertEqual(results[0].name, "Alice")
self.assertEqual(results[1].name, "Charlie")
self.assertEqual(results[1].name, "Bob")