Skip to content

Commit

Permalink
[Feature] Support json output for MMMU offical evaluation (#130)
Browse files Browse the repository at this point in the history
* Fix DocVQA bug and update md5

* update DocVQA md5

* Support json output for MMMU offical evaluation

* Fix

* Fix

* Fix

* Fix

* update
  • Loading branch information
SparksJoe authored Mar 26, 2024
1 parent 3f146c2 commit 6c00449
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
7 changes: 6 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vlmeval.evaluate import *
from vlmeval.inference import infer_data_job, prefetch_acc
from vlmeval.config import supported_VLM
from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full
from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full, MMMU_result_transfer


def parse_args():
Expand Down Expand Up @@ -96,6 +96,11 @@ def main():
'will skip the evaluation. '
)
continue
# noqa W293
if rank == 0:
if dataset_name in ['MMMU_TEST']:
result_json = MMMU_result_transfer(result_file)
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, json file saved in {result_json}') # noqa E501

if rank == 0 and args.prefetch:
time.sleep(3)
Expand Down
2 changes: 1 addition & 1 deletion vlmeval/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .mp_util import track_progress_rich
from .custom_prompt import CustomPrompt
from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
from .dataset import TSVDataset, split_MMMU
from .dataset import TSVDataset, split_MMMU, MMMU_result_transfer


__all__ = [
Expand Down
24 changes: 24 additions & 0 deletions vlmeval/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..smp import *
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
from .custom_prompt import CustomPrompt
from .matching_util import can_infer


def isliststr(s):
Expand Down Expand Up @@ -40,6 +41,29 @@ def split_MMMU(struct):
return segs


def MMMU_result_transfer(result_path):
res = {}
result_data = load(result_path)
mcq = result_data['A'].notna()
lt = len(result_data)
for i in range(lt):
line = result_data.iloc[i]
if mcq[i]:
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
prediction = line['prediction']
infer_prediction = can_infer(prediction, options)
res[line['id']] = infer_prediction
else:
res[line['id']] = line['prediction']
result_json = result_path.replace('.xlsx', '.json')
dump(res, result_json)
return result_json


class TSVDataset(CustomPrompt):

def __init__(self, dataset='MMBench', skip_noimg=True):
Expand Down

0 comments on commit 6c00449

Please sign in to comment.