From 1b67a4f61b52a8cd2857818d9a8a97e91f28573a Mon Sep 17 00:00:00 2001 From: James Riehl <33920192+jrriehl@users.noreply.github.com> Date: Fri, 8 Nov 2024 12:56:35 +0000 Subject: [PATCH] feat(core): add batch almanac api and contract registrations for Bureau (#551) --- python/src/uagents/agent.py | 209 ++++++++++++++++++++---- python/src/uagents/crypto/__init__.py | 4 +- python/src/uagents/network.py | 160 +++++++++++++++---- python/src/uagents/registration.py | 220 ++++++++++++++++++++++++-- python/src/uagents/types.py | 1 + python/tests/test_bureau.py | 110 +++++++++++++ 6 files changed, 626 insertions(+), 78 deletions(-) create mode 100644 python/tests/test_bureau.py diff --git a/python/src/uagents/agent.py b/python/src/uagents/agent.py index c8231472..9c4656dc 100644 --- a/python/src/uagents/agent.py +++ b/python/src/uagents/agent.py @@ -50,7 +50,11 @@ from uagents.registration import ( AgentRegistrationPolicy, AgentStatusUpdate, + BatchLedgerRegistrationPolicy, + BatchRegistrationPolicy, + DefaultBatchRegistrationPolicy, DefaultRegistrationPolicy, + LedgerBasedRegistrationPolicy, update_agent_status, ) from uagents.resolver import GlobalResolver, Resolver @@ -372,15 +376,18 @@ def __init__( self._on_shutdown = [] self._test = test self._version = version or "0.1.0" - self._registration_policy = registration_policy or DefaultRegistrationPolicy( - self._identity, - self._ledger, - self._wallet, - self._almanac_contract, - self._test, - logger=self._logger, - almanac_api=self._almanac_api_url, - ) + self._registration_policy = registration_policy or None + + if self._registration_policy is None: + self._registration_policy = DefaultRegistrationPolicy( + self._identity, + self._ledger, + self._wallet, + self._almanac_contract, + self._test, + logger=self._logger, + almanac_api=self._almanac_api_url, + ) self._metadata = self._initialize_metadata(metadata) self.initialize_wallet_messaging(enable_wallet_messaging) @@ -642,6 +649,21 @@ def balance(self) -> int: return self.ledger.query_bank_balance(Address(self.wallet.address())) + @property + def info(self) -> AgentInfo: + """ + Get basic information about the agent. + + Returns: + AgentInfo: The agent's address, endpoints, protocols, and metadata. + """ + return AgentInfo( + agent_address=self.address, + endpoints=self._endpoints, + protocols=list(self.protocols.keys()), + metadata=self.metadata, + ) + @property def metadata(self) -> Dict[str, Any]: """ @@ -699,19 +721,30 @@ def sign_digest(self, digest: bytes) -> str: """ return self._identity.sign_digest(digest) - def sign_registration(self, current_time: int) -> str: + def sign_registration( + self, timestamp: int, sender_wallet_address: Optional[str] = None + ) -> str: """ Sign the registration data for Almanac contract. + + Args: + timestamp (int): The timestamp for the registration. + sender_wallet_address (Optional[str]): The wallet address of the transaction sender. + Returns: str: The signature of the registration data. + Raises: - AssertionError: If the Almanac contract address is None. + AssertionError: If the Almanac contract is None. """ - assert self._almanac_contract.address is not None + sender_address = sender_wallet_address or str(self.wallet.address()) + + assert self._almanac_contract is not None + return self._identity.sign_registration( str(self._almanac_contract.address), - current_time, - str(self.wallet.address()), + timestamp, + sender_address, ) def update_endpoints(self, endpoints: List[AgentEndpoint]): @@ -722,7 +755,6 @@ def update_endpoints(self, endpoints: List[AgentEndpoint]): endpoints (List[AgentEndpoint]): List of endpoint dictionaries. """ - self._endpoints = endpoints def update_loop(self, loop): @@ -766,11 +798,13 @@ async def register(self): if necessary. """ + assert self._registration_policy is not None, "Agent has no registration policy" + await self._registration_policy.register( self.address, list(self.protocols.keys()), self._endpoints, self._metadata ) - async def _registration_loop(self): + async def _schedule_registration(self): """ Execute the registration loop. @@ -784,12 +818,12 @@ async def _registration_loop(self): except InsufficientFundsError: time_until_next_registration = 2 * AVERAGE_BLOCK_INTERVAL except Exception as ex: - self._logger.exception(f"Failed to register on almanac contract: {ex}") + self._logger.exception(f"Failed to register: {ex}") time_until_next_registration = REGISTRATION_RETRY_INTERVAL_SECONDS # schedule the next registration update self._loop.create_task( - _delay(self._registration_loop(), time_until_next_registration) + _delay(self._schedule_registration(), time_until_next_registration) ) def on_interval( @@ -1064,13 +1098,13 @@ async def _startup(self): Perform startup actions. """ - if self._endpoints: - await self._registration_loop() - - else: - self._logger.warning( - "No endpoints provided. Skipping registration: Agent won't be reachable." - ) + if self._registration_policy: + if self._endpoints: + await self._schedule_registration() + else: + self._logger.warning( + "No endpoints provided. Skipping registration: Agent won't be reachable." + ) for handler in self._on_startup: try: ctx = self._build_context() @@ -1339,6 +1373,12 @@ def __init__( agents: Optional[List[Agent]] = None, port: Optional[int] = None, endpoint: Optional[Union[str, List[str], Dict[str, dict]]] = None, + agentverse: Optional[Union[str, Dict[str, str]]] = None, + registration_policy: Optional[BatchRegistrationPolicy] = None, + ledger: Optional[LedgerClient] = None, + wallet: Optional[LocalWallet] = None, + seed: Optional[str] = None, + test: bool = True, loop: Optional[asyncio.AbstractEventLoop] = None, log_level: Union[int, str] = logging.INFO, ): @@ -1346,9 +1386,16 @@ def __init__( Initialize a Bureau instance. Args: - port (Optional[int]): The port on which the bureau's server will run. - endpoint (Optional[Union[str, List[str], Dict[str, dict]]]): The endpoint configuration - for the bureau. + agents (Optional[List[Agent]]): The list of agents to be managed by the bureau. + port (Optional[int]): The port number for the server. + endpoint (Optional[Union[str, List[str], Dict[str, dict]]]): The endpoint configuration. + agentverse (Optional[Union[str, Dict[str, str]]]): The agentverse configuration. + registration_policy (Optional[BatchRegistrationPolicy]): The registration policy. + wallet (Optional[LocalWallet]): The wallet for the bureau (overrides 'seed'). + seed (Optional[str]): The seed phrase for the wallet (overridden by 'wallet'). + test (Optional[bool]): True if the bureau will register and transact on the testnet. + loop (Optional[asyncio.AbstractEventLoop]): The event loop. + log_level (Union[int, str]): The logging level for the bureau. """ self._loop = loop or asyncio.get_event_loop_policy().get_event_loop() self._agents: List[Agent] = [] @@ -1362,31 +1409,124 @@ def __init__( queries=self._queries, logger=self._logger, ) - self._use_mailbox = False + self._agentverse = parse_agentverse_config(agentverse) + self._use_mailbox = self._agentverse["use_mailbox"] + almanac_api_url = f"{self._agentverse['http_prefix']}://{self._agentverse['base_url']}/v1/almanac" + almanac_contract = get_almanac_contract(test) + + if wallet and seed: + self._logger.warning( + "Ignoring 'seed' argument because 'wallet' is provided." + ) + elif seed: + wallet = LocalWallet( + PrivateKey(derive_key_from_seed(seed, LEDGER_PREFIX, 0)), + prefix=LEDGER_PREFIX, + ) + + if registration_policy is not None: + if ( + isinstance(registration_policy, BatchLedgerRegistrationPolicy) + and wallet is None + ): + raise ValueError( + "Argument 'wallet' must be provided when using " + "the batch ledger registration policy." + ) + self._registration_policy = registration_policy + else: + self._registration_policy = DefaultBatchRegistrationPolicy( + ledger or get_ledger(test), + wallet, + almanac_contract, + test, + logger=self._logger, + almanac_api=almanac_api_url, + ) if agents is not None: for agent in agents: self.add(agent) - def add(self, agent: Agent): + def _update_agent(self, agent: Agent): """ - Add an agent to the bureau. + Update the agent to be taken over by the Bureau. Args: - agent (Agent): The agent to be added. + agent (Agent): The agent to be updated. """ - if agent in self._agents: - return agent.update_loop(self._loop) agent.update_queries(self._queries) if agent.agentverse["use_mailbox"]: self._use_mailbox = True else: + if agent._endpoints: + self._logger.warning( + f"Overwriting the agent's endpoints {agent._endpoints} " + f"with the Bureau's endpoints {self._endpoints}." + ) agent.update_endpoints(self._endpoints) self._server._rest_handler_map.update(agent._server._rest_handler_map) + + # Run the batch Almanac API registration by default and only run the agent's + # ledger registration if the Bureau is not using a batch ledger registration + # policy because it has no wallet address. + agent._registration_policy = None + if ( + isinstance(self._registration_policy, DefaultBatchRegistrationPolicy) + and self._registration_policy._ledger_policy is None + and agent._almanac_contract is not None + ): + agent._registration_policy = LedgerBasedRegistrationPolicy( + agent._identity, + agent._ledger, + agent._wallet, + agent._almanac_contract, + agent._test, + logger=agent._logger, + ) + + agent._agentverse = self._agentverse + agent._logger.setLevel(self._logger.level) + + def add(self, agent: Agent): + """ + Add an agent to the bureau. + + Args: + agent (Agent): The agent to be added. + + """ + if agent in self._agents: + return + self._registration_policy.add_agent(agent.info, agent._identity) + self._update_agent(agent) self._agents.append(agent) + async def _schedule_registration(self): + """ + Start the batch registration loop. + + """ + + if not any(agent._endpoints for agent in self._agents): + return + + time_to_next_registration = REGISTRATION_UPDATE_INTERVAL_SECONDS + try: + await self._registration_policy.register() + except InsufficientFundsError: + time_to_next_registration = 2 * AVERAGE_BLOCK_INTERVAL + except Exception as ex: + self._logger.exception(f"Failed to register: {ex}") + time_to_next_registration = REGISTRATION_RETRY_INTERVAL_SECONDS + + # schedule the next registration update + self._loop.create_task( + _delay(self._schedule_registration(), time_to_next_registration) + ) + async def run_async(self): """ Run the agents managed by the bureau. @@ -1397,6 +1537,7 @@ async def run_async(self): await agent.setup() if agent.agentverse["use_mailbox"] and agent.mailbox_client is not None: tasks.append(agent.mailbox_client.run()) + tasks.append(self._schedule_registration()) try: await asyncio.gather(*tasks) diff --git a/python/src/uagents/crypto/__init__.py b/python/src/uagents/crypto/__init__.py index d19e7219..d18e2cd6 100644 --- a/python/src/uagents/crypto/__init__.py +++ b/python/src/uagents/crypto/__init__.py @@ -144,14 +144,14 @@ def sign_digest(self, digest: bytes) -> str: def sign_registration( self, contract_address: str, - sequence: int, + timestamp: int, wallet_address: str, ) -> str: """Sign the registration data for the Almanac contract.""" hasher = hashlib.sha256() hasher.update(encode_length_prefixed(contract_address)) hasher.update(encode_length_prefixed(self.address)) - hasher.update(encode_length_prefixed(sequence)) + hasher.update(encode_length_prefixed(timestamp)) hasher.update(encode_length_prefixed(wallet_address)) return self.sign_digest(hasher.digest()) diff --git a/python/src/uagents/network.py b/python/src/uagents/network.py index c591d5ed..51dd25a6 100644 --- a/python/src/uagents/network.py +++ b/python/src/uagents/network.py @@ -1,8 +1,9 @@ """Network and Contracts.""" import asyncio +import time from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from cosmpy.aerial.client import ( DEFAULT_QUERY_INTERVAL_SECS, @@ -22,6 +23,7 @@ from uagents.config import ( ALMANAC_CONTRACT_VERSION, + ALMANAC_REGISTRATION_WAIT, AVERAGE_BLOCK_INTERVAL, MAINNET_CONTRACT_ALMANAC, MAINNET_CONTRACT_NAME_SERVICE, @@ -30,7 +32,8 @@ TESTNET_CONTRACT_ALMANAC, TESTNET_CONTRACT_NAME_SERVICE, ) -from uagents.types import AgentEndpoint +from uagents.crypto import Identity +from uagents.types import AgentEndpoint, AgentInfo from uagents.utils import get_logger logger = get_logger("network") @@ -45,6 +48,19 @@ class InsufficientFundsError(Exception): """Raised when an agent has insufficient funds for a transaction.""" +class AlmanacContractRecord(AgentInfo): + contract_address: str + sender_address: str + timestamp: Optional[int] = None + signature: Optional[str] = None + + def sign(self, identity: Identity): + self.timestamp = int(time.time()) - ALMANAC_REGISTRATION_WAIT + self.signature = identity.sign_registration( + self.contract_address, self.timestamp, self.sender_address + ) + + def get_ledger(test: bool = True) -> LedgerClient: """ Get the Ledger client. @@ -224,52 +240,98 @@ def is_registered(self, address: str) -> bool: return bool(response.get("record")) - def get_expiry(self, address: str) -> int: + def registration_needs_update( + self, + address: str, + endpoints: List[AgentEndpoint], + protocols: List[str], + min_seconds_left: int, + ) -> bool: + """ + Check if an agent's registration needs to be updated. + + Args: + address (str): The agent's address. + endpoints (List[AgentEndpoint]): The agent's endpoints. + protocols (List[str]): The agent's protocols. + min_time_left (int): The minimum time left before the agent's registration expires + + Returns: + bool: True if the agent's registration needs to be updated or will expire sooner + than the specified minimum time, False otherwise. """ - Get the expiry height of an agent's registration. + seconds_to_expiry, registered_endpoints, registered_protocols = ( + self.query_agent_record(address) + ) + return ( + not self.is_registered(address) + or seconds_to_expiry < min_seconds_left + or endpoints != registered_endpoints + or protocols != registered_protocols + ) + + def query_agent_record( + self, address: str + ) -> Tuple[int, List[AgentEndpoint], List[str]]: + """ + Get the records associated with an agent's registration. Args: address (str): The agent's address. Returns: - int: The expiry height of the agent's registration. + Tuple[int, List[AgentEndpoint], List[str]]: The expiry height of the agent's + registration, the agent's endpoints, and the agent's protocols. """ query_msg = {"query_records": {"agent_address": address}} response = self.query_contract(query_msg) + if not response.get("record"): + return [] + if not response.get("record"): contract_state = self.query_contract({"query_contract_state": {}}) expiry = contract_state.get("state", {}).get("expiry_height", 0) return expiry * AVERAGE_BLOCK_INTERVAL - expiry = response["record"][0].get("expiry", 0) - height = response.get("height", 0) + expiry_block = response["record"][0].get("expiry", 0) + current_block = response.get("height", 0) - return (expiry - height) * AVERAGE_BLOCK_INTERVAL + seconds_to_expiry = (expiry_block - current_block) * AVERAGE_BLOCK_INTERVAL - def get_endpoints(self, address: str) -> List[AgentEndpoint]: + endpoints = [] + for endpoint in response["record"][0]["record"]["service"]["endpoints"]: + endpoints.append(AgentEndpoint.model_validate(endpoint)) + + protocols = response["record"][0]["record"]["service"]["protocols"] + + return seconds_to_expiry, endpoints, protocols + + def get_expiry(self, address: str) -> int: """ - Get the endpoints associated with an agent's registration. + Get the approximate seconds to expiry of an agent's registration. Args: address (str): The agent's address. Returns: - List[AgentEndpoint]: The endpoints associated with the agent's registration. + int: The approximate seconds to expiry of the agent's registration. """ - query_msg = {"query_records": {"agent_address": address}} - response = self.query_contract(query_msg) + return self.query_agent_record(address)[0] - if not response.get("record"): - return [] + def get_endpoints(self, address: str) -> List[AgentEndpoint]: + """ + Get the endpoints associated with an agent's registration. - endpoints = [] - for endpoint in response["record"][0]["record"]["service"]["endpoints"]: - endpoints.append(AgentEndpoint.model_validate(endpoint)) + Args: + address (str): The agent's address. - return endpoints + Returns: + List[AgentEndpoint]: The agent's registered endpoints. + """ + return self.query_agent_record(address)[1] - def get_protocols(self, address: str): + def get_protocols(self, address: str) -> List[str]: """ Get the protocols associated with an agent's registration. @@ -277,15 +339,9 @@ def get_protocols(self, address: str): address (str): The agent's address. Returns: - Any: The protocols associated with the agent's registration. + List[str]: The agent's registered protocols. """ - query_msg = {"query_records": {"agent_address": address}} - response = self.query_contract(query_msg) - - if not response.get("record"): - return None - - return response["record"][0]["record"]["service"]["protocols"] + return self.query_agent_record(address)[2] def get_registration_msg( self, @@ -357,6 +413,54 @@ async def register( ) await wait_for_tx_to_complete(transaction.tx_hash, ledger) + async def register_batch( + self, + ledger: LedgerClient, + wallet: LocalWallet, + agent_records: List[AlmanacContractRecord], + ): + """ + Register multiple agents with the Almanac contract. + + Args: + ledger (LedgerClient): The Ledger client. + wallet (LocalWallet): The wallet of the registration sender. + agents (List[ALmanacContractRecord]): The list of signed agent records to register. + """ + if not self.address: + raise ValueError("Contract address not set") + + transaction = Transaction() + + for record in agent_records: + if record.timestamp is None: + raise ValueError("Agent record is missing timestamp") + + if record.signature is None: + raise ValueError("Agent record is not signed") + + almanac_msg = self.get_registration_msg( + protocols=record.protocols, + endpoints=record.endpoints, + signature=record.signature, + sequence=record.timestamp, + address=record.agent_address, + ) + + transaction.add_message( + create_cosmwasm_execute_msg( + wallet.address(), + self.address, + almanac_msg, + funds=f"{REGISTRATION_FEE}{REGISTRATION_DENOM}", + ) + ) + + transaction = prepare_and_broadcast_basic_transaction( + ledger, transaction, wallet + ) + await wait_for_tx_to_complete(transaction.tx_hash, ledger) + def get_sequence(self, address: str) -> int: """ Get the agent's sequence number for Almanac registration. diff --git a/python/src/uagents/registration.py b/python/src/uagents/registration.py index bbefa0dd..62826517 100644 --- a/python/src/uagents/registration.py +++ b/python/src/uagents/registration.py @@ -16,13 +16,19 @@ ALMANAC_API_MAX_RETRIES, ALMANAC_API_TIMEOUT_SECONDS, ALMANAC_API_URL, + ALMANAC_CONTRACT_VERSION, ALMANAC_REGISTRATION_WAIT, REGISTRATION_FEE, REGISTRATION_UPDATE_INTERVAL_SECONDS, ) from uagents.crypto import Identity -from uagents.network import AlmanacContract, InsufficientFundsError, add_testnet_funds -from uagents.types import AgentEndpoint +from uagents.network import ( + AlmanacContract, + AlmanacContractRecord, + InsufficientFundsError, + add_testnet_funds, +) +from uagents.types import AgentEndpoint, AgentInfo class VerifiableModel(BaseModel): @@ -58,6 +64,10 @@ class AgentRegistrationAttestation(VerifiableModel): metadata: Optional[Dict[str, Union[str, Dict[str, str]]]] = None +class AgentRegistrationAttestationBatch(BaseModel): + attestations: List[AgentRegistrationAttestation] + + class AgentStatusUpdate(VerifiableModel): is_active: bool @@ -88,6 +98,17 @@ def coerce_metadata_to_str( return out +def extract_geo_metadata( + metadata: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + """ + Extract geo-location metadata from the metadata dictionary. + """ + if metadata is None: + return None + return {k: v for k, v in metadata.items() if k == "geolocation"} + + async def almanac_api_post( url: str, data: BaseModel, raise_from: bool = True, retries: int = 3 ) -> bool: @@ -124,6 +145,17 @@ async def register( pass +class BatchRegistrationPolicy(ABC): + @abstractmethod + # pylint: disable=unnecessary-pass + async def register(self): + pass + + @abstractmethod + def add_agent(self, agent_info: AgentInfo, identity: Identity): + pass + + class AlmanacApiRegistrationPolicy(AgentRegistrationPolicy): def __init__( self, @@ -145,18 +177,12 @@ async def register( endpoints: List[AgentEndpoint], metadata: Optional[Dict[str, Any]] = None, ): - clean_metadata = ( - {k: v for k, v in metadata.items() if k == "geolocation"} - if metadata - else None - ) # only keep geolocation metadata for registration - # create the attestation attestation = AgentRegistrationAttestation( agent_address=agent_address, protocols=protocols, endpoints=endpoints, - metadata=coerce_metadata_to_str(clean_metadata), + metadata=coerce_metadata_to_str(extract_geo_metadata(metadata)), ) # sign the attestation @@ -167,6 +193,41 @@ async def register( ) if success: self._logger.info("Registration on Almanac API successful") + else: + self._logger.warning("Registration on Almanac API failed") + + +class BatchAlmanacApiRegistrationPolicy(AgentRegistrationPolicy): + def __init__( + self, almanac_api: Optional[str] = None, logger: Optional[logging.Logger] = None + ): + self._almanac_api = almanac_api or ALMANAC_API_URL + self._attestations: List[AgentRegistrationAttestation] = [] + self._logger = logger or logging.getLogger(__name__) + + def add_agent(self, agent_info: AgentInfo, identity: Identity): + attestation = AgentRegistrationAttestation( + agent_address=agent_info.agent_address, + protocols=list(agent_info.protocols), + endpoints=agent_info.endpoints, + metadata=coerce_metadata_to_str(extract_geo_metadata(agent_info.metadata)), + ) + attestation.sign(identity) + self._attestations.append(attestation) + + async def register(self): + if not self._attestations: + return + attestations = AgentRegistrationAttestationBatch( + attestations=self._attestations + ) + success = await almanac_api_post( + f"{self._almanac_api}/agents/batch", attestations + ) + if success: + self._logger.info("Batch registration on Almanac API successful") + else: + self._logger.warning("Batch registration on Almanac API failed") class LedgerBasedRegistrationPolicy(AgentRegistrationPolicy): @@ -187,6 +248,20 @@ def __init__( self._almanac_contract = almanac_contract self._logger = logger or logging.getLogger(__name__) + def check_contract_version(self): + """ + Check the version of the deployed Almanac contract and log a warning + if it is different from the supported version. + """ + deployed_version = self._almanac_contract.get_contract_version() + if deployed_version != ALMANAC_CONTRACT_VERSION: + self._logger.warning( + "Mismatch in almanac contract versions: supported (%s), deployed (%s). " + "Update uAgents to the latest version for compatibility.", + ALMANAC_CONTRACT_VERSION, + deployed_version, + ) + async def register( self, agent_address: str, @@ -194,8 +269,12 @@ async def register( endpoints: List[AgentEndpoint], metadata: Optional[Dict[str, Any]] = None, ): - # register if not yet registered or registration is about to expire - # or anything has changed from the last registration + """ + Register the agent on the Almanac contract if registration is about to expire or + the registration data has changed. + """ + self.check_contract_version() + if ( not self._almanac_contract.is_registered(agent_address) or self._almanac_contract.get_expiry(agent_address) @@ -239,10 +318,13 @@ async def register( def _get_balance(self) -> int: return self._ledger.query_bank_balance(Address(self._wallet.address())) - def _sign_registration(self, current_time: int) -> str: + def _sign_registration(self, timestamp: int) -> str: """ Sign the registration data for Almanac contract. + Args: + timestamp (int): The timestamp for the registration. + Returns: str: The signature of the registration data. @@ -253,18 +335,75 @@ def _sign_registration(self, current_time: int) -> str: assert self._almanac_contract.address is not None return self._identity.sign_registration( str(self._almanac_contract.address), - current_time, + timestamp, str(self._wallet.address()), ) +class BatchLedgerRegistrationPolicy(BatchRegistrationPolicy): + def __init__( + self, + ledger: LedgerClient, + wallet: LocalWallet, + almanac_contract: AlmanacContract, + testnet: bool, + logger: Optional[logging.Logger] = None, + ): + self._ledger = ledger + self._wallet = wallet + self._almanac_contract = almanac_contract + self._testnet = testnet + self._logger = logger or logging.getLogger(__name__) + self._records: List[AlmanacContractRecord] = [] + self._identities: Dict[str, Identity] = {} + + def add_agent(self, agent_info: AgentInfo, identity: Identity): + agent_record = AlmanacContractRecord( + agent_address=agent_info.agent_address, + protocols=agent_info.protocols, + endpoints=agent_info.endpoints, + contract_address=str(self._almanac_contract.address), + sender_address=str(self._wallet.address()), + ) + self._records.append(agent_record) + self._identities[agent_info.agent_address] = identity + + def _get_balance(self) -> int: + return self._ledger.query_bank_balance(Address(self._wallet.address())) + + async def register(self): + self._logger.info("Registering agents on Almanac contract...") + for record in self._records: + record.sign(self._identities[record.agent_address]) + + if self._get_balance() < REGISTRATION_FEE * len(self._records): + self._logger.warning( + f"I do not have enough funds to register {len(self._records)} " + "agents on Almanac contract" + ) + if self._testnet: + add_testnet_funds(str(self._wallet.address())) + self._logger.info(f"Adding testnet funds to {self._wallet.address()}") + else: + self._logger.info( + f"Send funds to wallet address: {self._wallet.address()}" + ) + raise InsufficientFundsError() + + await self._almanac_contract.register_batch( + self._ledger, self._wallet, self._records + ) + + self._logger.info("Registering agents on Almanac contract...complete") + + class DefaultRegistrationPolicy(AgentRegistrationPolicy): def __init__( self, identity: Identity, ledger: LedgerClient, wallet: LocalWallet, - almanac_contract: AlmanacContract, + almanac_contract: Optional[AlmanacContract], testnet: bool, *, logger: Optional[logging.Logger] = None, @@ -322,3 +461,56 @@ async def update_agent_status(status: AgentStatusUpdate, almanac_api: str): status, raise_from=False, ) + + +class DefaultBatchRegistrationPolicy(BatchRegistrationPolicy): + def __init__( + self, + ledger: LedgerClient, + wallet: Optional[LocalWallet] = None, + almanac_contract: Optional[AlmanacContract] = None, + testnet: bool = True, + *, + logger: Optional[logging.Logger] = None, + almanac_api: Optional[str] = None, + ): + self._logger = logger or logging.getLogger(__name__) + self._api_policy = BatchAlmanacApiRegistrationPolicy( + almanac_api=almanac_api, logger=logger + ) + + if almanac_contract is None or wallet is None: + self._ledger_policy = None + else: + self._ledger_policy = BatchLedgerRegistrationPolicy( + ledger, wallet, almanac_contract, testnet, logger=logger + ) + + def add_agent(self, agent_info: AgentInfo, identity: Identity): + self._api_policy.add_agent(agent_info, identity) + if self._ledger_policy is not None: + self._ledger_policy.add_agent(agent_info, identity) + + async def register(self): + # prefer the API registration policy as it is faster + try: + await self._api_policy.register() + except Exception as e: + self._logger.warning( + f"Failed to batch register on Almanac API: {e.__class__.__name__}" + ) + + if self._ledger_policy is None: + return + + # schedule the ledger registration + try: + await self._ledger_policy.register() + except InsufficientFundsError: + self._logger.warning( + "Failed to batch register on Almanac contract due to insufficient funds" + ) + raise + except Exception as e: + self._logger.error(f"Failed to batch register on Almanac contract: {e}") + raise diff --git a/python/src/uagents/types.py b/python/src/uagents/types.py index 0ef45123..730b38ba 100644 --- a/python/src/uagents/types.py +++ b/python/src/uagents/types.py @@ -47,6 +47,7 @@ class AgentInfo(BaseModel): agent_address: str endpoints: List[AgentEndpoint] protocols: List[str] + metadata: Optional[Dict[str, Any]] = None class RestHandlerDetails(BaseModel): diff --git a/python/tests/test_bureau.py b/python/tests/test_bureau.py new file mode 100644 index 00000000..d32fba68 --- /dev/null +++ b/python/tests/test_bureau.py @@ -0,0 +1,110 @@ +import asyncio +import unittest + +from cosmpy.aerial.wallet import LocalWallet + +from uagents import Agent, Bureau +from uagents.registration import ( + AgentEndpoint, + BatchLedgerRegistrationPolicy, + DefaultBatchRegistrationPolicy, + DefaultRegistrationPolicy, + LedgerBasedRegistrationPolicy, +) + +ALICE_ENDPOINT = AgentEndpoint(url="http://alice:8000/submit", weight=1) +BOB_ENDPOINT = AgentEndpoint(url="http://bob:8000/submit", weight=1) +BUREAU_ENDPOINT = AgentEndpoint(url="http://bureau:8000/submit", weight=1) + + +bureau_wallet = LocalWallet.generate() + + +class TestBureau(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.loop = asyncio.get_event_loop() + return super().setUp() + + def test_bureau_updates_agents_no_ledger_batch(self): + alice = Agent(name="alice", endpoint=ALICE_ENDPOINT.url, loop=self.loop) + bob = Agent(name="bob", endpoint=BOB_ENDPOINT.url) + + assert alice._endpoints == [ALICE_ENDPOINT] + assert bob._endpoints == [BOB_ENDPOINT] + + assert isinstance(alice._registration_policy, DefaultRegistrationPolicy) + assert isinstance(bob._registration_policy, DefaultRegistrationPolicy) + + bureau = Bureau(agents=[alice, bob], endpoint=BUREAU_ENDPOINT.url) + + assert alice._endpoints == [BUREAU_ENDPOINT] + assert bob._endpoints == [BUREAU_ENDPOINT] + + assert isinstance(bureau._registration_policy, DefaultBatchRegistrationPolicy) + assert bureau._registration_policy._ledger_policy is None + assert isinstance(alice._registration_policy, LedgerBasedRegistrationPolicy) + assert isinstance(bob._registration_policy, LedgerBasedRegistrationPolicy) + + def test_bureau_updates_agents_with_wallet(self): + alice = Agent(name="alice", endpoint=ALICE_ENDPOINT.url) + bob = Agent(name="bob", endpoint=BOB_ENDPOINT.url) + + assert isinstance(alice._registration_policy, DefaultRegistrationPolicy) + assert isinstance(bob._registration_policy, DefaultRegistrationPolicy) + + bureau = Bureau(agents=[alice, bob], wallet=bureau_wallet) + + assert alice._endpoints == [] + assert bob._endpoints == [] + + assert isinstance(bureau._registration_policy, DefaultBatchRegistrationPolicy) + assert isinstance( + bureau._registration_policy._ledger_policy, BatchLedgerRegistrationPolicy + ) + assert alice._registration_policy is None + assert bob._registration_policy is None + + def test_bureau_updates_agents_with_seed(self): + alice = Agent(name="alice", endpoint=ALICE_ENDPOINT.url) + bob = Agent(name="bob", endpoint=BOB_ENDPOINT.url) + + assert isinstance(alice._registration_policy, DefaultRegistrationPolicy) + assert isinstance(bob._registration_policy, DefaultRegistrationPolicy) + + bureau = Bureau( + agents=[alice, bob], + endpoint=BUREAU_ENDPOINT.url, + seed="bureau test seed phrase", + ) + + assert isinstance(bureau._registration_policy, DefaultBatchRegistrationPolicy) + assert isinstance( + bureau._registration_policy._ledger_policy, BatchLedgerRegistrationPolicy + ) + assert alice._registration_policy is None + assert bob._registration_policy is None + + def test_bureau_updates_agents_wallet_overrides_seed(self): + alice = Agent(name="alice", endpoint=ALICE_ENDPOINT.url) + bob = Agent(name="bob", endpoint=BOB_ENDPOINT.url) + + assert isinstance(alice._registration_policy, DefaultRegistrationPolicy) + assert isinstance(bob._registration_policy, DefaultRegistrationPolicy) + + bureau = Bureau( + agents=[alice, bob], + endpoint=BUREAU_ENDPOINT.url, + wallet=bureau_wallet, + seed="bureau test seed phrase", + ) + + assert isinstance(bureau._registration_policy, DefaultBatchRegistrationPolicy) + assert isinstance( + bureau._registration_policy._ledger_policy, BatchLedgerRegistrationPolicy + ) + assert ( + bureau._registration_policy._ledger_policy._wallet.address() + == bureau_wallet.address() + ) + assert alice._registration_policy is None + assert bob._registration_policy is None