diff --git a/data/coco128.yaml b/data/coco128.yaml new file mode 100644 index 00000000..12e1d799 --- /dev/null +++ b/data/coco128.yaml @@ -0,0 +1,28 @@ +# COCO 2017 dataset http://cocodataset.org - first 128 training images +# Train command: python train.py --data coco128.yaml +# Default dataset location is next to /yolov5: +# /parent_folder +# /coco128 +# /yolov5 + + +# download command/URL (optional) +download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip + +# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] +train: ../coco128/images/train2017/ # 128 images +val: ../coco128/images/train2017/ # 128 images + +# number of classes +nc: 80 + +# class names +names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush'] diff --git a/test.py b/test.py index f5d24335..223f541f 100644 --- a/test.py +++ b/test.py @@ -42,7 +42,8 @@ def test(data, dataloader=None, save_dir='', merge=False, - save_txt=False): + save_txt=False, + log_imgs=0): # Initialize/load model and set device training = model is not None if training: # called by train.py @@ -86,6 +87,13 @@ def test(data, iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 niou = iouv.numel() + # Logging + log_imgs, wandb = min(log_imgs, 100), None # ceil + try: + import wandb # Weights & Biases + except ImportError: + log_imgs = 0 + # Dataloader if not training: img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img @@ -99,11 +107,12 @@ def test(data, names = model.names if hasattr(model, 'names') else model.module.names except: names = load_classes(opt.names) + names_dict = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} coco91class = coco80_to_coco91_class() s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. loss = torch.zeros(3, device=device) - jdict, stats, ap, ap_class = [], [], [], [] + jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): img = img.to(device, non_blocking=True) img = img.half() if half else img.float() # uint8 to fp16/32 @@ -149,7 +158,15 @@ def test(data, xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format - + # W&B logging + if len(wandb_images) < log_imgs: + box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, + "class_id": int(cls), + "box_caption": "%s %.3f" % (names_dict[cls], conf), + "scores": {"class_score": conf}, + "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] + boxes = {"predictions": {"box_data": box_data, "class_labels": names_dict}} # inference-space + wandb_images.append(wandb.Image(img[si], boxes=boxes)) # Clip boxes to image bounds clip_coords(pred, (height, width)) @@ -229,6 +246,8 @@ def test(data, if not training: print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) + if wandb and wandb.run: + wandb.log({"Bouding Box Debugging/Images": wandb_images}) # Save JSON if save_json and len(jdict): f = 'detections_val2017_%s_results.json' % \ diff --git a/train.py b/train.py index 78b04b98..1d47894c 100644 --- a/train.py +++ b/train.py @@ -27,8 +27,13 @@ from utils.google_utils import attempt_download from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts +try: + import wandb +except ImportError: + wandb = None + logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' ") -def train(hyp, opt, device, tb_writer=None): +def train(hyp, opt, device, tb_writer=None, wandb=None): print(f'Hyperparameters {hyp}') log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory wdir = str(log_dir / 'weights') + os.sep # weights directory @@ -99,6 +104,13 @@ def train(hyp, opt, device, tb_writer=None): scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) + if wandb and wandb.run is None: + opt.hyp = hyp # add hyperparameters + wandb_run = wandb.init(config=opt, resume="allow", + project='YOLOv4' if opt.logdir == 'runs/' else Path(opt.logdir).stem, + name=log_dir.stem, + id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) + # Resume start_epoch, best_fitness = 0, 0.0 if pretrained: @@ -305,21 +317,23 @@ def train(hyp, opt, device, tb_writer=None): model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema, single_cls=opt.single_cls, dataloader=testloader, - save_dir=log_dir) + save_dir=log_dir, + log_imgs=opt.log_imgs if wandb else 0) # Write with open(results_file, 'a') as f: f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) - - # Tensorboard - if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] - for x, tag in zip(list(mloss[:-1]) + list(results), tags): + + tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', + 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] + for x, tag in zip(list(mloss[:-1]) + list(results), tags): + if tb_writer: tb_writer.add_scalar(tag, x, epoch) + if wandb: + wandb.log({tag: x}) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] @@ -334,7 +348,8 @@ def train(hyp, opt, device, tb_writer=None): 'best_fitness': best_fitness, 'training_results': f.read(), 'model': ema.ema.module.state_dict() if hasattr(ema, 'module') else ema.ema.state_dict(), - 'optimizer': None if final_epoch else optimizer.state_dict()} + 'optimizer': None if final_epoch else optimizer.state_dict(), + 'wandb_id': wandb_run.id if wandb else None} # Save last, best and delete torch.save(ckpt, last) @@ -362,6 +377,7 @@ def train(hyp, opt, device, tb_writer=None): print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) dist.destroy_process_group() if rank not in [-1, 0] else None + wandb.run.finish() if wandb and wandb.run else None torch.cuda.empty_cache() return results @@ -392,6 +408,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--logdir', type=str, default='runs/', help='logging directory') + parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') opt = parser.parse_args() # Resume @@ -435,7 +452,7 @@ def train(hyp, opt, device, tb_writer=None): print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir) tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp - train(hyp, opt, device, tb_writer) + train(hyp, opt, device, tb_writer, wandb) # Evolve hyperparameters (optional) else: @@ -503,7 +520,7 @@ def train(hyp, opt, device, tb_writer=None): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device) + results = train(hyp.copy(), opt, device, wandb) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket)