Skip to content

Commit

Permalink
sldkfj
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenlugosch committed Mar 31, 2019
1 parent 52dd00e commit aba567c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
7 changes: 6 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def read_config(config_file):

#[pretraining]
config.asr_path=parser.get("pretraining", "asr_path")
config.pretraining_type=int(parser.get("pretraining", "pretraining_type")) # 0 - no pre-training, 1 - phoneme pre-training, 2 - word pre-training, 3 - embedding-constrained word pre-training
config.pretraining_type=int(parser.get("pretraining", "pretraining_type")) # 0 - no pre-training, 1 - phoneme pre-training, 2 - phoneme + word pre-training, 3 - word pre-training
if config.pretraining_type == 0: config.starting_unfreezing_index = 1 + len(config.word_rnn_num_hidden) + len(config.phone_rnn_num_hidden) + len(config.cnn_N_filt)
if config.pretraining_type == 1: config.starting_unfreezing_index = 1 + len(config.word_rnn_num_hidden)
if config.pretraining_type == 2: config.starting_unfreezing_index = 1
if config.pretraining_type == 3: config.starting_unfreezing_index = 1
config.pretraining_lr=float(parser.get("pretraining", "pretraining_lr"))
config.pretraining_batch_size=int(parser.get("pretraining", "pretraining_batch_size"))
config.pretraining_num_epochs=int(parser.get("pretraining", "pretraining_num_epochs"))
Expand All @@ -82,6 +83,10 @@ def read_config(config_file):
config.training_batch_size=int(parser.get("training", "training_batch_size"))
config.training_num_epochs=int(parser.get("training", "training_num_epochs"))
config.dataset_subset_percentage=float(parser.get("training", "dataset_subset_percentage"))
config.train_wording_path=parser.get("training", "train_wording_path")
if config.train_wording_path=="None": config.train_wording_path = None
config.test_wording_path=parser.get("training", "test_wording_path")
if config.test_wording_path=="None": config.test_wording_path = None

# compute downsample factor (divide T by this number)
config.phone_downsample_factor = 1
Expand Down
6 changes: 0 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@
parser.add_argument('--train', action='store_true', help='run SLU training')
parser.add_argument('--restart', action='store_true', help='load checkpoint from a previous run')
parser.add_argument('--config_path', type=str, help='path to config file with hyperparameters, etc.')
parser.add_argument('--train_wording_path', type=str, default=None, help='path to .txt file containing wordings that will be allowed during training')
parser.add_argument('--test_wording_path', type=str, default=None, help='path to .txt file containing wordings that will be allowed during validation/testing')
args = parser.parse_args()
pretrain = args.pretrain
train = args.train
restart = args.restart
config_path = args.config_path
train_wording_path = args.train_wording_path
test_wording_path = args.test_wording_path

# Read config file
config = read_config(config_path)
config.train_wording_path = train_wording_path
config.test_wording_path = test_wording_path
torch.manual_seed(config.seed); np.random.seed(config.seed)

if pretrain:
Expand Down
3 changes: 2 additions & 1 deletion training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def train(self, dataset, print_interval=100):
num_examples += batch_size
phoneme_loss, word_loss, phoneme_acc, word_acc = self.model(x,y_phoneme,y_word)
if self.config.pretraining_type == 1: loss = phoneme_loss
else: loss = phoneme_loss + word_loss
if self.config.pretraining_type == 2: loss = phoneme_loss + word_loss
if self.config.pretraining_type == 3: loss = word_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
Expand Down

0 comments on commit aba567c

Please sign in to comment.