Skip to content

Commit 15a73ff

Browse files
use yaml file when using --data flag
1 parent 4e32cff commit 15a73ff

2 files changed

Lines changed: 75 additions & 114 deletions

File tree

boxmot/engine/cli.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,50 @@ def resolve_dataset_cfg_path(data_ref) -> Optional[Path]:
6161
return None
6262

6363

64+
def resolve_source_from_data_ref(data_ref: str, overwrite: bool = False) -> tuple[str, str, str]:
65+
"""
66+
Resolve a data/source reference into (source_path, benchmark_name, split_name).
67+
68+
Supports either direct dataset paths or dataset yaml references.
69+
"""
70+
source_path = Path(str(data_ref))
71+
benchmark = source_path.parent.name
72+
split = source_path.name
73+
74+
cfg_path = resolve_dataset_cfg_path(data_ref)
75+
if cfg_path is None:
76+
return str(data_ref), benchmark, split
77+
78+
cfg = load_dataset_cfg_path(cfg_path)
79+
bench_name = Path(cfg["benchmark"]["source"]).name
80+
dataset_url = cfg["download"]["dataset_url"]
81+
82+
if dataset_url:
83+
if dataset_url.startswith("hf://"):
84+
dataset_dest = TRACKEVAL / bench_name
85+
else:
86+
dataset_dest = TRACKEVAL / f"{bench_name}.zip"
87+
else:
88+
dataset_dest = Path(cfg["download"].get("dataset_dest", f"assets/{bench_name}"))
89+
90+
download_eval_data(
91+
runs_url=cfg["download"]["runs_url"],
92+
dataset_url=dataset_url,
93+
dataset_dest=dataset_dest,
94+
overwrite=overwrite,
95+
)
96+
97+
split = cfg["benchmark"]["split"]
98+
if dataset_url:
99+
resolved_source = TRACKEVAL / f"{bench_name}/{split}"
100+
elif "source" in cfg["benchmark"]:
101+
resolved_source = Path(cfg["benchmark"]["source"]) / split
102+
else:
103+
resolved_source = dataset_dest / split
104+
105+
return str(resolved_source), bench_name, split
106+
107+
64108
def ensure_model_extension(model_path):
65109
"""
66110
Ensure model path has .pt extension.
@@ -376,8 +420,7 @@ def track(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwarg
376420
if tracker:
377421
kwargs['tracking_method'] = tracker
378422
src = kwargs.pop('source')
379-
source_path = Path(src)
380-
bench, split = source_path.parent.name, source_path.name
423+
resolved_source, bench, split = resolve_source_from_data_ref(src)
381424

382425
# Auto-append .pt extension if missing
383426
yolo_model = ensure_model_extension(yolo_model)
@@ -387,37 +430,10 @@ def track(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwarg
387430
'yolo_model': yolo_model,
388431
'reid_model': reid_model,
389432
'classes': parse_classes(classes),
390-
'source': src,
433+
'source': resolved_source,
391434
'benchmark': bench,
392435
'split': split}
393436
args = SimpleNamespace(**params)
394-
395-
# 2) if doing MOT17/20-ablation, pull down the dataset and rewire args.source/split
396-
cfg_path = resolve_dataset_cfg_path(args.source)
397-
if cfg_path is not None:
398-
cfg = load_dataset_cfg_path(cfg_path)
399-
400-
# Determine dataset destination
401-
if cfg["download"]["dataset_url"]:
402-
dataset_dest = TRACKEVAL / f"{Path(cfg['benchmark']['source']).name}.zip"
403-
else:
404-
# For custom datasets without URL, use the path from config if available, or default to assets
405-
dataset_dest = Path(cfg["download"].get("dataset_dest", f"assets/{Path(cfg['benchmark']['source']).name}"))
406-
407-
download_eval_data(
408-
runs_url=cfg["download"]["runs_url"],
409-
dataset_url=cfg["download"]["dataset_url"],
410-
dataset_dest=dataset_dest,
411-
overwrite=False
412-
)
413-
args.benchmark = Path(cfg["benchmark"]["source"]).name
414-
args.split = cfg["benchmark"]["split"]
415-
if cfg["download"]["dataset_url"]:
416-
args.source = TRACKEVAL / f"{args.benchmark}/{args.split}"
417-
elif "source" in cfg["benchmark"]:
418-
args.source = Path(cfg["benchmark"]["source"]) / args.split
419-
else:
420-
args.source = dataset_dest / args.split
421437

422438
from boxmot.engine.tracker import main as run_track
423439
run_track(args)
@@ -437,8 +453,7 @@ def generate(ctx, detector, reid, yolo_model, reid_model, classes, **kwargs):
437453
if reid:
438454
reid_model = [ensure_model_extension(reid)]
439455
src = kwargs.pop('data')
440-
source_path = Path(src)
441-
bench, split = source_path.parent.name, source_path.name
456+
resolved_source, bench, split = resolve_source_from_data_ref(src)
442457

443458
# Auto-append .pt extension if missing
444459
yolo_model = [ensure_model_extension(m) for m in yolo_model]
@@ -449,7 +464,7 @@ def generate(ctx, detector, reid, yolo_model, reid_model, classes, **kwargs):
449464
'reid_model': list(reid_model),
450465
'classes': parse_classes(classes),
451466
'data': src,
452-
'source': src,
467+
'source': resolved_source,
453468
'benchmark': bench,
454469
'split': split}
455470
args = SimpleNamespace(**params)
@@ -475,8 +490,6 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
475490
if tracker:
476491
kwargs['tracking_method'] = tracker
477492
src = kwargs.pop('data')
478-
source_path = Path(src)
479-
bench, split = source_path.parent.name, source_path.name
480493

481494
# Auto-append .pt extension if missing
482495
yolo_model = [ensure_model_extension(m) for m in yolo_model]
@@ -487,9 +500,6 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
487500
'reid_model': list(reid_model),
488501
'classes': parse_classes(classes),
489502
'data': src,
490-
'source': src,
491-
'benchmark': bench,
492-
'split': split,
493503
'imgsz': [1088, 1920]}
494504
args = SimpleNamespace(**params)
495505
from boxmot.engine.evaluator import main as run_eval
@@ -515,8 +525,6 @@ def tune(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
515525
if tracker:
516526
kwargs['tracking_method'] = tracker
517527
src = kwargs.pop('data')
518-
source_path = Path(src)
519-
bench, split = source_path.parent.name, source_path.name
520528

521529
# Auto-append .pt extension if missing
522530
yolo_model = [ensure_model_extension(m) for m in yolo_model]
@@ -526,10 +534,7 @@ def tune(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
526534
'yolo_model': list(yolo_model),
527535
'reid_model': list(reid_model),
528536
'classes': parse_classes(classes),
529-
'data': src,
530-
'source': src,
531-
'benchmark': bench,
532-
'split': split}
537+
'data': src}
533538
args = SimpleNamespace(**params)
534539
from boxmot.engine.tuner import main as run_tuning
535540
run_tuning(args)

boxmot/engine/evaluator.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424
from boxmot.utils.plots import MetricsPlotter
2525
from boxmot.utils.misc import increment_path, prompt_overwrite
2626
from boxmot.utils.timing import TimingStats, wrap_tracker_reid
27-
from typing import Optional, List, Dict, Generator, Union
27+
from typing import Optional, List, Dict, Generator
2828

2929
from boxmot.utils.dataloaders.MOT17 import MOT17DetEmbDataset
3030
from boxmot.postprocessing.gsi import gsi
3131

3232
from boxmot.engine.inference import DetectorReIDPipeline, extract_detections, filter_detections
3333
from boxmot.detectors import default_imgsz
3434
from boxmot.utils.mot_utils import convert_to_mot_format, write_mot_results
35-
from boxmot.utils.download import download_eval_data, download_trackeval
35+
from boxmot.utils.download import download_trackeval
3636

3737
checker = RequirementsChecker()
3838
checker.check_packages(('ultralytics', )) # install
@@ -57,47 +57,6 @@ def load_dataset_cfg(name: str) -> dict:
5757
return yaml.safe_load(f)
5858

5959

60-
def load_dataset_cfg_path(path: Path) -> dict:
61-
"""Load a dataset config directly from a yaml file path."""
62-
with open(path, 'r') as f:
63-
return yaml.safe_load(f)
64-
65-
66-
def resolve_dataset_cfg_path(data_ref: Union[str, Path, None]) -> Optional[Path]:
67-
"""
68-
Resolve dataset config references from --data/--source.
69-
70-
Supports:
71-
- config name: "MOT17-ablation"
72-
- config filename: "MOT17-ablation.yaml"
73-
- explicit path: "/path/to/custom.yaml" or "./custom.yaml"
74-
"""
75-
if data_ref is None:
76-
return None
77-
78-
ref = Path(str(data_ref))
79-
80-
# Explicit local yaml path (absolute or relative)
81-
if ref.suffix in {".yaml", ".yml"} and ref.exists() and ref.is_file():
82-
return ref.resolve()
83-
84-
# Config filename in boxmot/configs/datasets
85-
if ref.suffix in {".yaml", ".yml"}:
86-
cfg_by_filename = DATASET_CONFIGS / ref.name
87-
if cfg_by_filename.exists() and cfg_by_filename.is_file():
88-
return cfg_by_filename
89-
cfg_by_stem = DATASET_CONFIGS / f"{ref.stem}.yaml"
90-
if cfg_by_stem.exists() and cfg_by_stem.is_file():
91-
return cfg_by_stem
92-
93-
# Config name in boxmot/configs/datasets
94-
cfg_by_name = DATASET_CONFIGS / f"{str(data_ref)}.yaml"
95-
if cfg_by_name.exists() and cfg_by_name.is_file():
96-
return cfg_by_name
97-
98-
return None
99-
100-
10160
def eval_init(args,
10261
trackeval_dest: Path = TRACKEVAL,
10362
branch: str = "master",
@@ -111,33 +70,28 @@ def eval_init(args,
11170
# 1) download the TrackEval code
11271
download_trackeval(dest=trackeval_dest, branch=branch, overwrite=overwrite)
11372

114-
# 2) if using a dataset yaml, pull down data (if needed) and rewire args.source/split
115-
data_ref = getattr(args, "data", None) or getattr(args, "source", None)
116-
cfg_path = resolve_dataset_cfg_path(data_ref)
117-
if cfg_path is not None:
118-
cfg = load_dataset_cfg_path(cfg_path)
119-
120-
# Determine dataset destination
121-
if cfg["download"]["dataset_url"]:
122-
dataset_dest = TRACKEVAL / f"{Path(cfg['benchmark']['source']).name}.zip"
123-
else:
124-
# For custom datasets without URL, use the path from config if available, or default to assets
125-
dataset_dest = Path(cfg["download"].get("dataset_dest", f"assets/{Path(cfg['benchmark']['source']).name}"))
126-
127-
download_eval_data(
128-
runs_url=cfg["download"]["runs_url"],
129-
dataset_url=cfg["download"]["dataset_url"],
130-
dataset_dest=dataset_dest,
131-
overwrite=overwrite
132-
)
133-
args.benchmark = Path(cfg["benchmark"]["source"]).name
134-
args.split = cfg["benchmark"]["split"]
135-
if cfg["download"]["dataset_url"]:
136-
args.source = TRACKEVAL / f"{args.benchmark}/{args.split}"
137-
elif "source" in cfg["benchmark"]:
138-
args.source = Path(cfg["benchmark"]["source"]) / args.split
73+
# 2) resolve dataset reference via shared CLI helper.
74+
# For eval/tune, prefer --data and do not use --source when data is provided.
75+
from boxmot.engine.cli import resolve_source_from_data_ref
76+
77+
data_ref = getattr(args, "data", None)
78+
if data_ref is not None:
79+
resolved_source, benchmark, split = resolve_source_from_data_ref(data_ref, overwrite=overwrite)
80+
args.source = resolved_source
81+
args.benchmark = benchmark
82+
args.split = split
83+
else:
84+
# Backward-compatible fallback for callers that pass source directly.
85+
source_path = Path(str(getattr(args, "source", "")))
86+
if source_path.exists() and source_path.is_dir():
87+
args.source = source_path
88+
args.benchmark = getattr(args, "benchmark", source_path.parent.name)
89+
args.split = getattr(args, "split", source_path.name)
13990
else:
140-
args.source = dataset_dest / args.split
91+
resolved_source, benchmark, split = resolve_source_from_data_ref(getattr(args, "source", None), overwrite=overwrite)
92+
args.source = resolved_source
93+
args.benchmark = benchmark
94+
args.split = split
14195

14296
# 3) finally, make source an absolute Path everywhere
14397
args.source = Path(args.source).resolve()
@@ -1123,6 +1077,8 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
11231077

11241078

11251079
def main(args):
1080+
data_ref = getattr(args, "data", getattr(args, "source", None))
1081+
11261082
# Print evaluation pipeline header (blue palette)
11271083
LOGGER.info("")
11281084
LOGGER.opt(colors=True).info("<blue>" + "="*60 + "</blue>")
@@ -1131,7 +1087,7 @@ def main(args):
11311087
LOGGER.opt(colors=True).info(f"<bold>Detector:</bold> <cyan>{args.yolo_model[0]}</cyan>")
11321088
LOGGER.opt(colors=True).info(f"<bold>ReID:</bold> <cyan>{args.reid_model[0]}</cyan>")
11331089
LOGGER.opt(colors=True).info(f"<bold>Tracker:</bold> <cyan>{args.tracking_method}</cyan>")
1134-
LOGGER.opt(colors=True).info(f"<bold>Benchmark:</bold> <cyan>{args.source}</cyan>")
1090+
LOGGER.opt(colors=True).info(f"<bold>Benchmark:</bold> <cyan>{data_ref}</cyan>")
11351091
LOGGER.opt(colors=True).info(f"<bold>Image size:</bold> <cyan>{getattr(args, 'imgsz', None)}</cyan>")
11361092
LOGGER.opt(colors=True).info("<blue>" + "="*60 + "</blue>")
11371093

0 commit comments

Comments
 (0)