From c5258fd2424e85eb91ab77108a519d5adc251673 Mon Sep 17 00:00:00 2001 From: Florian Date: Mon, 7 Mar 2016 12:45:42 +0100 Subject: [PATCH 1/3] added new reset() function to sequential model --- keras/models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/keras/models.py b/keras/models.py index d65cb4ff9a1..81c5825594b 100644 --- a/keras/models.py +++ b/keras/models.py @@ -166,7 +166,7 @@ def model_from_json(json_string, custom_objects={}): return model_from_config(config, custom_objects=custom_objects) -def model_from_config(config, custom_objects={}): +def model_from_config(config, custom_objects={}, reset=False): ''' ''' model_name = config.get('name') @@ -181,6 +181,10 @@ def model_from_config(config, custom_objects={}): elif model_name == 'Sequential': model.__class__ = Sequential model.name = model_name + model.name = model_name + if reset: + for layer in model.layers: + layer.build() if 'optimizer' in config: # if it has an optimizer, the model is assumed to be compiled @@ -463,6 +467,14 @@ def summary(self): ''' model_summary(self) + def reset(self): + ''' Reset all weights and biases to random values + + Returns: Model + ''' + + return model_from_config(self.get_config(), reset=True) + class Sequential(Model, containers.Sequential): '''Linear stack of layers. From 342696eff584bcbeca32e5ed1341244990a74614 Mon Sep 17 00:00:00 2001 From: Florian Date: Mon, 7 Mar 2016 12:48:56 +0100 Subject: [PATCH 2/3] added tiny part of function doc --- keras/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/models.py b/keras/models.py index 81c5825594b..34ae3d0fecd 100644 --- a/keras/models.py +++ b/keras/models.py @@ -468,7 +468,7 @@ def summary(self): model_summary(self) def reset(self): - ''' Reset all weights and biases to random values + ''' Reset all weights and biases to random values, then recompiles Returns: Model ''' From 290d762baccd2e69ed3089c64fab6c7672f41928 Mon Sep 17 00:00:00 2001 From: Florian Golemo Date: Mon, 7 Mar 2016 12:52:52 +0100 Subject: [PATCH 3/3] accidentally duplicated a line --- keras/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/models.py b/keras/models.py index 34ae3d0fecd..087ce307491 100644 --- a/keras/models.py +++ b/keras/models.py @@ -181,7 +181,6 @@ def model_from_config(config, custom_objects={}, reset=False): elif model_name == 'Sequential': model.__class__ = Sequential model.name = model_name - model.name = model_name if reset: for layer in model.layers: layer.build()