Skip to content

Commit

Permalink
fix(experimental): quota protocol fix (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
Archento authored Oct 29, 2024
1 parent 64d38c8 commit 9c97a1e
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions python/src/uagents/experimental/quota/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def handle(ctx: Context, sender: str, msg: ExampleMessage3):
runtime. This can be useful for dynamic access control rules based on the state of the
agent or the network.
```python
acl = AccessControlList(default=True, allowed={""}, blocked={""})
acl = AccessControlList(default=True)
@proto.on_message(model=Message, access_control_list=acl)
async def message_handler(ctx: Context, sender: str, msg: Message):
Expand Down Expand Up @@ -97,8 +97,9 @@ class RateLimit(BaseModel):

class AccessControlList(BaseModel):
default: bool
allowed: set[str]
blocked: set[str]
allowed: set[str] = set()
blocked: set[str] = set()
bypass_rate_limit: set[str] = set()


class QuotaProtocol(Protocol):
Expand All @@ -108,6 +109,7 @@ def __init__(
name: Optional[str] = None,
version: Optional[str] = None,
default_rate_limit: Optional[RateLimit] = None,
default_acl: Optional[AccessControlList] = None,
):
"""
Initialize a QuotaProtocol instance.
Expand All @@ -116,11 +118,15 @@ def __init__(
storage_reference (StorageAPI): The storage reference to use for rate limiting.
name (Optional[str], optional): The name of the protocol. Defaults to None.
version (Optional[str], optional): The version of the protocol. Defaults to None.
acl (Optional[AccessControlList], optional): The access control list. Defaults to None.
default_rate_limit (Optional[RateLimit], optional): The default rate limit.
Defaults to None.
default_acl (Optional[AccessControlList], optional): The access control list.
Defaults to None.
"""
super().__init__(name=name, version=version)
self.storage_ref = storage_reference
self.default_rate_limit = default_rate_limit
self.default_acl = default_acl

def on_message(
self,
Expand Down Expand Up @@ -172,8 +178,9 @@ def wrap(
Returns:
Callable: The decorated
"""
acl = acl or self.default_acl
if acl is None:
acl = AccessControlList(default=True, allowed=set(), blocked=set())
acl = AccessControlList(default=True)

rate_limit = rate_limit or self.default_rate_limit

Expand All @@ -184,28 +191,26 @@ async def decorator(ctx: Context, sender: str, msg: Type[Model]):
):
return await ctx.send(
sender,
ErrorMessage(
error=("You are not allowed to access this endpoint.")
),
ErrorMessage(error=("You are not allowed to access this handler.")),
)
if (
sender in acl.bypass_rate_limit
or not rate_limit
or self.add_request(
sender,
func.__name__,
rate_limit.window_size_minutes,
rate_limit.max_requests,
)
if not rate_limit or self.add_request(
sender,
func.__name__,
rate_limit.window_size_minutes,
rate_limit.max_requests,
):
result = await func(ctx, sender, msg)
else:
result = await ctx.send(
sender,
ErrorMessage(
error=(
f"Rate limit exceeded for {msg.schema()["title"]}. "
f"This endpoint allows for {rate_limit.max_requests} calls per "
f"{rate_limit.window_size_minutes} minutes. Try again later."
)
),
err = (
f"Rate limit exceeded for {msg.schema()['title']}. "
f"This handler allows for {rate_limit.max_requests} calls per "
f"{rate_limit.window_size_minutes} minutes. Try again later."
)
result = await ctx.send(sender, ErrorMessage(error=err))
return result

return decorator # type: ignore
Expand Down

0 comments on commit 9c97a1e

Please sign in to comment.