Skip to content

Commit

Permalink
Fixed relative path problems mentioned in Issue #1
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhifeng JIANG committed Mar 6, 2023
1 parent f8b21b9 commit 834c834
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 13 deletions.
4 changes: 0 additions & 4 deletions plato/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def __new__(cls):
# A run ID is unique to each client in an experiment
Config.params['run_id'] = os.getpid()

# Pretrained models
Config.params['model_dir'] = "./models/pretrained/"
Config.params['pretrained_model_dir'] = "./models/pretrained/"

return cls._instance

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions plato/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def load_model(self, filename=None):
@staticmethod
def save_accuracy(accuracy, filename=None):
"""Saving the test accuracy to a file."""
model_dir = Config().params['model_dir']
model_dir = Config().result_dir
model_name = Config().trainer.model_name

if not os.path.exists(model_dir):
Expand All @@ -76,7 +76,7 @@ def save_accuracy(accuracy, filename=None):
@staticmethod
def load_accuracy(filename=None):
"""Loading the test accuracy from a file."""
model_dir = Config().params['model_dir']
model_dir = Config().result_dir
model_name = Config().trainer.model_name

if filename is not None:
Expand Down Expand Up @@ -110,7 +110,7 @@ def pause_training(self):
(self.client_id, ))

model_name = Config().trainer.model_name
model_dir = Config().params['model_dir']
model_dir = Config().result_dir
model_file = f"{model_dir}{model_name}_{self.client_id}_{Config().params['run_id']}.pth"
accuracy_file = f"{model_dir}{model_name}_{self.client_id}_{Config().params['run_id']}.acc"

Expand Down
4 changes: 2 additions & 2 deletions plato/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def zeros(self, shape):
def save_model(self, filename=None):
"""Saving the model to a file."""
model_name = Config().trainer.model_name
model_dir = Config().params['model_dir']
model_dir = Config().result_dir

if not os.path.exists(model_dir):
os.makedirs(model_dir)
Expand All @@ -96,7 +96,7 @@ def save_model(self, filename=None):

def load_model(self, filename=None):
"""Loading pre-trained model weights from a file."""
model_dir = Config().params['pretrained_model_dir']
model_dir = Config().result_dir
model_name = Config().trainer.model_name

if filename is not None:
Expand Down
4 changes: 2 additions & 2 deletions plato/trainers/mindspore/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def zeros(self, shape):
def save_model(self, filename=None):
"""Saving the model to a file."""
model_name = Config().trainer.model_name
model_dir = Config().params['model_dir']
model_dir = Config().result_dir

if not os.path.exists(model_dir):
os.makedirs(model_dir)
Expand All @@ -88,7 +88,7 @@ def save_model(self, filename=None):
def load_model(self, filename=None):
"""Loading pre-trained model weights from a file."""
model_name = Config().trainer.model_name
model_dir = Config().params['pretrained_model_dir']
model_dir = Config().result_dir

if filename is not None:
model_path = f'{model_dir}{filename}'
Expand Down
4 changes: 2 additions & 2 deletions plato/trainers/tensorflow/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def zeros(self, shape):
def save_model(self, filename=None):
"""Saving the model to a file."""
model_name = Config().trainer.model_name
model_dir = Config().params['model_dir']
model_dir = Config().result_dir

if not os.path.exists(model_dir):
os.makedirs(model_dir)
Expand All @@ -62,7 +62,7 @@ def save_model(self, filename=None):
def load_model(self, filename=None):
"""Loading pre-trained model weights from a file."""
model_name = Config().trainer.model_name
model_dir = Config().params['pretrained_model_dir']
model_dir = Config().result_dir

if filename is not None:
model_path = f'{model_dir}{filename}'
Expand Down

0 comments on commit 834c834

Please sign in to comment.