Skip to content

Commit

Permalink
Merge pull request #522 from GoogleCloudPlatform/update_encoder_decoder
Browse files Browse the repository at this point in the history
Update encoder-decoder notebooks
  • Loading branch information
takumiohym authored Sep 27, 2024
2 parents a619354 + df241a4 commit 4b9b789
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 146 deletions.
127 changes: 55 additions & 72 deletions notebooks/text_models/labs/rnn_encoder_decoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@
"At last, we'll benchmark our results using the industry standard BLEU score."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pip install nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -51,14 +42,15 @@
"import pickle\n",
"import sys\n",
"\n",
"import nltk\n",
"import evaluate\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import utils_preproc\n",
"from sklearn.model_selection import train_test_split\n",
"from tensorflow.keras.layers import GRU, Dense, Embedding, Input\n",
"from tensorflow.keras.models import Model, load_model\n",
"from tqdm import tqdm\n",
"\n",
"print(tf.__version__)"
]
Expand Down Expand Up @@ -724,7 +716,7 @@
"source": [
"## Implementing the translation (or decoding) function\n",
"\n",
"We can't just use model.predict(), because we don't know all the inputs we used during training. We only know the encoder_input (source language) but not the decoder_input (target language), which is what we want to predict (i.e., the translation of the source language)!\n",
"We can't just use model(), because we don't know all the inputs we used during training. We only know the encoder_input (source language) but not the decoder_input (target language), which is what we want to predict (i.e., the translation of the source language)!\n",
"\n",
"We do however know the first token of the decoder input, which is the `<start>` token. So using this plus the state of the encoder RNN, we can predict the next token. We will then use that token to be the second token of decoder input, and continue like this until we predict the `<end>` token, or we reach some defined max length.\n",
"\n",
Expand Down Expand Up @@ -764,8 +756,8 @@
"outputs": [],
"source": [
"if LOAD_CHECKPOINT:\n",
" encoder_model = load_model(os.path.join(MODEL_PATH, 'encoder_model.h5'))\n",
" decoder_model = load_model(os.path.join(MODEL_PATH, 'decoder_model.h5'))\n",
" encoder_model = load_model(os.path.join(MODEL_PATH, 'encoder_model'))\n",
" decoder_model = load_model(os.path.join(MODEL_PATH, 'decoder_model'))\n",
"\n",
"else:\n",
" encoder_model = # TODO\n",
Expand Down Expand Up @@ -814,7 +806,7 @@
" Returns translated sentences\n",
" \"\"\"\n",
" # Encode the input as state vectors.\n",
" states_value = encoder_model.predict(input_seqs)\n",
" states_value = encoder_model(input_seqs)\n",
"\n",
" # Populate the first character of target sequence with the start character.\n",
" batch_size = input_seqs.shape[0]\n",
Expand All @@ -824,7 +816,7 @@
"\n",
" for i in range(max_decode_length):\n",
"\n",
" output_tokens, decoder_state = decoder_model.predict(\n",
" output_tokens, decoder_state = decoder_model(\n",
" [target_seq, states_value])\n",
"\n",
" # Sample a token\n",
Expand Down Expand Up @@ -905,9 +897,9 @@
"### Exercise 9\n",
"\n",
"Save\n",
"* `model` to disk as the file `model.h5`\n",
"* `encoder_model` to disk as the file `encoder_model.h5`\n",
"* `decoder_model` to disk as the file `decoder_model.h5`\n"
"* `model` to disk as the file `model`\n",
"* `encoder_model` to disk as the file `encoder_model`\n",
"* `decoder_model` to disk as the file `decoder_model`\n"
]
},
{
Expand Down Expand Up @@ -951,7 +943,7 @@
"\n",
"It still is imperfect, since it gives no credit to synonyms and so human evaluation is still best when feasible. However BLEU is commonly considered the best among bad options for an automated metric.\n",
"\n",
"The NLTK framework has an implementation that we will use.\n",
"The Hugging Face evaluate framework has an implementation that we will use.\n",
"\n",
"We can't run calculate BLEU during training, because at that time the correct decoder input is used. Instead we'll calculate it now.\n",
"\n",
Expand All @@ -964,13 +956,16 @@
"metadata": {},
"outputs": [],
"source": [
"def bleu_1(reference, candidate):\n",
" reference = list(filter(lambda x: x != \"\", reference)) # remove padding\n",
" candidate = list(filter(lambda x: x != \"\", candidate)) # remove padding\n",
" smoothing_function = nltk.translate.bleu_score.SmoothingFunction().method1\n",
" return nltk.translate.bleu_score.sentence_bleu(\n",
" reference, candidate, (1,), smoothing_function\n",
" )"
"def postprocess(sentence):\n",
" filtered = list(filter(lambda x: x != \"\" and x != \"<end>\", sentence))\n",
" return \" \".join(filtered)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now average the `bleu_1` and `bleu_4` scores for all the sentence pairs in the eval set. The next cell takes around 1 minute (8 minutes for full dataset eval) to run, the bulk of which is decoding the sentences in the validation set. Please wait until completes."
]
},
{
Expand All @@ -979,22 +974,30 @@
"metadata": {},
"outputs": [],
"source": [
"def bleu_4(reference, candidate):\n",
" reference = list(filter(lambda x: x != \"\", reference)) # remove padding\n",
" candidate = list(filter(lambda x: x != \"\", candidate)) # remove padding\n",
" smoothing_function = nltk.translate.bleu_score.SmoothingFunction().method1\n",
" return nltk.translate.bleu_score.sentence_bleu(\n",
" reference, candidate, (0.25, 0.25, 0.25, 0.25), smoothing_function\n",
" )"
"NUM_EVALUATE = 1000 # `len(input_tensor_val)` for full eval.\n",
"\n",
"reference = []\n",
"candidate = []\n",
"\n",
"\n",
"for idx in tqdm(range(NUM_EVALUATE)):\n",
" reference_sentence = utils_preproc.int2word(\n",
" targ_lang, target_tensor_val[idx][1:]\n",
" )\n",
"\n",
" decoded_sentence = decode_sequences(\n",
" input_tensor_val[idx : idx + 1], targ_lang, max_length_targ\n",
" )[0]\n",
"\n",
" candidate.append(postprocess(decoded_sentence))\n",
" reference.append([postprocess(reference_sentence)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exercise 10\n",
"\n",
"Let's now average the `bleu_1` and `bleu_4` scores for all the sentence pairs in the eval set. The next cell takes some time to run, the bulk of which is decoding the 6000 sentences in the validation set. Please wait unitl completes."
"### Check the score"
]
},
{
Expand All @@ -1003,47 +1006,27 @@
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"num_examples = len(input_tensor_val)\n",
"bleu_1_total = 0\n",
"bleu_4_total = 0\n",
"\n",
"\n",
"for idx in range(num_examples):\n",
" reference_sentence = utils_preproc.int2word(\n",
" targ_lang, target_tensor_val[idx][1:])\n",
"\n",
" decoded_sentence = decode_sequences(\n",
" input_tensor_val[idx:idx+1], targ_lang, max_length_targ)[0]\n",
"\n",
" bleu_1_total += # TODO\n",
" bleu_4_total += # TODO\n",
"\n",
"print('BLEU 1: {}'.format(bleu_1_total/num_examples))\n",
"print('BLEU 4: {}'.format(bleu_4_total/num_examples))"
"bleu = evaluate.load(\"bleu\")\n",
"bleu_1 = bleu.compute(predictions=candidate, references=reference, max_order=1)\n",
"bleu_4 = bleu.compute(predictions=candidate, references=reference, max_order=4)"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bleu_1[\"bleu\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Results\n",
"\n",
"**Hyperparameters**\n",
"\n",
"- Batch_Size: 64\n",
"- Optimizer: adam\n",
"- Embed_dim: 256\n",
"- GRU Units: 1024\n",
"- Train Examples: 24,000\n",
"- Epochs: 10\n",
"- Hardware: P100 GPU\n",
"\n",
"**Performance**\n",
"- Training Time: 5min \n",
"- Cross-entropy loss: train: 0.0722 - val: 0.9062\n",
"- BLEU 1: 0.2519574312515255\n",
"- BLEU 4: 0.04589972764144636"
"bleu_4[\"bleu\"]"
]
},
{
Expand Down
Loading

0 comments on commit 4b9b789

Please sign in to comment.