Skip to content

Commit

Permalink
Merge pull request #332 from XpressAI/paul/copy-able-connections
Browse files Browse the repository at this point in the history
Introduce Parallel Execution Component and Make Args & Components Copyable
  • Loading branch information
MFA-X-AI authored Jun 17, 2024
2 parents e316b65 + a1f2744 commit c34d411
Show file tree
Hide file tree
Showing 4 changed files with 1,271 additions and 6 deletions.
80 changes: 76 additions & 4 deletions xai_components/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from argparse import Namespace
from typing import TypeVar, Generic, Tuple, NamedTuple, Callable, List
from copy import deepcopy

T = TypeVar('T')

class InArg(Generic[T]):
class OutArg(Generic[T]):
def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None:
self._value = value
self._getter = getter
Expand All @@ -16,9 +17,28 @@ def value(self):
def value(self, value: T):
self._value = value

class OutArg(Generic[T]):
def connect(self, ref: 'OutArg[T]'):
self._value = ref
self._getter = lambda x: x.value

def __copy__(self):
return type(self)(self._value, self._getter)

def __deepcopy__(self, memo):
id_self = id(self)
_copy = memo.get(id_self)
if _copy is None:
_copy = type(self)(
deepcopy(self._value, memo),
deepcopy(self._getter, memo)
)
memo[id_self] = _copy
return _copy


class InArg(Generic[T]):
def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None:
self.value = value
self._value = value
self._getter = getter

@property
Expand All @@ -29,9 +49,27 @@ def value(self):
def value(self, value: T):
self._value = value

def connect(self, ref: OutArg[T]):
self._value = ref
self._getter = lambda x: x.value

def __copy__(self):
return type(self)(self._value, self._getter)

def __deepcopy__(self, memo):
id_self = id(self)
_copy = memo.get(id_self)
if _copy is None:
_copy = type(self)(
deepcopy(self._value, memo),
deepcopy(self._getter, memo)
)
memo[id_self] = _copy
return _copy

class InCompArg(Generic[T]):
def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None:
self.value = value
self._value = value
self._getter = getter

@property
Expand All @@ -42,6 +80,24 @@ def value(self):
def value(self, value: T):
self._value = value

def connect(self, ref: OutArg[T]):
self._value = ref
self._getter = lambda x: x.value

def __copy__(self):
return type(self)(self._value, self._getter)

def __deepcopy__(self, memo):
id_self = id(self)
_copy = memo.get(id_self)
if _copy is None:
_copy = type(self)(
deepcopy(self._value, memo),
deepcopy(self._getter, memo)
)
memo[id_self] = _copy
return _copy

def xai_component(*args, **kwargs):
# Passthrough element without any changes.
# This is used for parser metadata only.
Expand Down Expand Up @@ -93,6 +149,22 @@ def execute(self, ctx) -> None:
def do(self, ctx) -> 'BaseComponent':
pass

def __copy__(self):
_copy = type(self)()
for key, type_arg in self.__dict__.items():
setattr(_copy, key, getattr(self, key))
return _copy

def __deepcopy__(self, memo):
id_self = id(self)
_copy = memo.get(id_self)
if _copy is None:
_copy = type(self)()
memo[id_self] = _copy
for key, type_arg in self.__dict__.items():
setattr(_copy, key, deepcopy(getattr(self, key), memo))
return _copy

class Component(BaseComponent):
next: BaseComponent

Expand Down
Loading

0 comments on commit c34d411

Please sign in to comment.