Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Xircuits Context #109

Merged
merged 6 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions Untitled.ipynb

This file was deleted.

10 changes: 6 additions & 4 deletions src/components/xircuitBodyWidget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ export const BodyWidget: FC<BodyWidgetProps> = ({
pythonCode += "outarg_output_data = []\n";
pythonCode += "is_done_list = []\n";

pythonCode += "\ndef main(args):\n";
pythonCode += "\ndef main(args):\n\n";
pythonCode += ' ' + 'ctx = {}\n';
pythonCode += ' ' + "ctx['args'] = args\n\n";

for (let i = 0; i < allNodes.length; i++) {

Expand Down Expand Up @@ -589,17 +591,17 @@ export const BodyWidget: FC<BodyWidgetProps> = ({
pythonCode += '\n';

pythonCode += ' ' + 'if len(input_data) > 0 and input_data[-1] == \'run\':\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do()\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do(ctx)\n';
pythonCode += ' ' + 'input_data.clear()\n';
pythonCode += ' ' + 'is_done_list.append(is_done)\n';
pythonCode += '\n';

pythonCode += ' ' + 'if len(input_data) > 0 and input_data[-1] == \'skip\':\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do()\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do(ctx)\n';
pythonCode += '\n';

pythonCode += ' ' + 'else:\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do()\n';
pythonCode += ' ' + 'is_done, next_component = next_component.do(ctx)\n';
pythonCode += '\n';

pythonCode += '@app.route(\'/terminate\')\n';
Expand Down
Empty file removed untitled.txt
Empty file.
22 changes: 11 additions & 11 deletions xai_components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class Component(BaseComponent):
next: BaseComponent
done: False

def do(self) -> BaseComponent:
def do(self, ctx) -> BaseComponent:
print(f"\nExecuting: {self.__class__.__name__}")
self.execute()
self.execute(ctx)

return self.done, self.next

Expand All @@ -87,7 +87,7 @@ class BranchComponent(BaseComponent):

condition: InArg[bool]

def do(self) -> BaseComponent:
def do(self, ctx) -> BaseComponent:
if self.condition.value:
return self.when_true
else:
Expand All @@ -100,28 +100,28 @@ class LoopComponent(Component):

condition: InArg[bool]

def do(self) -> BaseComponent:
def do(self, ctx) -> BaseComponent:
while self.condition.value:
next_body = self.body.do()
next_body = self.body.do(ctx)
while next_body:
next_body = next_body.do()
next_body = next_body.do(ctx)
return self
return self.next


def execute_graph(args: Namespace, start: BaseComponent) -> None:
def execute_graph(args: Namespace, start: BaseComponent, ctx) -> None:
BaseComponent.set_execution_context(ExecutionContext(args))

if 'debug' in args and args['debug']:
import pdb
pdb.set_trace()

current_component = start
next_component = current_component.do()
next_component = current_component.do(ctx)
while next_component:
current_component = next_component
next_component = current_component.do()
next_component = current_component.do(ctx)
else:
next_component = start.do()
next_component = start.do(ctx)
while next_component:
next_component = next_component.do()
next_component = next_component.do(ctx)
4 changes: 2 additions & 2 deletions xai_components/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ def __init__(self):
self.data_name = InArg.default()
self.data_set = OutArg.default()

def execute(self):
def execute(self, ctx):
# logic here
pass

class RotateCounterClockWiseComponent(Component):
data_set: InArg[Dataset]
out: OutArg[Dataset]

def execute(self):
def execute(self, ctx):
pass

4 changes: 2 additions & 2 deletions xai_components/xai_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
self.dataset_name = InArg.empty()
self.dataset = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
if self.dataset_name.value == 'mnist':
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self):
self.dataset = InArg.empty()
self.resized_dataset = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
x = self.dataset.value[0]
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2])

Expand Down
18 changes: 9 additions & 9 deletions xai_components/xai_learning/tf_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:
model = None
model_name = (self.model_name.value).lower()

Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self):
self.class_list = InArg(None)
self.target_shape = InArg(None)

def execute(self) -> None:
def execute(self, ctx) -> None:
model = self.model.value
img_path = self.img_string.value
class_list = self.class_list.value if self.class_list.value else []
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:
model_config = resnet_model_config()

#dynamically sync model config with node inputs
Expand Down Expand Up @@ -274,7 +274,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:
model_config = resnet_model_config()

#dynamically sync model config with node inputs
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:
model_config = resnet_model_config()
#dynamically sync model config with node inputs
for port in self.__dict__.keys():
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:

model_config = vgg_model_config()

Expand Down Expand Up @@ -418,7 +418,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:

model_config = vgg_model_config()

Expand Down Expand Up @@ -461,7 +461,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:

model_config = vgg_model_config()

Expand Down Expand Up @@ -535,7 +535,7 @@ def __init__(self):
self.model = OutArg(None)


def execute(self) -> None:
def execute(self, ctx) -> None:

model_config = vgg_model_config()

Expand Down
22 changes: 11 additions & 11 deletions xai_components/xai_learning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self):
self.class_dict = OutArg.empty()


def execute(self) -> None:
def execute(self, ctx) -> None:

if self.dataset_name.value == 'mnist':
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(self):
self.mask_dataset_name = InArg.empty()
self.dataset = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:

if self.dataset_name.value and self.mask_dataset_name.value:

Expand All @@ -162,7 +162,7 @@ def __init__(self):
self.dataset = InArg.empty()
self.resized_dataset = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:

x = self.dataset.value[0]
x = x.reshape(x.shape[0], -1)
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(self):
self.train = OutArg.empty()
self.test = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:

train_split = self.train_split.value if self.train_split.value else 0.75
shuffle = self.shuffle.value if self.shuffle.value else True
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(self):
self.training_data = InArg.empty()
self.model = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
x_shape = self.training_data.value[0].shape
y_shape = self.training_data.value[1].shape

Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(self):
self.model_config = OutArg.empty()


def execute(self) -> None:
def execute(self, ctx) -> None:

x_shape = self.training_data.value[0].shape[1:]
y_shape = self.training_data.value[1].shape[1]
Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(self):
self.trained_model = OutArg.empty()
self.training_metrics = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:

model = self.model.value

Expand Down Expand Up @@ -348,7 +348,7 @@ def __init__(self):
self.eval_dataset = InArg.empty()
self.metrics = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
(loss, acc) = self.model.value.evaluate(self.eval_dataset.value[0], self.eval_dataset.value[1], verbose=0)
metrics = {
'loss': str(loss),
Expand Down Expand Up @@ -377,7 +377,7 @@ def __init__(self):
self.should_retrain = OutArg(True)
self.retries = 0

def execute(self) -> None:
def execute(self, ctx) -> None:
self.retries += 1

if self.retries < self.max_retries.value:
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self):

self.model_h5_path = OutArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
model = self.model.value
model_name = self.model_name.value if self.model_name.value else os.path.splitext(sys.argv[0])[0] + ".h5"
model.save(model_name)
Expand All @@ -431,7 +431,7 @@ def __init__(self):
self.experiment_name = InArg.empty()
self.metrics = InArg.empty()

def execute(self) -> None:
def execute(self, ctx) -> None:
config = self.execution_context.args

if not os.path.exists(os.path.join('..', 'experiments')):
Expand Down
Loading