Skip to content

Commit

Permalink
Change path for loading
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenlugosch committed Apr 5, 2019
1 parent a4a1f45 commit e618d8a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
19 changes: 3 additions & 16 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def get_SLU_datasets(config):
base_path = config.slu_path

# Split
train_df = pd.read_csv(os.path.join(base_path, "train.csv"))
valid_df = pd.read_csv(os.path.join(base_path, "valid.csv"))
test_df = pd.read_csv(os.path.join(base_path, "test.csv"))
train_df = pd.read_csv(os.path.join(base_path, "data", "train.csv"))
valid_df = pd.read_csv(os.path.join(base_path, "data", "valid.csv"))
test_df = pd.read_csv(os.path.join(base_path, "data", "test.csv"))

# Get list of slots
Sy_intent = {"action": {}, "object": {}, "location": {}}
Expand Down Expand Up @@ -307,15 +307,6 @@ def __len__(self):
def __getitem__(self, idx):
x, fs = sf.read(self.wav_paths[idx])

# https://github.com/jameslyons/python_speech_features/blob/master/python_speech_features/base.py
# if config.use_fbank:
# eps = 1e-8
# fbank = python_speech_features.fbank(x, nfilt=40, winfunc=np.hamming)
# fbank = np.concatenate([fbank[1].reshape(-1,1), fbank[0]], axis=1) + eps
# fbank = np.log(fbank)
# fbank = (fbank - fbank.mean(0))
# fbank = fbank/(np.sqrt(fbank.var(0)))

tg = textgrid.TextGrid()
tg.read(self.textgrid_paths[idx])

Expand Down Expand Up @@ -348,10 +339,6 @@ def __getitem__(self, idx):
return (x, y_phoneme, y_word)

class CollateWavsASR:
# def __init__(self, Sy_phoneme, Sy_word):
# self.Sy_phoneme = Sy_phoneme
# self.Sy_word = Sy_word

def __call__(self, batch):
"""
batch: list of tuples (input wav, phoneme labels, word labels)
Expand Down
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,6 @@

trainer.save_checkpoint()

test_intent_acc, test_intent_loss = trainer.test(test_dataset)
test_intent_acc, test_intent_loss = trainer.test(test_dataset)
print("========= Test results =========")
print("*intents*| test accuracy: %.2f| test loss: %.2f| valid accuracy: %.2f| valid loss: %.2f\n" % (train_intent_acc, train_intent_loss, valid_intent_acc, valid_intent_loss) )

0 comments on commit e618d8a

Please sign in to comment.