Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(experimental): quota protocol fix #572

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions python/src/uagents/experimental/quota/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def handle(ctx: Context, sender: str, msg: ExampleMessage3):
runtime. This can be useful for dynamic access control rules based on the state of the
agent or the network.
```python
acl = AccessControlList(default=True, allowed={""}, blocked={""})
acl = AccessControlList(default=True)
@proto.on_message(model=Message, access_control_list=acl)
async def message_handler(ctx: Context, sender: str, msg: Message):
Expand Down Expand Up @@ -97,8 +97,9 @@ class RateLimit(BaseModel):

class AccessControlList(BaseModel):
default: bool
allowed: set[str]
blocked: set[str]
allowed: set[str] = set()
blocked: set[str] = set()
bypass_rate_limit: set[str] = set()


class QuotaProtocol(Protocol):
Expand All @@ -108,6 +109,7 @@ def __init__(
name: Optional[str] = None,
version: Optional[str] = None,
default_rate_limit: Optional[RateLimit] = None,
default_acl: Optional[AccessControlList] = None,
):
"""
Initialize a QuotaProtocol instance.
Expand All @@ -116,11 +118,15 @@ def __init__(
storage_reference (StorageAPI): The storage reference to use for rate limiting.
name (Optional[str], optional): The name of the protocol. Defaults to None.
version (Optional[str], optional): The version of the protocol. Defaults to None.
acl (Optional[AccessControlList], optional): The access control list. Defaults to None.
default_rate_limit (Optional[RateLimit], optional): The default rate limit.
Defaults to None.
default_acl (Optional[AccessControlList], optional): The access control list.
Defaults to None.
"""
super().__init__(name=name, version=version)
self.storage_ref = storage_reference
self.default_rate_limit = default_rate_limit
self.default_acl = default_acl

def on_message(
self,
Expand Down Expand Up @@ -172,8 +178,9 @@ def wrap(
Returns:
Callable: The decorated
"""
acl = acl or self.default_acl
if acl is None:
acl = AccessControlList(default=True, allowed=set(), blocked=set())
acl = AccessControlList(default=True)

rate_limit = rate_limit or self.default_rate_limit

Expand All @@ -184,28 +191,26 @@ async def decorator(ctx: Context, sender: str, msg: Type[Model]):
):
return await ctx.send(
sender,
ErrorMessage(
error=("You are not allowed to access this endpoint.")
),
ErrorMessage(error=("You are not allowed to access this handler.")),
)
if (
sender in acl.bypass_rate_limit
or not rate_limit
or self.add_request(
sender,
func.__name__,
rate_limit.window_size_minutes,
rate_limit.max_requests,
)
if not rate_limit or self.add_request(
sender,
func.__name__,
rate_limit.window_size_minutes,
rate_limit.max_requests,
):
result = await func(ctx, sender, msg)
else:
result = await ctx.send(
sender,
ErrorMessage(
error=(
f"Rate limit exceeded for {msg.schema()["title"]}. "
f"This endpoint allows for {rate_limit.max_requests} calls per "
f"{rate_limit.window_size_minutes} minutes. Try again later."
)
),
err = (
f"Rate limit exceeded for {msg.schema()['title']}. "
f"This handler allows for {rate_limit.max_requests} calls per "
f"{rate_limit.window_size_minutes} minutes. Try again later."
)
result = await ctx.send(sender, ErrorMessage(error=err))
return result

return decorator # type: ignore
Expand Down
Loading