From b08fcbebda914988bfc6df33fb811046fee76ab4 Mon Sep 17 00:00:00 2001
From: "Perry Gibson (gabriel)"
Date: Sun, 6 Dec 2020 15:38:30 +0000
Subject: [PATCH 1/2] added tinyyolov3 training
---
imageai/Detection/Custom/__init__.py | 109 ++++++--
.../Detection/Custom/tinyyolov3_generator.py | 233 ++++++++++++++++++
imageai/Detection/Custom/yolo.py | 89 ++++++-
imageai/Detection/YOLOv3/models.py | 2 +-
4 files changed, 403 insertions(+), 30 deletions(-)
create mode 100644 imageai/Detection/Custom/tinyyolov3_generator.py
diff --git a/imageai/Detection/Custom/__init__.py b/imageai/Detection/Custom/__init__.py
index 50c4931b..ec6f7aed 100644
--- a/imageai/Detection/Custom/__init__.py
+++ b/imageai/Detection/Custom/__init__.py
@@ -3,9 +3,10 @@
import numpy as np
import json
from imageai.Detection.Custom.voc import parse_voc_annotation
-from imageai.Detection.Custom.yolo import create_yolov3_model, dummy_loss
-from imageai.Detection.YOLOv3.models import yolo_main
+from imageai.Detection.Custom.yolo import create_yolov3_model, create_tinyyolov3_model, dummy_loss
+from imageai.Detection.YOLOv3.models import yolo_main, tiny_yolo_main
from imageai.Detection.Custom.generator import BatchGenerator
+from imageai.Detection.Custom.tinyyolov3_generator import TinyYOLOv3BatchGenerator
from imageai.Detection.Custom.utils.utils import normalize, evaluate, makedirs
from keras.callbacks import ReduceLROnPlateau
from keras.optimizers import Adam
@@ -74,6 +75,14 @@ def setModelTypeAsYOLOv3(self):
"""
self.__model_type = "yolov3"
+ def setModelTypeAsTinyYOLOv3(self):
+ """
+ 'setModelTypeAsTinyYOLOv3()' is used to set the model type to the TinyYOLOv3 model
+ for the training instance object .
+ :return:
+ """
+ self.__model_type = "tinyyolov3"
+
def setDataDirectory(self, data_directory):
"""
@@ -220,6 +229,8 @@ def trainModel(self):
###############################
# Create the generators
###############################
+ if self.__model_type == "tinyyolov3":
+ BatchGenerator = TinyYOLOv3BatchGenerator
train_generator = BatchGenerator(
instances=train_ints,
anchors=self.__model_anchors,
@@ -527,12 +538,47 @@ def _create_model(
):
if len(multi_gpu) > 1:
with tf.device('/cpu:0'):
+ if self.__model_type == "yolov3":
+ template_model, infer_model = create_yolov3_model(
+ nb_class=nb_class,
+ anchors=anchors,
+ max_box_per_image=max_box_per_image,
+ max_grid=max_grid,
+ batch_size=batch_size // len(multi_gpu),
+ warmup_batches=warmup_batches,
+ ignore_thresh=ignore_thresh,
+ grid_scales=grid_scales,
+ obj_scale=obj_scale,
+ noobj_scale=noobj_scale,
+ xywh_scale=xywh_scale,
+ class_scale=class_scale
+ )
+ elif self.__model_type == "tinyyolov3":
+ template_model, infer_model = create_tinyyolov3_model(
+ nb_class=nb_class,
+ anchors=anchors,
+ max_box_per_image=max_box_per_image,
+ max_grid=max_grid,
+ batch_size=batch_size // len(multi_gpu),
+ warmup_batches=warmup_batches,
+ ignore_thresh=ignore_thresh,
+ grid_scales=grid_scales,
+ obj_scale=obj_scale,
+ noobj_scale=noobj_scale,
+ xywh_scale=xywh_scale,
+ class_scale=class_scale
+ )
+ else:
+ raise ValueError(f'Unsupported model type: {self.__model_type}')
+ else:
+ print('Hello world\n\n\n\n\n', self.__model_type)
+ if self.__model_type == "yolov3":
template_model, infer_model = create_yolov3_model(
nb_class=nb_class,
anchors=anchors,
max_box_per_image=max_box_per_image,
max_grid=max_grid,
- batch_size=batch_size // len(multi_gpu),
+ batch_size=batch_size,
warmup_batches=warmup_batches,
ignore_thresh=ignore_thresh,
grid_scales=grid_scales,
@@ -541,21 +587,23 @@ def _create_model(
xywh_scale=xywh_scale,
class_scale=class_scale
)
- else:
- template_model, infer_model = create_yolov3_model(
- nb_class=nb_class,
- anchors=anchors,
- max_box_per_image=max_box_per_image,
- max_grid=max_grid,
- batch_size=batch_size,
- warmup_batches=warmup_batches,
- ignore_thresh=ignore_thresh,
- grid_scales=grid_scales,
- obj_scale=obj_scale,
- noobj_scale=noobj_scale,
- xywh_scale=xywh_scale,
- class_scale=class_scale
- )
+ elif self.__model_type == "tinyyolov3":
+ template_model, infer_model = create_tinyyolov3_model(
+ nb_class=nb_class,
+ anchors=anchors,
+ max_box_per_image=max_box_per_image,
+ max_grid=max_grid,
+ batch_size=batch_size,
+ warmup_batches=warmup_batches,
+ ignore_thresh=ignore_thresh,
+ grid_scales=grid_scales,
+ obj_scale=obj_scale,
+ noobj_scale=noobj_scale,
+ xywh_scale=xywh_scale,
+ class_scale=class_scale
+ )
+ else:
+ raise ValueError(f'Unsupported model type: {self.__model_type}')
# load the pretrained weight if exists, otherwise load the backend weight only
@@ -639,6 +687,18 @@ def loadModel(self):
self.__model = yolo_main(Input(shape=(None, None, 3)), 3, len(self.__model_labels))
self.__model.load_weights(self.__model_path)
+ elif self.__model_type == "tinyyolov3":
+ detection_model_json = json.load(open(self.__detection_config_json_path))
+
+ self.__model_labels = detection_model_json["labels"]
+ self.__model_anchors = detection_model_json["anchors"]
+
+ self.__detection_utils = CustomDetectionUtils(labels=self.__model_labels)
+
+ self.__model = tiny_yolo_main(Input(shape=(None, None, 3)), 3, len(self.__model_labels))
+
+ self.__model.load_weights(self.__model_path)
+
def detectObjectsFromImage(self, input_image="", output_image_path="", input_type="file", output_type="file",
extract_detected_objects=False, minimum_percentage_probability=50, nms_treshold=0.4,
@@ -762,7 +822,7 @@ def detectObjectsFromImage(self, input_image="", output_image_path="", input_typ
# expand the image to batch
image = np.expand_dims(image, 0)
- if self.__model_type == "yolov3":
+ if self.__model_type == "yolov3" or self.__model_type == "tinyyolov3":
if thread_safe == True:
with K.get_session().graph.as_default():
yolo_results = self.__model.predict(image)
@@ -901,6 +961,15 @@ def loadModel(self):
self.__detector = detector
self.__model_loaded = True
+ elif(self.__model_type == "tinyyolov3"):
+ detector = CustomObjectDetection()
+ detector.setModelTypeAsTinyYOLOv3()
+ detector.setModelPath(self.__model_path)
+ detector.setJsonPath(self.__detection_config_json_path)
+ detector.loadModel()
+
+ self.__detector = detector
+ self.__model_loaded = True
def detectObjectsFromVideo(self, input_file_path="", camera_input=None, output_file_path="", frames_per_second=20,
@@ -1015,7 +1084,7 @@ def detectObjectsFromVideo(self, input_file_path="", camera_input=None, output_f
video_frames_count = 0
- if(self.__model_type == "yolov3"):
+ if(self.__model_type == "yolov3" or self.__model_type == "tinyyolov3"):
diff --git a/imageai/Detection/Custom/tinyyolov3_generator.py b/imageai/Detection/Custom/tinyyolov3_generator.py
new file mode 100644
index 00000000..329c4295
--- /dev/null
+++ b/imageai/Detection/Custom/tinyyolov3_generator.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python
+
+import cv2
+import copy
+import numpy as np
+from keras.utils import Sequence
+from imageai.Detection.Custom.utils.bbox import BoundBox, bbox_iou
+from imageai.Detection.Custom.utils.image import apply_random_scale_and_crop, random_distort_image, random_flip, correct_bounding_boxes
+
+class TinyYOLOv3BatchGenerator(Sequence):
+ def __init__(self,
+ instances,
+ anchors,
+ labels,
+ downsample=32, # ratio between network input's size and network output's size, 32 for YOLOv3
+ max_box_per_image=30,
+ batch_size=1,
+ min_net_size=320,
+ max_net_size=608,
+ shuffle=True,
+ jitter=True,
+ norm=None
+ ):
+ self.instances = instances
+ self.batch_size = batch_size
+ self.labels = labels
+ self.downsample = downsample
+ self.max_box_per_image = max_box_per_image
+ self.min_net_size = (min_net_size//self.downsample)*self.downsample
+ self.max_net_size = (max_net_size//self.downsample)*self.downsample
+ self.shuffle = shuffle
+ self.jitter = jitter
+ self.norm = norm
+ self.anchors = [BoundBox(0, 0, anchors[2*i], anchors[2*i+1]) for i in range(len(anchors)//2)]
+ self.net_h = 416
+ self.net_w = 416
+
+ if shuffle: np.random.shuffle(self.instances)
+
+ def __len__(self):
+ return int(np.ceil(float(len(self.instances))/self.batch_size))
+
+ def __getitem__(self, idx):
+ # get image input size, change every 10 batches
+ net_h, net_w = self._get_net_size(idx)
+ base_grid_h, base_grid_w = net_h//self.downsample, net_w//self.downsample
+
+ # determine the first and the last indices of the batch
+ l_bound = idx * self.batch_size
+ r_bound = (idx+1) * self.batch_size
+
+ if r_bound > len(self.instances):
+ r_bound = len(self.instances)
+ l_bound = r_bound - self.batch_size
+
+ x_batch = np.zeros((r_bound - l_bound, net_h, net_w, 3)) # input images
+ t_batch = np.zeros((r_bound - l_bound, 1, 1, 1, self.max_box_per_image, 4)) # list of groundtruth boxes
+
+ # initialize the inputs and the outputs
+ yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 1
+ yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 2
+ yolos = [yolo_2, yolo_1]
+
+ dummy_yolo_1 = np.zeros((r_bound - l_bound, 1))
+ dummy_yolo_2 = np.zeros_like(dummy_yolo_1)
+
+ instance_count = 0
+ true_box_index = 0
+
+ # do the logic to fill in the inputs and the output
+ for train_instance in self.instances[l_bound:r_bound]:
+ # augment input image and fix object's position and size
+ img, all_objs = self._aug_image(train_instance, net_h, net_w)
+
+ for obj in all_objs:
+ # find the best anchor box for this object
+ max_anchor = None
+ max_index = -1
+ max_iou = -1
+
+ shifted_box = BoundBox(0,
+ 0,
+ obj['xmax']-obj['xmin'],
+ obj['ymax']-obj['ymin'])
+
+ for i in range(len(self.anchors)):
+ anchor = self.anchors[i]
+ iou = bbox_iou(shifted_box, anchor)
+
+ if max_iou < iou:
+ max_anchor = anchor
+ max_index = i
+ max_iou = iou
+
+ # determine the yolo to be responsible for this bounding box
+
+ yolo = yolos[max_index // 2]
+ grid_h, grid_w = yolo.shape[1:3]
+
+ # determine the position of the bounding box on the grid
+ center_x = .5*(obj['xmin'] + obj['xmax'])
+ center_x = center_x / float(net_w) * grid_w # sigma(t_x) + c_x
+ center_y = .5*(obj['ymin'] + obj['ymax'])
+ center_y = center_y / float(net_h) * grid_h # sigma(t_y) + c_y
+
+ # determine the sizes of the bounding box
+ w = np.log((obj['xmax'] - obj['xmin']) / float(max_anchor.xmax)) # t_w
+ h = np.log((obj['ymax'] - obj['ymin']) / float(max_anchor.ymax)) # t_h
+
+ box = [center_x, center_y, w, h]
+
+ # determine the index of the label
+ obj_indx = self.labels.index(obj['name'])
+
+ # determine the location of the cell responsible for this object
+ grid_x = int(np.floor(center_x))
+ grid_y = int(np.floor(center_y))
+
+ # assign ground truth x, y, w, h, confidence and class probs to y_batch
+ yolo[instance_count, grid_y, grid_x, max_index%3] = 0
+ yolo[instance_count, grid_y, grid_x, max_index%3, 0:4] = box
+ yolo[instance_count, grid_y, grid_x, max_index%3, 4 ] = 1.
+ yolo[instance_count, grid_y, grid_x, max_index%3, 5+obj_indx] = 1
+
+ # assign the true box to t_batch
+ true_box = [center_x, center_y, obj['xmax'] - obj['xmin'], obj['ymax'] - obj['ymin']]
+ t_batch[instance_count, 0, 0, 0, true_box_index] = true_box
+
+ true_box_index += 1
+ true_box_index = true_box_index % self.max_box_per_image
+
+ # assign input image to x_batch
+ if self.norm != None:
+ x_batch[instance_count] = self.norm(img)
+ else:
+ # plot image and bounding boxes for sanity check
+ for obj in all_objs:
+ cv2.rectangle(img, (obj['xmin'],obj['ymin']), (obj['xmax'],obj['ymax']), (255,0,0), 3)
+ cv2.putText(img, obj['name'],
+ (obj['xmin']+2, obj['ymin']+12),
+ 0, 1.2e-3 * img.shape[0],
+ (0,255,0), 2)
+
+ x_batch[instance_count] = img
+
+ # increase instance counter in the current batch
+ instance_count += 1
+
+ return [x_batch, t_batch, yolo_1, yolo_2], [dummy_yolo_1, dummy_yolo_2]
+
+ def _get_net_size(self, idx):
+ if idx % 10 == 0:
+ net_size = self.downsample*np.random.randint(self.min_net_size/self.downsample, \
+ self.max_net_size/self.downsample+1)
+
+ self.net_h, self.net_w = net_size, net_size
+ return self.net_h, self.net_w
+
+ def _aug_image(self, instance, net_h, net_w):
+ image_name = instance['filename']
+ image = cv2.imread(image_name) # BGR image
+
+ if image is None:
+ print('Cannot find ', image_name)
+
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB image
+
+ image_h, image_w, _ = image.shape
+
+ # determine the amount of scaling and cropping
+ dw = self.jitter * image_w
+ dh = self.jitter * image_h
+
+ new_ar = (image_w + np.random.uniform(-dw, dw)) / (image_h + np.random.uniform(-dh, dh))
+ scale = np.random.uniform(0.25, 2)
+
+ if new_ar < 1:
+ new_h = int(scale * net_h)
+ new_w = int(net_h * new_ar)
+ else:
+ new_w = int(scale * net_w)
+ new_h = int(net_w / new_ar)
+
+ dx = int(np.random.uniform(0, net_w - new_w))
+ dy = int(np.random.uniform(0, net_h - new_h))
+
+ # apply scaling and cropping
+ im_sized = apply_random_scale_and_crop(image, new_w, new_h, net_w, net_h, dx, dy)
+
+ # randomly distort hsv space
+ im_sized = random_distort_image(im_sized)
+
+ # randomly flip
+ flip = np.random.randint(2)
+ im_sized = random_flip(im_sized, flip)
+
+ # correct the size and pos of bounding boxes
+ all_objs = correct_bounding_boxes(instance['object'], new_w, new_h, net_w, net_h, dx, dy, flip, image_w, image_h)
+
+ return im_sized, all_objs
+
+ def on_epoch_end(self):
+ if self.shuffle:
+ np.random.shuffle(self.instances)
+
+ def num_classes(self):
+ return len(self.labels)
+
+ def size(self):
+ return len(self.instances)
+
+ def get_anchors(self):
+ anchors = []
+
+ for anchor in self.anchors:
+ anchors += [anchor.xmax, anchor.ymax]
+
+ return anchors
+
+ def load_annotation(self, i):
+ annots = []
+
+ for obj in self.instances[i]['object']:
+ annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], self.labels.index(obj['name'])]
+ annots += [annot]
+
+ if len(annots) == 0:
+ annots = [[]]
+
+ return np.array(annots)
+
+ def load_image(self, i):
+ return cv2.imread(self.instances[i]['filename']) # BGR image
diff --git a/imageai/Detection/Custom/yolo.py b/imageai/Detection/Custom/yolo.py
index f9661226..e256d747 100644
--- a/imageai/Detection/Custom/yolo.py
+++ b/imageai/Detection/Custom/yolo.py
@@ -282,7 +282,8 @@ def create_yolov3_model(
# Layer 80 => 82
pred_yolo_1 = _conv_block(x, [{'filter': 1024, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 80},
{'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 81}], do_skip=False)
- loss_yolo_1 = YoloLayer(anchors[12:],
+
+ loss_yolo_1 = YoloLayer(anchors[12:],
[1*num for num in max_grid],
batch_size,
warmup_batches,
@@ -298,13 +299,6 @@ def create_yolov3_model(
x = UpSampling2D(2)(x)
x = concatenate([x, skip_61])
- # Layer 87 => 91
- x = _conv_block(x, [{'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 87},
- {'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 88},
- {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 89},
- {'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 90},
- {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 91}], do_skip=False)
-
# Layer 92 => 94
pred_yolo_2 = _conv_block(x, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 92},
{'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 93}], do_skip=False)
@@ -348,5 +342,82 @@ def create_yolov3_model(
return [train_model, infer_model]
+
+def create_tinyyolov3_model(
+ nb_class,
+ anchors,
+ max_box_per_image,
+ max_grid,
+ batch_size,
+ warmup_batches,
+ ignore_thresh,
+ grid_scales,
+ obj_scale,
+ noobj_scale,
+ xywh_scale,
+ class_scale
+):
+ input_image = Input(shape=(None, None, 3)) # net_h, net_w, 3
+ true_boxes = Input(shape=(1, 1, 1, max_box_per_image, 4))
+ true_yolo_1 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
+ true_yolo_2 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
+
+ # Layer 0 => 5
+ network1 = _conv_block(input_image, [{'filter': 16, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 0},
+ {'filter': 32, 'kernel': 3, 'stride': 2, 'bnorm': True, 'leaky': True, 'layer_idx': 1},
+ {'filter': 64, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 2},
+ {'filter': 128, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 3},
+ {'filter': 256, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 4}
+ ],
+ do_skip=False)
+
+ # Layer 6 => 8
+ network2 = _conv_block(network1, [{'filter': 512, 'kernel': 3, 'stride': 2, 'bnorm': True, 'leaky': True, 'layer_idx': 6},
+ {'filter': 1024, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 7},
+ {'filter': 256, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 8}], do_skip=False)
+
+ # Layer 10 => 11
+ # network3
+ pred_yolo_1 = _conv_block(network2, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 9},
+ {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 10}], do_skip=False)
+ # check this layer
+ loss_yolo_1 = YoloLayer(anchors[12:],
+ [1*num for num in max_grid],
+ batch_size,
+ warmup_batches,
+ ignore_thresh,
+ grid_scales[0],
+ obj_scale,
+ noobj_scale,
+ xywh_scale,
+ class_scale)([input_image, pred_yolo_1, true_yolo_1, true_boxes])
+
+ # Layer 12
+ network2 = _conv_block(network2, [{'filter': 128, 'kernel': 1, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 12},], do_skip=False)
+
+ network2 = UpSampling2D(2)(network2)
+ network4 = concatenate([network2, network1])
+ # network4 = _conv_block(network4, )
+
+ # Layer 92 => 94
+ pred_yolo_2 = _conv_block(network4, [{'filter': 256, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 92},
+ {'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 93}], do_skip=False)
+ loss_yolo_2 = YoloLayer(anchors[:6],
+ [4*num for num in max_grid],
+ batch_size,
+ warmup_batches,
+ ignore_thresh,
+ grid_scales[2],
+ obj_scale,
+ noobj_scale,
+ xywh_scale,
+ class_scale)([input_image, pred_yolo_2, true_yolo_2, true_boxes])
+
+ train_model = Model([input_image, true_boxes, true_yolo_1, true_yolo_2], [loss_yolo_1, loss_yolo_2])
+ infer_model = Model(input_image, [pred_yolo_1, pred_yolo_2])
+
+ return [train_model, infer_model]
+
+
def dummy_loss(y_true, y_pred):
- return tf.sqrt(tf.reduce_sum(y_pred))
\ No newline at end of file
+ return tf.sqrt(tf.reduce_sum(y_pred))
diff --git a/imageai/Detection/YOLOv3/models.py b/imageai/Detection/YOLOv3/models.py
index 6d12124e..4ef585da 100644
--- a/imageai/Detection/YOLOv3/models.py
+++ b/imageai/Detection/YOLOv3/models.py
@@ -103,4 +103,4 @@ def tiny_yolo_main(input, num_anchors, num_classes):
network_4 = NetworkConv2D_BN_Leaky(input=network_4, channels=256, kernel_size=(3, 3))
network_4 = Conv2D(num_anchors * (num_classes + 5), kernel_size=(1,1))(network_4)
- return Model(input, [network_3, network_4])
\ No newline at end of file
+ return Model(input, [network_3, network_4])
From 1f7dd5a0bdfb71a9f93e9e146cc4a3991b748439 Mon Sep 17 00:00:00 2001
From: "Perry Gibson (gabriel)"
Date: Sun, 6 Dec 2020 16:48:32 +0000
Subject: [PATCH 2/2] updated TinyYOLOv3
---
imageai/Detection/Custom/__init__.py | 9 ++++++++-
imageai/Detection/Custom/gen_anchors.py | 4 ++--
imageai/Detection/Custom/tinyyolov3_generator.py | 15 +++++++--------
imageai/Detection/Custom/yolo.py | 7 ++++---
4 files changed, 21 insertions(+), 14 deletions(-)
diff --git a/imageai/Detection/Custom/__init__.py b/imageai/Detection/Custom/__init__.py
index ec6f7aed..2aabef31 100644
--- a/imageai/Detection/Custom/__init__.py
+++ b/imageai/Detection/Custom/__init__.py
@@ -175,9 +175,16 @@ def setTrainConfig(self, object_names_array, batch_size=4, num_experiments=100,
:return:
"""
+ if self.__model_type == "tinyyolov3":
+ num_anchors = 6
+ elif self.__model_type == "yolov3":
+ num_anchors = 9
+ else:
+ raise ValueError(f'Unsupported model type: {self.__model_type}')
self.__model_anchors, self.__inference_anchors = generateAnchors(self.__train_annotations_folder,
self.__train_images_folder,
- self.__train_cache_file, self.__model_labels)
+ self.__train_cache_file, self.__model_labels,
+ num_anchors)
self.__model_labels = sorted(object_names_array)
self.__num_objects = len(object_names_array)
diff --git a/imageai/Detection/Custom/gen_anchors.py b/imageai/Detection/Custom/gen_anchors.py
index 693e6b21..4e4a29fc 100644
--- a/imageai/Detection/Custom/gen_anchors.py
+++ b/imageai/Detection/Custom/gen_anchors.py
@@ -70,10 +70,10 @@ def run_kmeans(ann_dims, anchor_num):
old_distances = distances.copy()
-def generateAnchors(train_annotation_folder, train_image_folder, train_cache_file, model_labels):
+def generateAnchors(train_annotation_folder, train_image_folder, train_cache_file, model_labels,
+ num_anchors=9):
print("Generating anchor boxes for training images and annotation...")
- num_anchors = 9
train_imgs, train_labels = parse_voc_annotation(
train_annotation_folder,
diff --git a/imageai/Detection/Custom/tinyyolov3_generator.py b/imageai/Detection/Custom/tinyyolov3_generator.py
index 329c4295..ef3a48c6 100644
--- a/imageai/Detection/Custom/tinyyolov3_generator.py
+++ b/imageai/Detection/Custom/tinyyolov3_generator.py
@@ -57,8 +57,8 @@ def __getitem__(self, idx):
t_batch = np.zeros((r_bound - l_bound, 1, 1, 1, self.max_box_per_image, 4)) # list of groundtruth boxes
# initialize the inputs and the outputs
- yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 1
- yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//3, 4+1+len(self.labels))) # desired network output 2
+ yolo_1 = np.zeros((r_bound - l_bound, 1*base_grid_h, 1*base_grid_w, len(self.anchors)//2, 4+1+len(self.labels))) # desired network output 1
+ yolo_2 = np.zeros((r_bound - l_bound, 2*base_grid_h, 2*base_grid_w, len(self.anchors)//2, 4+1+len(self.labels))) # desired network output 2
yolos = [yolo_2, yolo_1]
dummy_yolo_1 = np.zeros((r_bound - l_bound, 1))
@@ -93,8 +93,7 @@ def __getitem__(self, idx):
max_iou = iou
# determine the yolo to be responsible for this bounding box
-
- yolo = yolos[max_index // 2]
+ yolo = yolos[max_index // 4]
grid_h, grid_w = yolo.shape[1:3]
# determine the position of the bounding box on the grid
@@ -117,10 +116,10 @@ def __getitem__(self, idx):
grid_y = int(np.floor(center_y))
# assign ground truth x, y, w, h, confidence and class probs to y_batch
- yolo[instance_count, grid_y, grid_x, max_index%3] = 0
- yolo[instance_count, grid_y, grid_x, max_index%3, 0:4] = box
- yolo[instance_count, grid_y, grid_x, max_index%3, 4 ] = 1.
- yolo[instance_count, grid_y, grid_x, max_index%3, 5+obj_indx] = 1
+ yolo[instance_count, grid_y, grid_x, max_index%1] = 0
+ yolo[instance_count, grid_y, grid_x, max_index%1, 0:4] = box
+ yolo[instance_count, grid_y, grid_x, max_index%1, 4 ] = 1.
+ yolo[instance_count, grid_y, grid_x, max_index%1, 5+obj_indx] = 1
# assign the true box to t_batch
true_box = [center_x, center_y, obj['xmax'] - obj['xmin'], obj['ymax'] - obj['ymin']]
diff --git a/imageai/Detection/Custom/yolo.py b/imageai/Detection/Custom/yolo.py
index e256d747..a63ec81f 100644
--- a/imageai/Detection/Custom/yolo.py
+++ b/imageai/Detection/Custom/yolo.py
@@ -357,10 +357,11 @@ def create_tinyyolov3_model(
xywh_scale,
class_scale
):
+ print('test hello\n\n\n\n\n\n', len(anchors), len(anchors)//4)
input_image = Input(shape=(None, None, 3)) # net_h, net_w, 3
true_boxes = Input(shape=(1, 1, 1, max_box_per_image, 4))
- true_yolo_1 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
- true_yolo_2 = Input(shape=(None, None, len(anchors)//6, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
+ true_yolo_1 = Input(shape=(None, None, len(anchors)//4, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
+ true_yolo_2 = Input(shape=(None, None, len(anchors)//4, 4+1+nb_class)) # grid_h, grid_w, nb_anchor, 5+nb_class
# Layer 0 => 5
network1 = _conv_block(input_image, [{'filter': 16, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 0},
@@ -381,7 +382,7 @@ def create_tinyyolov3_model(
pred_yolo_1 = _conv_block(network2, [{'filter': 512, 'kernel': 3, 'stride': 1, 'bnorm': True, 'leaky': True, 'layer_idx': 9},
{'filter': (3*(5+nb_class)), 'kernel': 1, 'stride': 1, 'bnorm': False, 'leaky': False, 'layer_idx': 10}], do_skip=False)
# check this layer
- loss_yolo_1 = YoloLayer(anchors[12:],
+ loss_yolo_1 = YoloLayer(anchors[6:],
[1*num for num in max_grid],
batch_size,
warmup_batches,