Skip to content

Commit

Permalink
feat: loads the pre-trained weight from the official branch, related to
Browse files Browse the repository at this point in the history
  • Loading branch information
gwkrsrch committed Aug 4, 2022
1 parent 7e45119 commit f174bb0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ Gradio web demos are available! [![Demo](https://img.shields.io/badge/Demo-Gradi

|Task|Sec/Img|Score|Trained Model|<div id="demo">Demo</div>|
|---|---|---|---|---|
| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /<br> 0.7 /<br> 1.2 | 93.9 /<br> 93.6 /<br> 93.5 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2) (1280) /<br> [donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1) (1280) /<br> [donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),<br>[google colab demo](https://colab.research.google.com/drive/1o07hty-3OQTvGnc_7lgQFLvvKQuLjqiw?usp=sharing) |
| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.8 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket) | [google colab demo](https://colab.research.google.com/drive/16O-hMvGiXrYZnlXA_tfJ9_q760YcLoOj?usp=sharing) |
| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip) | [google colab demo](https://colab.research.google.com/drive/1xUDmLqlthx8A8rWKLMSLThZ7oeRJkDuU?usp=sharing) |
| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) | [google colab demo](https://colab.research.google.com/drive/1Z4WG8Wunj3HE0CERjt608ALSgSzRC9ig?usp=sharing) |
| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /<br> 0.7 /<br> 1.2 | 93.9 /<br> 93.6 /<br> 93.5 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2/tree/official) (1280) /<br> [donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1/tree/official) (1280) /<br> [donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560/tree/official) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),<br>[google colab demo](https://colab.research.google.com/drive/1o07hty-3OQTvGnc_7lgQFLvvKQuLjqiw?usp=sharing) |
| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.8 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket/tree/official) | [google colab demo](https://colab.research.google.com/drive/16O-hMvGiXrYZnlXA_tfJ9_q760YcLoOj?usp=sharing) |
| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip/tree/official) | [google colab demo](https://colab.research.google.com/drive/1xUDmLqlthx8A8rWKLMSLThZ7oeRJkDuU?usp=sharing) |
| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa/tree/official) | [google colab demo](https://colab.research.google.com/drive/1Z4WG8Wunj3HE0CERjt608ALSgSzRC9ig?usp=sharing) |

The links to the pre-trained backbones are here:
- [`donut-base`](https://huggingface.co/naver-clova-ix/donut-base): trained with 64 A100 GPUs (~2.5 days), number of layers (encoder: {2,2,14,2}, decoder: 4), input size 2560x1920, swin window size 10, IIT-CDIP (11M) and SynthDoG (ECJK, 0.5M x 4).
- [`donut-proto`](https://huggingface.co/naver-clova-ix/donut-proto): (preliminary model) trained with 8 V100 GPUs (~5 days), number of layers (encoder: {2,2,18,2}, decoder: 4), input size 2048x1536, swin window size 8, and SynthDoG (EJK, 0.4M x 3).
- [`donut-base`](https://huggingface.co/naver-clova-ix/donut-base/tree/official): trained with 64 A100 GPUs (~2.5 days), number of layers (encoder: {2,2,14,2}, decoder: 4), input size 2560x1920, swin window size 10, IIT-CDIP (11M) and SynthDoG (ECJK, 0.5M x 4).
- [`donut-proto`](https://huggingface.co/naver-clova-ix/donut-proto/tree/official): (preliminary model) trained with 8 V100 GPUs (~5 days), number of layers (encoder: {2,2,18,2}, decoder: 4), input size 2048x1536, swin window size 8, and SynthDoG (EJK, 0.4M x 3).

Please see [our paper](#how-to-cite) for more details.

Expand Down
2 changes: 1 addition & 1 deletion donut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def from_pretrained(
Name of a pretrained model name either registered in huggingface.co. or saved in local,
e.g., `naver-clova-ix/donut-base`, or `naver-clova-ix/donut-base-finetuned-rvlcdip`
"""
model = super(DonutModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = super(DonutModel, cls).from_pretrained(pretrained_model_name_or_path, revision="official", *model_args, **kwargs)

# truncate or interplolate position embeddings of donut decoder
max_length = kwargs.get("max_length", model.config.max_position_embeddings)
Expand Down

0 comments on commit f174bb0

Please sign in to comment.