diff --git a/deployment/registry.py b/deployment/registry.py index a75521aa..d756b4bc 100644 --- a/deployment/registry.py +++ b/deployment/registry.py @@ -1,5 +1,5 @@ import json -from collections import defaultdict +from collections import OrderedDict, defaultdict from enum import Enum from pathlib import Path from typing import Dict, List, NamedTuple, Optional @@ -148,7 +148,10 @@ class ConflictResolution(Enum): def _select_conflict_resolution( registry_1_entry, registry_1_filepath, registry_2_entry, registry_2_filepath ) -> ConflictResolution: - print(f"\n! Conflict detected for {registry_1_entry.name}:") + print( + f"\n! Conflict detected for {registry_1_entry.name} " + f"on chain id {registry_1_entry.chain_id}:" + ) print( f"[1]: {registry_1_entry.name} at {registry_1_entry.address} " f"for {registry_1_filepath}" ) @@ -196,33 +199,49 @@ def merge_registries( deprecated_contracts = deprecated_contracts or [] # Read the registries, excluding deprecated contracts - reg1 = { - e.name: e for e in read_registry(registry_1_filepath) if e.name not in deprecated_contracts - } - reg2 = { - e.name: e for e in read_registry(registry_2_filepath) if e.name not in deprecated_contracts - } + reg1 = defaultdict(OrderedDict) + reg2 = defaultdict(OrderedDict) + + for e in read_registry(registry_1_filepath): + if e.name in deprecated_contracts: + continue + reg1[str(e.chain_id)][e.name] = e + + for e in read_registry(registry_2_filepath): + if e.name in deprecated_contracts: + continue + reg2[str(e.chain_id)][e.name] = e merged: List[RegistryEntry] = list() - # Iterate over all unique contract names across both registries - conflicts, contracts = set(reg1) & set(reg2), set(reg1) | set(reg2) - for name in contracts: - entry_1, entry_2 = reg1.get(name), reg2.get(name) - conflict = name in conflicts and entry_1.chain_id == entry_2.chain_id - if conflict: - resolution = _select_conflict_resolution( - registry_1_entry=entry_1, - registry_2_entry=entry_2, - registry_1_filepath=registry_1_filepath, - registry_2_filepath=registry_2_filepath, - ) - selected_entry = entry_1 if resolution == ConflictResolution.USE_1 else entry_2 + # Iterate over all chains and unique contract names across both registries + all_chains = set(reg1) | set(reg2) + common_chains = set(reg1) & set(reg2) + for chain in all_chains: + reg1_chain_entries, reg2_chain_entries = reg1.get(str(chain), {}), reg2.get(str(chain), {}) + if chain in common_chains: + # check for conflicting contracts + all_contracts = set(reg1_chain_entries) | set(reg2_chain_entries) + for name in all_contracts: + entry_1, entry_2 = reg1_chain_entries.get(name), reg2_chain_entries.get(name) + if entry_1 and entry_2: + # entries for the same name (same chain) + resolution = _select_conflict_resolution( + registry_1_entry=entry_1, + registry_2_entry=entry_2, + registry_1_filepath=registry_1_filepath, + registry_2_filepath=registry_2_filepath, + ) + selected_entry = entry_1 if resolution == ConflictResolution.USE_1 else entry_2 + else: + selected_entry = entry_1 or entry_2 + + # commit the selected entry + merged.append(selected_entry) else: - selected_entry = entry_1 or entry_2 - - # commit the selected entry - merged.append(selected_entry) + # not a common chain so just move on right along + selected_entries = reg1_chain_entries or reg2_chain_entries + merged.extend(list(selected_entries.values())) # Write the merged registry to the specified output file path write_registry(entries=merged, filepath=output_filepath)