forked from tensorflow/gan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In response to Issue "Colab Tutorial Not TF 2.x Compatible tensorflow#18" on tensorflow/gan tensorflow#18 Collab Tutorial using tf 2.x
- Loading branch information
Showing
1 changed file
with
334 additions
and
0 deletions.
There are no files selected for viewing
334 changes: 334 additions & 0 deletions
334
tensorflow_gan/Collab_tutorial_on_GAN_using_tf2_x.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "Collab_tutorial_on_GAN_using_tf2.x.ipynb", | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "h8-7MxIETkG4", | ||
"colab_type": "code", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 36 | ||
}, | ||
"outputId": "c5b2bb65-bf61-43ec-adbe-908c16741db9" | ||
}, | ||
"source": [ | ||
"# Setting up\n", | ||
"from __future__ import absolute_import, division, print_function, unicode_literals\n", | ||
"try:\n", | ||
" # tf_version only exist in colab\n", | ||
" %tensorflow_version 2.x\n", | ||
"\n", | ||
"except Exception:\n", | ||
" pass\n" | ||
], | ||
"execution_count": 1, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"text": [ | ||
"TensorFlow 2.x selected.\n" | ||
], | ||
"name": "stdout" | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "6MD16UIYVt7n", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# SETUP" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "A5gKsgrGUZ39", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"tf.random.set_seed(7)\n", | ||
"import numpy as np\n", | ||
"np.random.seed(7)\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import os\n", | ||
"import time" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "J0iuziuiVwkA", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# HYPERPARAMETERS\n", | ||
"\n", | ||
"X_Train is the array of images, the user wants to use as training data. Labels and test data aren't needed" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "ZlWkvpA0UcK5", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"EPOCHS = 60\n", | ||
"BATCH_SIZE = 256\n", | ||
"NO_OF_BATCHES = int(X_Train.shape[0]/BATCH_SIZE)\n", | ||
"HALF_BATCH_SIZE = 128\n", | ||
"NOISE_DIM = 100\n", | ||
"# Customised ADAM optimizer\n", | ||
"adam = tf.keras.optimizers.Adam(lr=2e-4,beta_1=0.5) # lr: Learning Rate" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "YEnk012DV_Et", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Making up models for GAN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "ewWCDClQUhGo", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"from tensorflow.keras.layers import *\n", | ||
"from tensorflow.keras.layers import LeakyReLU\n", | ||
"from tensorflow.keras.models import Sequential, Model" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "gMbuIfzhWC6K", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# GENERATOR " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "oXt9-uPMUjUo", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"# Generator\n", | ||
"generator = Sequential()\n", | ||
"generator.add(Dense(256,input_shape=(NOISE_DIM,)))\n", | ||
"generator.add(LeakyReLU(0.2))\n", | ||
"generator.add(Dense(512))\n", | ||
"generator.add(LeakyReLU(0.2))\n", | ||
"generator.add(Dense(1024))\n", | ||
"generator.add(LeakyReLU(0.2))\n", | ||
"generator.add(Dense(784,activation='tanh'))\n", | ||
"\n", | ||
"# Compile\n", | ||
"generator.compile(loss='binary_crossentropy',optimizer=adam)\n", | ||
"\n", | ||
"generator.summary()" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "RZownvPRWGT-", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# DISCRIMINATOR" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "KAhiXuXWUnyT", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"#Discriminator\n", | ||
"discriminator = Sequential()\n", | ||
"discriminator.add(Dense(512,input_shape=(784,)))\n", | ||
"discriminator.add(LeakyReLU(0.2))\n", | ||
"discriminator.add(Dense(256))\n", | ||
"discriminator.add(LeakyReLU(0.2))\n", | ||
"discriminator.add(Dense(1,activation='sigmoid'))\n", | ||
"\n", | ||
"discriminator.compile(loss='binary_crossentropy',optimizer=adam)\n", | ||
"discriminator.summary()" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "JfnJ1IVIWJt2", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Combined Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "VKzNOfNIUpxa", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"# Combined Model (Geneerator + Discriminator) -> Functional API\n", | ||
"discriminator.trainable = False\n", | ||
"gan_input = Input(shape=(NOISE_DIM,))\n", | ||
"generator_output = generator(gan_input)\n", | ||
"gan_output = discriminator(generator_output)\n", | ||
"\n", | ||
"model = Model(gan_input,gan_output)\n", | ||
"model.compile(loss='binary_crossentropy',optimizer=adam)\n", | ||
"\n", | ||
"model.summary()" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "tWMncFAzUs_D", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"def showImgs(epoch):\n", | ||
" noise = np.random.normal(0,1,size=(100,NOISE_DIM)) # Random Noise\n", | ||
" generated_imgs = generator.predict(noise)\n", | ||
" generated_imgs = generated_imgs.reshape(-1,28,28)\n", | ||
" \n", | ||
" #Display the Images\n", | ||
" plt.figure(figsize=(10,10))\n", | ||
" for i in range(100):\n", | ||
" plt.subplot(10,10,i+1)\n", | ||
" plt.imshow(generated_imgs[i],cmap='gray',interpolation='nearest')\n", | ||
" plt.axis(\"off\")\n", | ||
" plt.tight_layout()\n", | ||
" plt.show()" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "21WSSthiWPuo", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "TTUSWup6U59U", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"d_losses = []\n", | ||
"g_losses = []\n", | ||
"\n", | ||
"# Training Loop\n", | ||
"for epoch in range(EPOCHS+1): \n", | ||
" epoch_d_loss = 0.0\n", | ||
" epoch_g_loss = 0.0\n", | ||
" \n", | ||
" # Mini Batch\n", | ||
" for step in range(NO_OF_BATCHES):\n", | ||
" idx = np.random.randint(0,X_Train.shape[0],HALF_BATCH_SIZE)\n", | ||
" real_imgs = X_Train[idx]\n", | ||
" \n", | ||
" # generate fake images assuming generator is frozen\n", | ||
" noise = np.random.normal(0,1,size=(HALF_BATCH_SIZE,NOISE_DIM))\n", | ||
" fake_imgs = generator.predict(noise)\n", | ||
" \n", | ||
" # Labels\n", | ||
" real_y = np.ones((HALF_BATCH_SIZE,1))*0.9\n", | ||
" fake_y = np.zeros((HALF_BATCH_SIZE,1))\n", | ||
" \n", | ||
" #Train on Real and Fake Images\n", | ||
" d_real_loss = discriminator.train_on_batch(real_imgs,real_y)\n", | ||
" d_fake_loss = discriminator.train_on_batch(fake_imgs,fake_y)\n", | ||
" \n", | ||
" d_loss = 0.5*d_real_loss + 0.5*d_fake_loss\n", | ||
" epoch_d_loss += d_loss\n", | ||
" \n", | ||
" # Train Generator\n", | ||
" noise = np.random.normal(0,1,size=(BATCH_SIZE,NOISE_DIM))\n", | ||
" real_y = np.ones((BATCH_SIZE,1))\n", | ||
" g_loss = model.train_on_batch(noise,real_y)\n", | ||
" epoch_g_loss += g_loss\n", | ||
" \n", | ||
" d_losses.append(epoch_d_loss)\n", | ||
" g_losses.append(epoch_g_loss)\n", | ||
" \n", | ||
" print(\"Epoch %d D Loss %.4f G loss %0.4f \"%((epoch+1),epoch_d_loss,epoch_g_loss))\n", | ||
" if (epoch%5)==0: #printing images after every 5 epocs\n", | ||
" #generator.save(\"model/gen_{0}.h5\".format(epoch)) # Saving the model\n", | ||
" showImgs(epoch) # Show the image transformation progress at every step" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "tP85ZHAvVksH", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
} | ||
] | ||
} |