From b08fcbebda914988bfc6df33fb811046fee76ab4 Mon Sep 17 00:00:00 2001 From: "Perry Gibson (gabriel)" Date: Sun, 6 Dec 2020 15:38:30 +0000 Subject: [PATCH 1/2] added tinyyolov3 training --- imageai/Detection/Custom/__init__.py | 109 ++++++-- .../Detection/Custom/tinyyolov3_generator.py | 233 ++++++++++++++++++ imageai/Detection/Custom/yolo.py | 89 ++++++- imageai/Detection/YOLOv3/models.py | 2 +- 4 files changed, 403 insertions(+), 30 deletions(-) create mode 100644 imageai/Detection/Custom/tinyyolov3_generator.py diff --git a/imageai/Detection/Custom/__init__.py b/imageai/Detection/Custom/__init__.py index 50c4931b..ec6f7aed 100644 --- a/imageai/Detection/Custom/__init__.py +++ b/imageai/Detection/Custom/__init__.py @@ -3,9 +3,10 @@ import numpy as np import json from imageai.Detection.Custom.voc import parse_voc_annotation -from imageai.Detection.Custom.yolo import create_yolov3_model, dummy_loss -from imageai.Detection.YOLOv3.models import yolo_main +from imageai.Detection.Custom.yolo import create_yolov3_model, create_tinyyolov3_model, dummy_loss +from imageai.Detection.YOLOv3.models import yolo_main, tiny_yolo_main from imageai.Detection.Custom.generator import BatchGenerator +from imageai.Detection.Custom.tinyyolov3_generator import TinyYOLOv3BatchGenerator from imageai.Detection.Custom.utils.utils import normalize, evaluate, makedirs from keras.callbacks import ReduceLROnPlateau from keras.optimizers import Adam @@ -74,6 +75,14 @@ def setModelTypeAsYOLOv3(self): """ self.__model_type = "yolov3" + def setModelTypeAsTinyYOLOv3(self): + """ + 'setModelTypeAsTinyYOLOv3()' is used to set the model type to the TinyYOLOv3 model + for the training instance object . + :return: + """ + self.__model_type = "tinyyolov3" + def setDataDirectory(self, data_directory): """ @@ -220,6 +229,8 @@ def trainModel(self): ############################### # Create the generators ############################### + if self.__model_type == "tinyyolov3": + BatchGenerator = TinyYOLOv3BatchGenerator train_generator = BatchGenerator( instances=train_ints, anchors=self.__model_anchors, @@ -527,12 +538,47 @@ def _create_model( ): if len(multi_gpu) > 1: with tf.device('/cpu:0'): + if self.__model_type == "yolov3": + template_model, infer_model = create_yolov3_model( + nb_class=nb_class, + anchors=anchors, + max_box_per_image=max_box_per_image, + max_grid=max_grid, + batch_size=batch_size // len(multi_gpu), + warmup_batches=warmup_batches, + ignore_thresh=ignore_thresh, + grid_scales=grid_scales, + obj_scale=obj_scale, + noobj_scale=noobj_scale, + xywh_scale=xywh_scale, + class_scale=class_scale + ) + elif self.__model_type == "tinyyolov3": + template_model, infer_model = create_tinyyolov3_model( + nb_class=nb_class, + anchors=anchors, + max_box_per_image=max_box_per_image, + max_grid=max_grid, + batch_size=batch_size // len(multi_gpu), + warmup_batches=warmup_batches, + ignore_thresh=ignore_thresh, + grid_scales=grid_scales, + obj_scale=obj_scale, + noobj_scale=noobj_scale, + xywh_scale=xywh_scale, + class_scale=class_scale + ) + else: + raise ValueError(f'Unsupported model type: {self.__model_type}') + else: + print('Hello world\n\n\n\n\n', self.__model_type) + if self.__model_type == "yolov3": template_model, infer_model = create_yolov3_model( nb_class=nb_class, anchors=anchors, max_box_per_image=max_box_per_image, max_grid=max_grid, - batch_size=batch_size // len(multi_gpu), + batch_size=batch_size, warmup_batches=warmup_batches, ignore_thresh=ignore_thresh, grid_scales=grid_scales, @@ -541,21 +587,23 @@ def _create_model( xywh_scale=xywh_scale, class_scale=class_scale ) - else: - template_model, infer_model = create_yolov3_model( - nb_class=nb_class, - anchors=anchors, - max_box_per_image=max_box_per_image, - max_grid=max_grid, - batch_size=batch_size, - warmup_batches=warmup_batches, - ignore_thresh=ignore_thresh, - grid_scales=grid_scales, - obj_scale=obj_scale, - noobj_scale=noobj_scale, - xywh_scale=xywh_scale, - class_scale=class_scale - ) + elif self.__model_type == "tinyyolov3": + template_model, infer_model = create_tinyyolov3_model( + nb_class=nb_class, + anchors=anchors, + max_box_per_image=max_box_per_image, + max_grid=max_grid, + batch_size=batch_size, + warmup_batches=warmup_batches, + ignore_thresh=ignore_thresh, + grid_scales=grid_scales, + obj_scale=obj_scale, + noobj_scale=noobj_scale, + xywh_scale=xywh_scale, + class_scale=class_scale + ) + else: + raise ValueError(f'Unsupported model type: {self.__model_type}') # load the pretrained weight if exists, otherwise load the backend weight only @@ -639,6 +687,18 @@ def loadModel(self): self.__model = yolo_main(Input(shape=(None, None, 3)), 3, len(self.__model_labels)) self.__model.load_weights(self.__model_path) + elif self.__model_type == "tinyyolov3": + detection_model_json = json.load(open(self.__detection_config_json_path)) + + self.__model_labels = detection_model_json["labels"] + self.__model_anchors = detection_model_json["anchors"] + + self.__detection_utils = CustomDetectionUtils(labels=self.__model_labels) + + self.__model = tiny_yolo_main(Input(shape=(None, None, 3)), 3, len(self.__model_labels)) + + self.__model.load_weights(self.__model_path) + def detectObjectsFromImage(self, input_image="", output_image_path="", input_type="file", output_type="file", extract_detected_objects=False, minimum_percentage_probability=50, nms_treshold=0.4, @@ -762,7 +822,7 @@ def detectObjectsFromImage(self, input_image="", output_image_path="", input_typ # expand the image to batch image = np.expand_dims(image, 0) - if self.__model_type == "yolov3": + if self.__model_type == "yolov3" or self.__model_type == "tinyyolov3": if thread_safe == True: with K.get_session().graph.as_default(): yolo_results = self.__model.predict(image) @@ -901,6 +961,15 @@ def loadModel(self): self.__detector = detector self.__model_loaded = True + elif(self.__model_type == "tinyyolov3"): + detector = CustomObjectDetection() + detector.setModelTypeAsTinyYOLOv3() + detector.setModelPath(self.__model_path) + detector.setJsonPath(self.__detection_config_json_path) + detector.loadModel() + + self.__detector = detector + self.__model_loaded = True def detectObjectsFromVideo(self, input_file_path="", camera_input=None, output_file_path="", frames_per_second=20, @@ -1015,7 +1084,7 @@ def detectObjectsFromVideo(self, input_file_path="", camera_input=None, output_f video_frames_count = 0 - if(self.__model_type == "yolov3"): + if(self.__model_type == "yolov3" or self.__model_type == "tinyyolov3"): diff --git a/imageai/Detection/Custom/tinyyolov3_generator.py b/imageai/Detection/Custom/tinyyolov3_generator.py new file mode 100644 index 00000000..329c4295 --- /dev/null +++ b/imageai/Detection/Custom/tinyyolov3_generator.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python + +import cv2 +import copy +import numpy as np +from keras.utils import Sequence +from imageai.Detection.Custom.utils.bbox import BoundBox, bbox_iou +from imageai.Detection.Custom.utils.image import apply_random_scale_and_crop, random_distort_image, random_flip, correct_bounding_boxes + +class TinyYOLOv3BatchGenerator(Sequence): + def __init__(self, + instances, + anchors, + labels, + downsample=32, # ratio between network input's size and network output's size, 32 for YOLOv3 + max_box_per_image=30, + batch_size=1, + min_net_size=320, + max_net_size=608, + shuffle=True, + jitter=True, + norm=None + ): + self.instances = instances + self.batch_size = batch_size + self.labels = labels + self.downsample = downsample + self.max_box_per_image = max_box_per_image + self.min_net_size = (min_net_size//self.downsample)*self.downsample + self.max_net_size = (max_net_size//self.downsample)*self.downsample + self.shuffle = shuffle + self.jitter = jitter + self.norm = norm + self.anchors = [BoundBox(0, 0, anchors[2*i], anchors[2*i+1]) for i in range(len(anchors)//2)] + self.net_h = 416 + self.net_w = 416 + + if shuffle: np.random.shuffle(self.instances) + + def __len__(self): + return int(np.ceil(float(len(self.instances))/self.batch_size)) + + def __getitem__(self, idx): + # get image input size, change every 10 batches + net_h, net_w = self._get_net_size(idx) + base_grid_h, base_grid_w = net_h//self.downsample, net_w//self.downsample + + # determine the first and the last indices of the batch + l_bound = idx * self.batch_size + r_bound = (idx+1) * self.batch_size + + if r_bound > len(self.instances): + r_bound = len(self.instances) + l_bound = r_bound - self.batch_size + + x_batch = np.zeros((r_bound - l_bound, net_h, net_w, 3)) # input images + t_batch = np.zeros((r_bound - l_bound, 1, 1, 1, self.max_box_per_image, 4)) # list of groundtruth boxes + + # initialize the inputs and the outputs + yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 1 + yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 2 + yolos = [yolo_2, yolo_1] + + dummy_yolo_1 = np.zeros((r_bound - l_bound, 1)) + dummy_yolo_2 = np.zeros_like(dummy_yolo_1) + + instance_count = 0 + true_box_index = 0 + + # do the logic to fill in the inputs and the output + for train_instance in self.instances[l_bound:r_bound]: + # augment input image and fix object's position and size + img, all_objs = self._aug_image(train_instance, net_h, net_w) + + for obj in all_objs: + # find the best anchor box for this object + max_anchor = None + max_index = -1 + max_iou = -1 + + shifted_box = BoundBox(0, + 0, + obj['xmax']-obj['xmin'], + obj['ymax']-obj['ymin']) + + for i in range(len(self.anchors)): + anchor = self.anchors[i] + iou = bbox_iou(shifted_box, anchor) + + if max_iou < iou: + max_anchor = anchor + max_index = i + max_iou = iou + + # determine the yolo to be responsible for this bounding box + + yolo = yolos[max_index // 2] + grid_h, grid_w = yolo.shape[1:3] + + # determine the position of the bounding box on the grid + center_x = .5*(obj['xmin'] + obj['xmax']) + center_x = center_x / float(net_w) * grid_w # sigma(t_x) + c_x + center_y = .5*(obj['ymin'] + obj['ymax']) + center_y = center_y / float(net_h) * grid_h # sigma(t_y) + c_y + + # determine the sizes of the bounding box + w = np.log((obj['xmax'] - obj['xmin']) / float(max_anchor.xmax)) # t_w + h = np.log((obj['ymax'] - obj['ymin']) / float(max_anchor.ymax)) # t_h + + box = [center_x, center_y, w, h] + + # determine the index of the label + obj_indx = self.labels.index(obj['name']) + + # determine the location of the cell responsible for this object + grid_x = int(np.floor(center_x)) + grid_y = int(np.floor(center_y)) + + # assign ground truth x, y, w, h, confidence and class probs to y_batch + yolo[instance_count, grid_y, grid_x, max_index%3] = 0 + yolo[instance_count, grid_y, grid_x, max_index%3, 0:4] = box + yolo[instance_count, grid_y, grid_x, max_index%3, 4 ] = 1. + yolo[instance_count, grid_y, grid_x, max_index%3, 5+obj_indx] = 1 + + # assign the true box to t_batch + true_box = [center_x, center_y, obj['xmax'] - obj['xmin'], obj['ymax'] - obj['ymin']] + t_batch[instance_count, 0, 0, 0, true_box_index] = true_box + + true_box_index += 1 + true_box_index = true_box_index % self.max_box_per_image + + # assign input image to x_batch + if self.norm != None: + x_batch[instance_count] = self.norm(img) + else: + # plot image and bounding boxes for sanity check + for obj in all_objs: + cv2.rectangle(img, (obj['xmin'],obj['ymin']), (obj['xmax'],obj['ymax']), (255,0,0), 3) + cv2.putText(img, obj['name'], + (obj['xmin']+2, obj['ymin']+12), + 0, 1.2e-3 * img.shape[0], + (0,255,0), 2) + + x_batch[instance_count] = img + + # increase instance counter in the current batch + instance_count += 1 + + return [x_batch, t_batch, yolo_1, yolo_2], [dummy_yolo_1, dummy_yolo_2] + + def _get_net_size(self, idx): + if idx % 10 == 0: + net_size = self.downsample*np.random.randint(self.min_net_size/self.downsample, \ + self.max_net_size/self.downsample+1) + + self.net_h, self.net_w = net_size, net_size + return self.net_h, self.net_w + + def _aug_image(self, instance, net_h, net_w): + image_name = instance['filename'] + image = cv2.imread(image_name) # BGR image + + if image is None: + print('Cannot find ', image_name) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB image + + image_h, image_w, _ = image.shape + + # determine the amount of scaling and cropping + dw = self.jitter * image_w + dh = self.jitter * image_h + + new_ar = (image_w + np.random.uniform(-dw, dw)) / (image_h + np.random.uniform(-dh, dh)) + scale = np.random.uniform(0.25, 2) + + if new_ar < 1: + new_h = int(scale * net_h) + new_w = int(net_h * new_ar) + else: + new_w = int(scale * net_w) + new_h = int(net_w / new_ar) + + dx = int(np.random.uniform(0, net_w - new_w)) + dy = int(np.random.uniform(0, net_h - new_h)) + + # apply scaling and cropping + im_sized = apply_random_scale_and_crop(image, new_w, new_h, net_w, net_h, dx, dy) + + # randomly distort hsv space + im_sized = random_distort_image(im_sized) + + # randomly flip + flip = np.random.randint(2) + im_sized = random_flip(im_sized, flip) + + # correct the size and pos of bounding boxes + all_objs = correct_bounding_boxes(instance['object'], new_w, new_h, net_w, net_h, dx, dy, flip, image_w, image_h) + + return im_sized, all_objs + + def on_epoch_end(self): + if self.shuffle: + np.random.shuffle(self.instances) + + def num_classes(self): + return len(self.labels) + + def size(self): + return len(self.instances) + + def get_anchors(self): + anchors = [] + + for anchor in self.anchors: + anchors += [anchor.xmax, anchor.ymax] + + return anchors + + def load_annotation(self, i): + annots = [] + + for obj in self.instances[i]['object']: + annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], self.labels.index(obj['name'])] + annots += [annot] + + if len(annots) == 0: + annots = [[]] + + return np.array(annots) + + def load_image(self, i): + return cv2.imread(self.instances[i]['filename']) # BGR image diff --git a/imageai/Detection/Custom/yolo.py b/imageai/Detection/Custom/yolo.py index f9661226..e256d747 100644 --- a/imageai/Detection/Custom/yolo.py +++ b/imageai/Detection/Custom/yolo.py @@ -282,7 +282,8 @@ def create_yolov3_model( # Layer 80 => 82 pred_yolo_1 = _conv_block(x, [{'filter': 1024, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 80}, {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 81}], do_skip=False) - loss_yolo_1 = YoloLayer(anchors[12:], + + loss_yolo_1 = YoloLayer(anchors[12:], [1*num for num in max_grid], batch_size, warmup_batches, @@ -298,13 +299,6 @@ def create_yolov3_model( x = UpSampling2D(2)(x) x = concatenate([x, skip_61]) - # Layer 87 => 91 - x = _conv_block(x, [{'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 87}, - {'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 88}, - {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 89}, - {'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 90}, - {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 91}], do_skip=False) - # Layer 92 => 94 pred_yolo_2 = _conv_block(x, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 92}, {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 93}], do_skip=False) @@ -348,5 +342,82 @@ def create_yolov3_model( return [train_model, infer_model] + +def create_tinyyolov3_model( + nb_class, + anchors, + max_box_per_image, + max_grid, + batch_size, + warmup_batches, + ignore_thresh, + grid_scales, + obj_scale, + noobj_scale, + xywh_scale, + class_scale +): + input_image = Input(shape=(None, None, 3)) # net_h, net_w, 3 + true_boxes = Input(shape=(1, 1, 1, max_box_per_image, 4)) + true_yolo_1 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class + true_yolo_2 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class + + # Layer 0 => 5 + network1 = _conv_block(input_image, [{'filter': 16, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 0}, + {'filter': 32, 'kernel': 3, 'stride': 2, 'bnorm': True, 'leaky': True, 'layer_idx': 1}, + {'filter': 64, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 2}, + {'filter': 128, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 3}, + {'filter': 256, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 4} + ], + do_skip=False) + + # Layer 6 => 8 + network2 = _conv_block(network1, [{'filter': 512, 'kernel': 3, 'stride': 2, 'bnorm': True, 'leaky': True, 'layer_idx': 6}, + {'filter': 1024, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 7}, + {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 8}], do_skip=False) + + # Layer 10 => 11 + # network3 + pred_yolo_1 = _conv_block(network2, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 9}, + {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 10}], do_skip=False) + # check this layer + loss_yolo_1 = YoloLayer(anchors[12:], + [1*num for num in max_grid], + batch_size, + warmup_batches, + ignore_thresh, + grid_scales[0], + obj_scale, + noobj_scale, + xywh_scale, + class_scale)([input_image, pred_yolo_1, true_yolo_1, true_boxes]) + + # Layer 12 + network2 = _conv_block(network2, [{'filter': 128, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 12},], do_skip=False) + + network2 = UpSampling2D(2)(network2) + network4 = concatenate([network2, network1]) + # network4 = _conv_block(network4, ) + + # Layer 92 => 94 + pred_yolo_2 = _conv_block(network4, [{'filter': 256, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 92}, + {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 93}], do_skip=False) + loss_yolo_2 = YoloLayer(anchors[:6], + [4*num for num in max_grid], + batch_size, + warmup_batches, + ignore_thresh, + grid_scales[2], + obj_scale, + noobj_scale, + xywh_scale, + class_scale)([input_image, pred_yolo_2, true_yolo_2, true_boxes]) + + train_model = Model([input_image, true_boxes, true_yolo_1, true_yolo_2], [loss_yolo_1, loss_yolo_2]) + infer_model = Model(input_image, [pred_yolo_1, pred_yolo_2]) + + return [train_model, infer_model] + + def dummy_loss(y_true, y_pred): - return tf.sqrt(tf.reduce_sum(y_pred)) \ No newline at end of file + return tf.sqrt(tf.reduce_sum(y_pred)) diff --git a/imageai/Detection/YOLOv3/models.py b/imageai/Detection/YOLOv3/models.py index 6d12124e..4ef585da 100644 --- a/imageai/Detection/YOLOv3/models.py +++ b/imageai/Detection/YOLOv3/models.py @@ -103,4 +103,4 @@ def tiny_yolo_main(input, num_anchors, num_classes): network_4 = NetworkConv2D_BN_Leaky(input=network_4, channels=256, kernel_size=(3, 3)) network_4 = Conv2D(num_anchors * (num_classes + 5), kernel_size=(1,1))(network_4) - return Model(input, [network_3, network_4]) \ No newline at end of file + return Model(input, [network_3, network_4]) From 1f7dd5a0bdfb71a9f93e9e146cc4a3991b748439 Mon Sep 17 00:00:00 2001 From: "Perry Gibson (gabriel)" Date: Sun, 6 Dec 2020 16:48:32 +0000 Subject: [PATCH 2/2] updated TinyYOLOv3 --- imageai/Detection/Custom/__init__.py | 9 ++++++++- imageai/Detection/Custom/gen_anchors.py | 4 ++-- imageai/Detection/Custom/tinyyolov3_generator.py | 15 +++++++-------- imageai/Detection/Custom/yolo.py | 7 ++++--- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/imageai/Detection/Custom/__init__.py b/imageai/Detection/Custom/__init__.py index ec6f7aed..2aabef31 100644 --- a/imageai/Detection/Custom/__init__.py +++ b/imageai/Detection/Custom/__init__.py @@ -175,9 +175,16 @@ def setTrainConfig(self, object_names_array, batch_size=4, num_experiments=100, :return: """ + if self.__model_type == "tinyyolov3": + num_anchors = 6 + elif self.__model_type == "yolov3": + num_anchors = 9 + else: + raise ValueError(f'Unsupported model type: {self.__model_type}') self.__model_anchors, self.__inference_anchors = generateAnchors(self.__train_annotations_folder, self.__train_images_folder, - self.__train_cache_file, self.__model_labels) + self.__train_cache_file, self.__model_labels, + num_anchors) self.__model_labels = sorted(object_names_array) self.__num_objects = len(object_names_array) diff --git a/imageai/Detection/Custom/gen_anchors.py b/imageai/Detection/Custom/gen_anchors.py index 693e6b21..4e4a29fc 100644 --- a/imageai/Detection/Custom/gen_anchors.py +++ b/imageai/Detection/Custom/gen_anchors.py @@ -70,10 +70,10 @@ def run_kmeans(ann_dims, anchor_num): old_distances = distances.copy() -def generateAnchors(train_annotation_folder, train_image_folder, train_cache_file, model_labels): +def generateAnchors(train_annotation_folder, train_image_folder, train_cache_file, model_labels, + num_anchors=9): print("Generating anchor boxes for training images and annotation...") - num_anchors = 9 train_imgs, train_labels = parse_voc_annotation( train_annotation_folder, diff --git a/imageai/Detection/Custom/tinyyolov3_generator.py b/imageai/Detection/Custom/tinyyolov3_generator.py index 329c4295..ef3a48c6 100644 --- a/imageai/Detection/Custom/tinyyolov3_generator.py +++ b/imageai/Detection/Custom/tinyyolov3_generator.py @@ -57,8 +57,8 @@ def __getitem__(self, idx): t_batch = np.zeros((r_bound - l_bound, 1, 1, 1, self.max_box_per_image, 4)) # list of groundtruth boxes # initialize the inputs and the outputs - yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 1 - yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 2 + yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//2, 4+1+len(self.labels))) # desired network output 1 + yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//2, 4+1+len(self.labels))) # desired network output 2 yolos = [yolo_2, yolo_1] dummy_yolo_1 = np.zeros((r_bound - l_bound, 1)) @@ -93,8 +93,7 @@ def __getitem__(self, idx): max_iou = iou # determine the yolo to be responsible for this bounding box - - yolo = yolos[max_index // 2] + yolo = yolos[max_index // 4] grid_h, grid_w = yolo.shape[1:3] # determine the position of the bounding box on the grid @@ -117,10 +116,10 @@ def __getitem__(self, idx): grid_y = int(np.floor(center_y)) # assign ground truth x, y, w, h, confidence and class probs to y_batch - yolo[instance_count, grid_y, grid_x, max_index%3] = 0 - yolo[instance_count, grid_y, grid_x, max_index%3, 0:4] = box - yolo[instance_count, grid_y, grid_x, max_index%3, 4 ] = 1. - yolo[instance_count, grid_y, grid_x, max_index%3, 5+obj_indx] = 1 + yolo[instance_count, grid_y, grid_x, max_index%1] = 0 + yolo[instance_count, grid_y, grid_x, max_index%1, 0:4] = box + yolo[instance_count, grid_y, grid_x, max_index%1, 4 ] = 1. + yolo[instance_count, grid_y, grid_x, max_index%1, 5+obj_indx] = 1 # assign the true box to t_batch true_box = [center_x, center_y, obj['xmax'] - obj['xmin'], obj['ymax'] - obj['ymin']] diff --git a/imageai/Detection/Custom/yolo.py b/imageai/Detection/Custom/yolo.py index e256d747..a63ec81f 100644 --- a/imageai/Detection/Custom/yolo.py +++ b/imageai/Detection/Custom/yolo.py @@ -357,10 +357,11 @@ def create_tinyyolov3_model( xywh_scale, class_scale ): + print('test hello\n\n\n\n\n\n', len(anchors), len(anchors)//4) input_image = Input(shape=(None, None, 3)) # net_h, net_w, 3 true_boxes = Input(shape=(1, 1, 1, max_box_per_image, 4)) - true_yolo_1 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class - true_yolo_2 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class + true_yolo_1 = Input(shape=(None, None, len(anchors)//4, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class + true_yolo_2 = Input(shape=(None, None, len(anchors)//4, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class # Layer 0 => 5 network1 = _conv_block(input_image, [{'filter': 16, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 0}, @@ -381,7 +382,7 @@ def create_tinyyolov3_model( pred_yolo_1 = _conv_block(network2, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 9}, {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 10}], do_skip=False) # check this layer - loss_yolo_1 = YoloLayer(anchors[12:], + loss_yolo_1 = YoloLayer(anchors[6:], [1*num for num in max_grid], batch_size, warmup_batches,