Skip to content
Merged
90 changes: 69 additions & 21 deletions orangecontrib/imageanalytics/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@
'target_image_size': (256, 256),
'layers': ['penultimate']
},
'deeploc': {
'name': 'DeepLoc',
'description': 'A model trained to analyze yeast cell images.',
'target_image_size': (64, 64),
'layers': ['penultimate']
},
'vgg16': {
'name': 'VGG-16',
'description': '16-layer image recognition model trained on ImageNet.',
'target_image_size': (224, 224),
'layers': ['penultimate']
},
'vgg19': {
'name': 'VGG-19',
'description': '19-layer image recognition model trained on ImageNet.',
'target_image_size': (224, 224),
'layers': ['penultimate']
}
}


Expand Down Expand Up @@ -66,9 +84,11 @@ class ImageEmbedder(Http2Client):
>>> embedded_images, skipped_images, num_skipped = image_embedder(images)
"""
_cache_file_blueprint = '{:s}_{:s}_embeddings.pickle'
MAX_REPEATS = 4
CANNOT_LOAD = "cannot load"

def __init__(self, model="inception-v3", layer="penultimate",
server_url='api.biolab.si:8080'):
server_url='api.garaza.io:443'):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to switch to a different URL with a separate PR, and not add it here?
Note that this URL is also temporary and will probably change as soon as we promote the new test k8s cluster into a production.

super().__init__(server_url)
model_settings = self._get_model_settings_confidently(model, layer)
self._model = model
Expand Down Expand Up @@ -163,26 +183,45 @@ def from_file_paths(self, file_paths, image_processed_callback=None):
if not self.is_connected_to_server():
self.reconnect_to_server()

all_embeddings = []
all_embeddings = [None] * len(file_paths)
repeats_counter = 0

# repeat while all images has embeddings or
# while counter counts out (prevents cycling)
while len([el for el in all_embeddings if el is None]) > 0 and \
repeats_counter < self.MAX_REPEATS:

# take all images without embeddings yet
selected_indices = [i for i, v in enumerate(all_embeddings)
if v is None]
file_paths_wo_emb = [(file_paths[i], i) for i in selected_indices]

for batch in self._yield_in_batches(file_paths_wo_emb):
b_images, b_indices = zip(*batch)
try:
embeddings = self._send_to_server(
b_images, image_processed_callback
)
except MaxNumberOfRequestsError:
# maximum number of http2 requests through a single
# connection is exceeded and a remote peer has closed
# the connection so establish a new connection and retry
# with the same batch (should happen rarely as the setting
# is usually set to >= 1000 requests in http2)
self.reconnect_to_server()
embeddings = [None] * len(batch)

# insert embeddings into the list
for i, emb in zip(b_indices, embeddings):
all_embeddings[i] = emb

for batch in self._yield_in_batches(file_paths):
try:
embeddings = self._send_to_server(
batch, image_processed_callback
)
except MaxNumberOfRequestsError:
# maximum number of http2 requests through a single
# connection is exceeded and a remote peer has closed
# the connection so establish a new connection and retry
# with the same batch (should happen rarely as the setting
# is usually set to >= 1000 requests in http2)
self.reconnect_to_server()
embeddings = self._send_to_server(
batch, image_processed_callback
)
self.persist_cache()
repeats_counter += 1

all_embeddings += embeddings
self.persist_cache()
# change images that were not loaded from 'cannot loaded' to None
all_embeddings = \
[None if not isinstance(el, np.ndarray) and el == self.CANNOT_LOAD
else el for el in all_embeddings]

return np.array(all_embeddings)

Expand Down Expand Up @@ -304,14 +343,23 @@ def _get_responses_from_server(self, http_streams, cache_keys,
if self.cancelled:
raise EmbeddingCancelledException()

if not stream_id and not cache_key:
# when image cannot be loaded
embeddings.append(self.CANNOT_LOAD)

if image_processed_callback:
image_processed_callback(success=False)
continue


if not stream_id:
# skip rest of the waiting because image was either
# skipped at loading or is present in the local cache
embedding = self._get_cached_result_or_none(cache_key)
embeddings.append(embedding)

if image_processed_callback:
image_processed_callback()
image_processed_callback(success=embedding is not None)
continue

try:
Expand All @@ -331,7 +379,7 @@ def _get_responses_from_server(self, http_streams, cache_keys,
self._cache_dict[cache_key] = embedding

if image_processed_callback:
image_processed_callback()
image_processed_callback(embeddings[-1] is not None)

return embeddings

Expand Down
5 changes: 3 additions & 2 deletions orangecontrib/imageanalytics/widgets/owimageembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def commit(self):
set_progress = qconcurrent.methodinvoke(
self, "__progress_set", (float,))

def advance():
set_progress(next(ticks))
def advance(success=True):
if success:
set_progress(next(ticks))

def cancel():
task.future.cancel()
Expand Down