This project demonstrates the possibility of training a generic neural model to perform very complex arithmetic operations, without designing the model architecture explicitly for the task. Our model is able to compute 5-digit by 5-digit decimal multiplication at 100% accuracy. In particular, we train the GPT-2 model on a large number of generated expressions that express the process of computing multiplication step by step. Provide 87708*15192
and it should give 4192581956
.
See How does it work section for explanation of the inner mechanics.
Install dependencies with pip install -r requirements.txt
.
Edit train.py
and edit the following lines:
- Edit
gpus=4
to the number of GPUs on your machine. - Edit
batch_size=16
to fit the memory size of your GPU. - Adjust
accumulate_grad_batches=8
accordingly to control the effective batch size (gpus*batch_size*accumulate_grad_batches). In my early tests I found that larger batch sizes train faster and more stable. - Optionally add argument
resume_from_checkpoint
to a checkpoint file to resume training from a previous checkpoint. PyTorch Lightning should automatically save checkpoints during training. - You may also want to adjust the hyperparameters in
main.py
. Since the task is relatively simple, a smaller model should still work, and if it does it will train faster.
Then just run python train.py
and watch it training. It typically takes around 100 epoches (10000000 samples) until it converges, at which point it should reach 100% accuracy.
If you don't want to train the model and just want to test it out, you can download the pretrained model at https://drive.google.com/file/d/1YKzHTec5FsN6NftR3uc5j-SpmNpxTDsc/view?usp=sharing. Put the model (named epoch=98.ckpt
) in the same folder as all other files. Otherwise, modify test.py
and change the filename of the checkpoint to the one generated by train.py
.
Run python test.py
and give it prompts in the format of xxxxx*xxxxx;
. Replace x
with decimal numbers (the first digit must not be zero). For example: 12345*54321;
and watch it compute the answer.
The answer is expected to be found between the last =
sign and the $
sign.
Think about how humans perform decimal multiplications:
39
x 96
-----
234
351
-----
3744
Multiplications can be computed vertically. We first decompose the multiplier into digits, and multiply each digit by the multiplicand, then we shift these results by the position of the digit and sum them together. There are two things left to do:
-
Single digit multiplication. In the above example we need to compute
39*6=234
and39*9=351
. This is again done by decomposing the multiplicand into digits, mutiply each digit by the multiplier, and sum them together. The last unit of computation, single-digit by single-digit multiplication, can be done by looking up the multiplication table. -
While the process of summing up the inputs may seem trivial, the GPT-3 model is known to struggle with multi-operation arithmetic (see "Single Digit Three Ops" performance in the original GPT-3 paper). We need to break it up, and compute the sum of two numbers at a time. We compute the number from right to left, and compute the carry first and then the actual digit. To make the task even easier for the model, the model finds the appropriate digit first and then computes the carry digit and the ones digit, so it doesn't have to worry about where to find the digits to perform the operation. When the two numbers are of differing number of digits, the missing digits are supplied with zero. However, the operation does not stop when the digits of both numbers are used up, since the most significant digit can be carried, like
9999+101
. Instead, the operation stops when the two digits and the carry are all zero. The sequence for9999+101
looks like: (spaces added for clarity, they don't exist in actual training samples)9999+101;9100 9010 9111 9010 0011 000=10100
The full sequence for 39*96
is:
39*96;39*6;9*6;=54;3*6;=18;5+18;58030112000=23;=234;39*9;9*9;=81;3*9;=27;8+27;87050213000=35;=351;23+351;310425070303000=374;=3744$
Refer to expr.py
for details regarding how the expressions are generated. Try out the generate_multiplication
function for yourself (enter any two numbers as arguments).
The first 20 training samples can be found in exprs.txt
. In actual training they are generated on the fly using pseudo random number generators. Here is the first line of exprs.txt
:
78268*53567;78268*7;8*7;=56;6*7;=42;2*7;=14;8*7;=56;7*7;=49;5+42;52070404000=47;4+14;44080101000=18;1+56;16070505000=57;5+49;59040415000=54;=547876;78268*6;8*6;=48;6*6;=36;2*6;=12;8*6;=48;7*6;=42;4+36;46000314000=40;4+12;42060101000=16;1+48;18090404000=49;4+42;42060404000=46;=469608;78268*5;8*5;=40;6*5;=30;2*5;=10;8*5;=40;7*5;=35;4+30;40040303000=34;3+10;30030101000=13;1+40;10010404000=41;4+35;45090303000=39;=391340;78268*3;8*3;=24;6*3;=18;2*3;=6;8*3;=24;7*3;=21;2+18;28000112000=20;2+6;2608000=8;0+24;04040202000=24;2+21;21030202000=23;=234804;78268*5;8*5;=40;6*5;=30;2*5;=10;8*5;=40;7*5;=35;4+30;40040303000=34;3+10;30030101000=13;1+40;10010404000=41;4+35;45090303000=39;=391340;54787+469608;780580197603491456120415000=524395;52439+391340;900934074307210359040314000=443779;44377+234804;740170183801441943070202000=279181;27918+391340;800814059302711929010314000=419258;=4192581956$