Skip to content

Commit 4fa44d6

Browse files
chore: improve mmmu benchmark (#7000)
Signed-off-by: Xinyuan Tong <[email protected]> Co-authored-by: Xinyuan Tong <[email protected]>
1 parent e6312d2 commit 4fa44d6

2 files changed

Lines changed: 24 additions & 14 deletions

File tree

benchmark/mmmu/bench_sglang.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ async def eval_mmmu(args) -> None:
125125
client = openai.AsyncOpenAI(
126126
api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
127127
)
128-
semaphore = asyncio.Semaphore(args.concurrency)
129128
start = time.perf_counter()
130129
base_url = f"http://127.0.0.1:{args.port}"
131130

@@ -139,16 +138,26 @@ async def eval_mmmu(args) -> None:
139138

140139
samples = samples[: args.profile_number]
141140

142-
tasks = [
143-
process_sample_with_semaphore(
144-
semaphore, client, sample, sampling_params, lora_path
145-
)
146-
for sample in samples
147-
]
148-
149-
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
150-
sample, response = await coro
151-
process_result(response, sample, answer_dict, out_samples)
141+
if args.concurrency == 1:
142+
# For concurrency == 1, run in sequential mode to ensure consistent order
143+
# this is mainly for profiling
144+
for sample in tqdm(samples):
145+
_, response = await process_sample(
146+
client, sample, sampling_params, lora_path
147+
)
148+
process_result(response, sample, answer_dict, out_samples)
149+
else:
150+
semaphore = asyncio.Semaphore(args.concurrency)
151+
tasks = [
152+
process_sample_with_semaphore(
153+
semaphore, client, sample, sampling_params, lora_path
154+
)
155+
for sample in samples
156+
]
157+
158+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
159+
sample, response = await coro
160+
process_result(response, sample, answer_dict, out_samples)
152161

153162
if args.profile:
154163
print("Stopping profiler...")

benchmark/mmmu/eval_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
class EvalArgs:
2828
seed: int = 42
2929
split: str = "validation"
30-
# Default setting to make the benchmark available on A100 for most 7B models
31-
image_pixels_limit: int = 4300000
30+
image_pixels_limit: int = -1
3231
result_filename: str = ""
3332
prompt_format_file: str = "prompt_format.yaml"
3433
dataset_path: str = "MMMU/MMMU"
@@ -190,7 +189,7 @@ def process_sample(i, sample):
190189
sample = construct_prompt(sample, eval_args.config)
191190
image = sample["image"]
192191
width, height = image.size
193-
if width * height >= eval_args.image_pixels_limit:
192+
if 0 < eval_args.image_pixels_limit <= width * height:
194193
return None, True
195194
# Use a unique identifier for the image path to avoid potential collisions if indices reset
196195
image_path = f"{images_path}/image_{sample['id']}.png"
@@ -217,6 +216,8 @@ def process_sample(i, sample):
217216
elif sample:
218217
samples.append(sample)
219218

219+
samples.sort(key=lambda x: x["final_input_prompt"])
220+
220221
print(
221222
f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
222223
)

0 commit comments

Comments
 (0)