Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions data/coco128.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# COCO 2017 dataset http://cocodataset.org - first 128 training images
# Train command: python train.py --data coco128.yaml
# Default dataset location is next to /yolov5:
# /parent_folder
# /coco128
# /yolov5


# download command/URL (optional)
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip

# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../coco128/images/train2017/ # 128 images
val: ../coco128/images/train2017/ # 128 images

# number of classes
nc: 80

# class names
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
25 changes: 22 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test(data,
dataloader=None,
save_dir='',
merge=False,
save_txt=False):
save_txt=False,
log_imgs=0):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -86,6 +87,13 @@ def test(data,
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
niou = iouv.numel()

# Logging
log_imgs, wandb = min(log_imgs, 100), None # ceil
try:
import wandb # Weights & Biases
except ImportError:
log_imgs = 0

# Dataloader
if not training:
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
Expand All @@ -99,11 +107,12 @@ def test(data,
names = model.names if hasattr(model, 'names') else model.module.names
except:
names = load_classes(opt.names)
names_dict = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32
Expand Down Expand Up @@ -149,7 +158,15 @@ def test(data,
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format

# W&B logging
if len(wandb_images) < log_imgs:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names_dict[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names_dict}} # inference-space
wandb_images.append(wandb.Image(img[si], boxes=boxes))
# Clip boxes to image bounds
clip_coords(pred, (height, width))

Expand Down Expand Up @@ -229,6 +246,8 @@ def test(data,
if not training:
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)

if wandb and wandb.run:
wandb.log({"Bouding Box Debugging/Images": wandb_images})
# Save JSON
if save_json and len(jdict):
f = 'detections_val2017_%s_results.json' % \
Expand Down
41 changes: 29 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
from utils.google_utils import attempt_download
from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts

try:
import wandb
except ImportError:
wandb = None
logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' ")

def train(hyp, opt, device, tb_writer=None):
def train(hyp, opt, device, tb_writer=None, wandb=None):
print(f'Hyperparameters {hyp}')
log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
wdir = str(log_dir / 'weights') + os.sep # weights directory
Expand Down Expand Up @@ -99,6 +104,13 @@ def train(hyp, opt, device, tb_writer=None):
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)

if wandb and wandb.run is None:
opt.hyp = hyp # add hyperparameters
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv4' if opt.logdir == 'runs/' else Path(opt.logdir).stem,
name=log_dir.stem,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)

# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
Expand Down Expand Up @@ -305,21 +317,23 @@ def train(hyp, opt, device, tb_writer=None):
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=log_dir)
save_dir=log_dir,
log_imgs=opt.log_imgs if wandb else 0)

# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
if len(opt.name) and opt.bucket:
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))

# Tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
for x, tag in zip(list(mloss[:-1]) + list(results), tags):

tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
if tb_writer:
tb_writer.add_scalar(tag, x, epoch)
if wandb:
wandb.log({tag: x})

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
Expand All @@ -334,7 +348,8 @@ def train(hyp, opt, device, tb_writer=None):
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema.module.state_dict() if hasattr(ema, 'module') else ema.ema.state_dict(),
'optimizer': None if final_epoch else optimizer.state_dict()}
'optimizer': None if final_epoch else optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
Expand Down Expand Up @@ -362,6 +377,7 @@ def train(hyp, opt, device, tb_writer=None):
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))

dist.destroy_process_group() if rank not in [-1, 0] else None
wandb.run.finish() if wandb and wandb.run else None
torch.cuda.empty_cache()
return results

Expand Down Expand Up @@ -392,6 +408,7 @@ def train(hyp, opt, device, tb_writer=None):
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
opt = parser.parse_args()

# Resume
Expand Down Expand Up @@ -435,7 +452,7 @@ def train(hyp, opt, device, tb_writer=None):
print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp

train(hyp, opt, device, tb_writer)
train(hyp, opt, device, tb_writer, wandb)

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -503,7 +520,7 @@ def train(hyp, opt, device, tb_writer=None):
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device)
results = train(hyp.copy(), opt, device, wandb)

# Write mutation results
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
Expand Down