diff --git a/.gitignore b/.gitignore index 1dd766a18c..3d0511c445 100644 --- a/.gitignore +++ b/.gitignore @@ -190,7 +190,6 @@ hyper.sublime-* .kube-temp hyper_search/ -yolo/ logs/ # other diff --git a/deeplake/__init__.py b/deeplake/__init__.py index 095f076144..31f845174a 100644 --- a/deeplake/__init__.py +++ b/deeplake/__init__.py @@ -44,6 +44,7 @@ ingest = api_dataset.ingest connect = api_dataset.connect ingest_coco = api_dataset.ingest_coco +ingest_yolo = api_dataset.ingest_yolo ingest_kaggle = api_dataset.ingest_kaggle ingest_dataframe = api_dataset.ingest_dataframe ingest_huggingface = huggingface.ingest_huggingface diff --git a/deeplake/api/dataset.py b/deeplake/api/dataset.py index 2b1b6832b2..56a175de99 100644 --- a/deeplake/api/dataset.py +++ b/deeplake/api/dataset.py @@ -8,6 +8,7 @@ from deeplake.auto.unstructured.kaggle import download_kaggle_dataset from deeplake.auto.unstructured.image_classification import ImageClassification from deeplake.auto.unstructured.coco.coco import CocoDataset +from deeplake.auto.unstructured.yolo.yolo import YoloDataset from deeplake.client.client import DeepLakeBackendClient from deeplake.client.log import logger from deeplake.core.dataset import Dataset, dataset_factory @@ -1023,7 +1024,7 @@ def ingest_coco( >>> ds = deeplake.ingest_coco( >>> "path/to/images/directory", >>> ["path/to/annotation/file1.json", "path/to/annotation/file2.json"], - >>> dest="hub://username/dataset", + >>> dest="hub://org_id/dataset", >>> key_to_tensor_mapping={"category_id": "labels", "bbox": "boxes"}, >>> file_to_group_mapping={"file1.json": "group1", "file2.json": "group2"}, >>> ignore_keys=["area", "image_id", "id"], @@ -1034,7 +1035,7 @@ def ingest_coco( >>> ds = deeplake.ingest_coco( >>> "s3://bucket/images/directory", >>> "s3://bucket/annotation/file1.json", - >>> dest="hub://username/dataset", + >>> dest="hub://org_id/dataset", >>> ignore_one_group=True, >>> ignore_keys=["area", "image_id", "id"], >>> image_settings={"name": "images", "linked": True, creds_key="my_managed_creds_key", "sample_compression": "jpeg"}, @@ -1048,7 +1049,7 @@ def ingest_coco( annotation_files (str, pathlib.Path, List[str]): Path to JSON annotation files in COCO format. dest (str, pathlib.Path): - The full path to the dataset. Can be: - - a Deep Lake cloud path of the form ``hub://username/datasetname``. To write to Deep Lake cloud datasets, ensure that you are logged in to Deep Lake (use 'activeloop login' from command line) + - a Deep Lake cloud path of the form ``hub://org_id/datasetname``. To write to Deep Lake cloud datasets, ensure that you are logged in to Deep Lake (use 'activeloop login' from command line), or pass in a token using the 'token' parameter. - an s3 path of the form ``s3://bucketname/path/to/dataset``. Credentials are required in either the environment or passed to the creds argument. - a local file system path of the form ``./path/to/dataset`` or ``~/path/to/dataset`` or ``path/to/dataset``. - a memory path of the form ``mem://path/to/dataset`` which doesn't save the dataset but keeps it in memory instead. Should be used only for testing as it does not persist. @@ -1070,6 +1071,7 @@ def ingest_coco( Raises: IngestionError: If either ``key_to_tensor_mapping`` or ``file_to_group_mapping`` are not one-to-one. """ + dest = convert_pathlib_to_string_if_needed(dest) images_directory = convert_pathlib_to_string_if_needed(images_directory) annotation_files = ( @@ -1078,7 +1080,12 @@ def ingest_coco( else convert_pathlib_to_string_if_needed(annotation_files) ) - ds = deeplake.empty(dest, creds=dest_creds, verbose=False, **dataset_kwargs) + feature_report_path( + dest, + "ingest_coco", + {"num_workers": num_workers}, + token=dataset_kwargs.get("token", None), + ) unstructured = CocoDataset( source=images_directory, @@ -1091,6 +1098,130 @@ def ingest_coco( creds=src_creds, ) structure = unstructured.prepare_structure(inspect_limit) + + ds = deeplake.empty(dest, creds=dest_creds, verbose=False, **dataset_kwargs) + + structure.create_missing(ds) + + unstructured.structure( + ds, + progressbar, + num_workers, + ) + + return ds + + @staticmethod + def ingest_yolo( + data_directory: Union[str, pathlib.Path], + dest: Union[str, pathlib.Path], + class_names_file: Optional[Union[str, pathlib.Path]] = None, + annotations_directory: Optional[Union[str, pathlib.Path]] = None, + allow_no_annotation: bool = False, + image_params: Optional[Dict] = None, + label_params: Optional[Dict] = None, + coordinates_params: Optional[Dict] = None, + src_creds: Optional[Dict] = None, + dest_creds: Optional[Dict] = None, + image_creds_key: Optional[str] = None, + inspect_limit: int = 1000, + progressbar: bool = True, + num_workers: int = 0, + connect_kwargs: Optional[Dict] = None, + **dataset_kwargs, + ) -> Dataset: + """Ingest images and annotations (bounding boxes or polygons) in YOLO format to a Deep Lake Dataset. + + Examples: + >>> ds = deeplake.ingest_yolo( + >>> "path/to/data/directory", + >>> dest="hub://org_id/dataset", + >>> allow_no_annotation=True, + >>> token="my_activeloop_token", + >>> num_workers=4, + >>> ) + >>> # or ingest data from cloud + >>> ds = deeplake.ingest_yolo( + >>> "s3://bucket/data_directory", + >>> dest="hub://org_id/dataset", + >>> image_params={"name": "image_links", "htype": "link[image]"}, + >>> image_creds_key='my_s3_managed_crerendials", + >>> src_creds=aws_creds, # Can also be inferred from environment + >>> token="my_activeloop_token", + >>> num_workers=4, + >>> ) + + Args: + data_directory (str, pathlib.Path): The path to the directory containing the data (images files and annotation files(see 'annotations_directory' input for specifying annotations in a separate directory). + dest (str, pathlib.Path): + - The full path to the dataset. Can be: + - a Deep Lake cloud path of the form ``hub://org_id/datasetname``. To write to Deep Lake cloud datasets, ensure that you are logged in to Deep Lake (use 'activeloop login' from command line), or pass in a token using the 'token' parameter. + - an s3 path of the form ``s3://bucketname/path/to/dataset``. Credentials are required in either the environment or passed to the creds argument. + - a local file system path of the form ``./path/to/dataset`` or ``~/path/to/dataset`` or ``path/to/dataset``. + - a memory path of the form ``mem://path/to/dataset`` which doesn't save the dataset but keeps it in memory instead. Should be used only for testing as it does not persist. + class_names_file: Path to the file containing the class names on separate lines. This is typically a file titled classes.names. + annotations_directory (Optional[Union[str, pathlib.Path]]): Path to directory containing the annotations. If specified, the 'data_directory' will not be examined for annotations. + allow_no_annotation (bool): Flag to determine whether missing annotations files corresponding to an image should be treated as empty annoations. Set to ``False`` by default. + image_params (Optional[Dict]): A dictionary containing parameters for the images tensor. + label_params (Optional[Dict]): A dictionary containing parameters for the labels tensor. + coordinates_params (Optional[Dict]): A dictionary containing parameters for the ccoordinates tensor. This tensor either contains bounding boxes or polygons. + src_creds (Optional[Dict]): Credentials to access the source path. If not provided, will be inferred from the environment. + dest_creds (Optional[Dict]): A dictionary containing credentials used to access the destination path of the dataset. + image_creds_key (Optional[str]): creds_key for linked tensors, applicable if the htype for the images tensor is specified as 'link[image]' in the 'image_params' input. + inspect_limit (int): The maximum number of annotations to inspect, in order to infer whether they are bounding boxes of polygons. This in put is ignored if the htype is specfied in the 'coordinates_params'. + progressbar (bool): Enables or disables ingestion progress bar. Set to ``True`` by default. + num_workers (int): The number of workers to use for ingestion. Set to ``0`` by default. + connect_kwargs (Optional[Dict]): If specified, the dataset will be connected to Platform, and connect_kwargs will be passed to :func:`ds.connect`. + **dataset_kwargs: Any arguments passed here will be forwarded to the dataset creator function. See :func:`deeplake.empty`. + + Returns: + Dataset: The Dataset created from the images and YOLO annotations. + + Raises: + IngestionError: If annotations are not found for all the images and 'allow_no_annotation' is False + """ + + dest = convert_pathlib_to_string_if_needed(dest) + data_directory = convert_pathlib_to_string_if_needed(data_directory) + + annotations_directory = ( + convert_pathlib_to_string_if_needed(annotations_directory) + if annotations_directory is not None + else None + ) + + class_names_file = ( + convert_pathlib_to_string_if_needed(class_names_file) + if class_names_file is not None + else None + ) + + feature_report_path( + dest, + "ingest_yolo", + {"num_workers": num_workers}, + token=dataset_kwargs.get("token", None), + ) + + unstructured = YoloDataset( + data_directory=data_directory, + class_names_file=class_names_file, + annotations_directory=annotations_directory, + image_params=image_params, + label_params=label_params, + coordinates_params=coordinates_params, + allow_no_annotation=allow_no_annotation, + creds=src_creds, + image_creds_key=image_creds_key, + inspect_limit=inspect_limit, + ) + + structure = unstructured.prepare_structure() + + ds = deeplake.empty(dest, creds=dest_creds, verbose=False, **dataset_kwargs) + if connect_kwargs is not None: + ds.connect(**connect_kwargs, token=dataset_kwargs.get("token", None)) + structure.create_missing(ds) unstructured.structure( @@ -1185,6 +1316,7 @@ def ingest( "Progressbar": progressbar, "Summary": summary, }, + token=dataset_kwargs.get("token", None), ) src = convert_pathlib_to_string_if_needed(src) @@ -1212,11 +1344,11 @@ def ingest( if images_compression is None: raise InvalidFileExtension(src) - ds = deeplake.dataset(dest, creds=dest_creds, **dataset_kwargs) - # TODO: support more than just image classification (and update docstring) unstructured = ImageClassification(source=src) + ds = deeplake.dataset(dest, creds=dest_creds, **dataset_kwargs) + # TODO: auto detect compression unstructured.structure( ds, # type: ignore @@ -1278,6 +1410,7 @@ def ingest_kaggle( "Progressbar": progressbar, "Summary": summary, }, + token=dataset_kwargs.get("token", None), ) if os.path.isdir(src) and os.path.isdir(dest): @@ -1334,36 +1467,49 @@ def ingest_dataframe( import pandas as pd from deeplake.auto.structured.dataframe import DataFrame + feature_report_path( + convert_pathlib_to_string_if_needed(dest), + "ingest_dataframe", + {}, + token=dataset_kwargs.get("token", None), + ) + if not isinstance(src, pd.DataFrame): raise Exception("Source provided is not a valid pandas dataframe object") + structured = DataFrame(src) + if isinstance(dest, Dataset): ds = dest else: dest = convert_pathlib_to_string_if_needed(dest) ds = deeplake.dataset(dest, creds=dest_creds, **dataset_kwargs) - structured = DataFrame(src) structured.fill_dataset(ds, progressbar) # type: ignore return ds # type: ignore @staticmethod - @deeplake_reporter.record_call def list( - workspace: str = "", + org_id: str = "", token: Optional[str] = None, ) -> None: """List all available Deep Lake cloud datasets. Args: - workspace (str): Specify user/organization name. If not given, + org_id (str): Specify organization id. If not given, returns a list of all datasets that can be accessed, regardless of what workspace they are in. - Otherwise, lists all datasets in the given workspace. + Otherwise, lists all datasets in the given organization. token (str, optional): Activeloop token, used for fetching credentials for Deep Lake datasets. This is optional, tokens are normally autogenerated. Returns: List: List of dataset names. """ + + deeplake_reporter.feature_report( + feature_name="list", + parameters={"org_id": org_id}, + ) + client = DeepLakeBackendClient(token=token) - datasets = client.get_datasets(workspace=workspace) + datasets = client.get_datasets(workspace=org_id) return datasets diff --git a/deeplake/auto/tests/test_coco_template.py b/deeplake/auto/tests/test_coco_template.py index fb093dae8e..d943768639 100644 --- a/deeplake/auto/tests/test_coco_template.py +++ b/deeplake/auto/tests/test_coco_template.py @@ -11,13 +11,12 @@ def test_full_dataset_structure(local_ds): dataset_structure = DatasetStructure(ignore_one_group=False) dataset_structure.add_first_level_tensor( - TensorStructure("tensor1", params={"htype": "generic"}, primary=False) + TensorStructure("tensor1", params={"htype": "generic"}) ) dataset_structure.add_first_level_tensor( TensorStructure( "images", params={"htype": "image", "sample_compression": "jpeg"}, - primary=True, ) ) @@ -51,13 +50,12 @@ def test_missing_dataset_structure(local_ds): local_ds.create_tensor("annotations/masks", htype="binary_mask") dataset_structure.add_first_level_tensor( - TensorStructure("tensor1", params={"htype": "generic"}, primary=False) + TensorStructure("tensor1", params={"htype": "generic"}) ) dataset_structure.add_first_level_tensor( TensorStructure( "images", params={"htype": "image", "sample_compression": "jpeg"}, - primary=True, ) ) diff --git a/deeplake/auto/tests/test_yolo_template.py b/deeplake/auto/tests/test_yolo_template.py new file mode 100644 index 0000000000..236b281ae0 --- /dev/null +++ b/deeplake/auto/tests/test_yolo_template.py @@ -0,0 +1,143 @@ +import deeplake +import pytest +from deeplake.util.exceptions import IngestionError + + +def test_minimal_yolo_ingestion(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory"], + "class_names_file": yolo_ingestion_data["class_names_file"], + } + + ds = deeplake.ingest_yolo(**params, dest=local_path) + + assert ds.path == local_path + assert "images" in ds.tensors + assert "boxes" in ds.tensors + assert "labels" in ds.tensors + assert len(ds.labels.info["class_names"]) > 0 + assert ds.boxes.htype == "bbox" + + +def test_minimal_yolo_ingestion_no_class_names(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory"], + "class_names_file": None, + } + + ds = deeplake.ingest_yolo(**params, dest=local_path) + + assert ds.path == local_path + assert "images" in ds.tensors + assert "boxes" in ds.tensors + assert "labels" in ds.tensors + assert ds.labels.info["class_names"] == [] + assert ds.boxes.htype == "bbox" + + +def test_minimal_yolo_ingestion_separate_annotations(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory_no_annotations"], + "class_names_file": yolo_ingestion_data["class_names_file"], + "annotations_directory": yolo_ingestion_data["annotations_directory"], + } + + ds = deeplake.ingest_yolo(**params, dest=local_path) + + assert ds.path == local_path + assert "images" in ds.tensors + assert "boxes" in ds.tensors + assert "labels" in ds.tensors + assert len(ds.labels.info["class_names"]) > 0 + assert ds.boxes.htype == "bbox" + + +def test_minimal_yolo_ingestion_missing_annotations(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory_missing_annotations"], + "class_names_file": yolo_ingestion_data["class_names_file"], + "allow_no_annotation": True, + } + + ds = deeplake.ingest_yolo(**params, dest=local_path) + + assert ds.path == local_path + assert "images" in ds.tensors + assert "boxes" in ds.tensors + assert "labels" in ds.tensors + assert len(ds.labels.info["class_names"]) > 0 + assert ds.boxes.htype == "bbox" + + +def test_minimal_yolo_ingestion_unsupported_annotations( + local_path, yolo_ingestion_data +): + + params = { + "data_directory": yolo_ingestion_data["data_directory_unsupported_annotations"], + "class_names_file": yolo_ingestion_data["class_names_file"], + } + + with pytest.raises(IngestionError): + ds = deeplake.ingest_yolo(**params, dest=local_path) + + +def test_minimal_yolo_ingestion_bad_data_path(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory"] + "corrupt_this_path", + "class_names_file": yolo_ingestion_data["class_names_file"], + } + + with pytest.raises(IngestionError): + ds = deeplake.ingest_yolo(**params, dest=local_path) + + +def test_minimal_yolo_ingestion_poly(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory"], + "class_names_file": yolo_ingestion_data["class_names_file"], + } + + ds = deeplake.ingest_yolo( + **params, + dest=local_path, + coordinates_params={"name": "polygons", "htype": "polygon"}, + ) + + assert ds.path == local_path + assert "images" in ds.tensors + assert "polygons" in ds.tensors + assert "labels" in ds.tensors + assert len(ds.labels.info["class_names"]) > 0 + assert ds.polygons.htype == "polygon" + + +def test_minimal_yolo_ingestion_with_linked_images(local_path, yolo_ingestion_data): + + params = { + "data_directory": yolo_ingestion_data["data_directory"], + "class_names_file": yolo_ingestion_data["class_names_file"], + } + + ds = deeplake.ingest_yolo( + **params, + dest=local_path, + image_params={ + "name": "linked_images", + "htype": "link[image]", + "sample_compression": "png", + }, + ) + + assert ds.path == local_path + assert "linked_images" in ds.tensors + assert "boxes" in ds.tensors + assert "labels" in ds.tensors + assert len(ds.labels.info["class_names"]) > 0 + assert ds.linked_images.htype == "link[image]" diff --git a/deeplake/auto/unstructured/coco/coco.py b/deeplake/auto/unstructured/coco/coco.py index be013c166f..3e74e975fb 100644 --- a/deeplake/auto/unstructured/coco/coco.py +++ b/deeplake/auto/unstructured/coco/coco.py @@ -124,9 +124,7 @@ def _add_images_tensor(self, structure: DatasetStructure): ) name = self.image_settings.get("name", "images") - structure.add_first_level_tensor( - TensorStructure(name=name, primary=True, params=img_config) - ) + structure.add_first_level_tensor(TensorStructure(name=name, params=img_config)) def _ingest_images( self, diff --git a/deeplake/auto/unstructured/util.py b/deeplake/auto/unstructured/util.py index 84dc05fe07..33df9e8bb9 100644 --- a/deeplake/auto/unstructured/util.py +++ b/deeplake/auto/unstructured/util.py @@ -10,11 +10,9 @@ def __init__( self, name: str, params: Optional[Dict] = None, - primary: bool = False, ) -> None: self.name = name self.params = params if params is not None else dict() - self.primary = primary def create(self, ds: Dataset): ds.create_tensor(self.name, **self.params) diff --git a/deeplake/auto/unstructured/yolo/__init__.py b/deeplake/auto/unstructured/yolo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deeplake/auto/unstructured/yolo/constants.py b/deeplake/auto/unstructured/yolo/constants.py new file mode 100644 index 0000000000..6de055c136 --- /dev/null +++ b/deeplake/auto/unstructured/yolo/constants.py @@ -0,0 +1,14 @@ +DEFAULT_IMAGE_TENSOR_PARAMS = { + "name": "images", + "htype": "image", +} + +DEFAULT_YOLO_LABEL_TENSOR_PARAMS = { + "name": "labels", + "htype": "class_label", + "sample_compression": None, +} + +DEFAULT_YOLO_COORDINATES_TENSOR_PARAMS = { + "sample_compression": None, +} diff --git a/deeplake/auto/unstructured/yolo/utils.py b/deeplake/auto/unstructured/yolo/utils.py new file mode 100644 index 0000000000..4607bedc1f --- /dev/null +++ b/deeplake/auto/unstructured/yolo/utils.py @@ -0,0 +1,195 @@ +import os + +import pathlib + +from collections import defaultdict +from typing import Tuple, List, Union, Optional, DefaultDict + +import deeplake +from deeplake.htype import HTYPE_SUPPORTED_COMPRESSIONS +from deeplake.util.exceptions import IngestionError +from deeplake.client.log import logger +from deeplake.util.storage import storage_provider_from_path +from deeplake.util.path import convert_pathlib_to_string_if_needed + +import numpy as np + + +class YoloData: + def __init__( + self, + data_directory: Union[str, pathlib.Path], + creds, + annotations_directory: Optional[Union[str, pathlib.Path]] = None, + class_names_file: Optional[Union[str, pathlib.Path]] = None, + ) -> None: + """Annotations can either be in hte data_directory, or in a separate annotations_directory""" + + self.root = convert_pathlib_to_string_if_needed(data_directory) + self.provider = storage_provider_from_path(self.root, creds=creds) + + if annotations_directory: + self.separate_annotations = True + self.root_annotations = convert_pathlib_to_string_if_needed( + annotations_directory + ) + self.provider_annotations = storage_provider_from_path( + self.root_annotations, creds=creds + ) + else: + self.separate_annotations = False + self.root_annotations = self.root + self.provider_annotations = self.provider + + if class_names_file: + class_names_file = convert_pathlib_to_string_if_needed(class_names_file) + self.root_class_names = os.path.dirname(class_names_file) + self.provider_class_names = storage_provider_from_path( + self.root_class_names, creds=creds + ) + self.class_names = ( + self.provider_class_names.get_bytes(os.path.basename(class_names_file)) + .decode() + .splitlines() + ) + else: + self.class_names = None + + ( + self.supported_images, + self.supported_annotations, + self.invalid_files, + self.image_extensions, + self.most_frequent_image_extension, + ) = self.parse_data() + + def parse_data( + self, + ) -> Tuple[List[str], List[str], List[str], List[str], Optional[str]]: + """Parses the given directory to generate a list of image and annotation paths. + Returns: + A tuple with, respectively, list of supported images, list of encountered invalid files, list of encountered extensions and the most frequent extension + """ + supported_image_extensions = tuple( + HTYPE_SUPPORTED_COMPRESSIONS["image"] + ["jpg"] + ) + supported_images = [] + supported_annotations = [] + invalid_files = [] + image_extensions: DefaultDict[str, int] = defaultdict(int) + + if self.separate_annotations: + for file in self.provider_annotations: + if file.endswith(".txt"): + supported_annotations.append(file) + else: + invalid_files.append(file) + + for file in self.provider: + if file.endswith(supported_image_extensions): + supported_images.append(file) + ext = pathlib.Path(file).suffix[ + 1: + ] # Get extension without the . symbol + image_extensions[ext] += 1 + else: + invalid_files.append(file) + + else: + for file in self.provider: + if file.endswith(".txt"): + supported_annotations.append(file) + elif file.endswith(supported_image_extensions): + supported_images.append(file) + ext = pathlib.Path(file).suffix[ + 1: + ] # Get extension without the . symbol + image_extensions[ext] += 1 + else: + invalid_files.append(file) + + if len(invalid_files) > 0: + logger.warning( + f"Encountered {len(invalid_files)} unsupported files the data folders and annotation folders (if specified)." + ) + + most_frequent_image_extension = max( + image_extensions, key=lambda k: image_extensions[k], default=None + ) + + return ( + supported_images, + supported_annotations, + invalid_files, + list(image_extensions.keys()), + most_frequent_image_extension, + ) + + def read_yolo_coordinates(self, file_name: str, is_box: bool = True): + """ + Function reads a label.txt YOLO file and returns a numpy array of labels, + and an object containing the coordinates. If is_box is True, the coordinates + object is an (Nx4) array, where N is the number of bounding boxes in the annotation + file. If is_box is Fales, we assume the coordinates represent a polygon, so the coordinates + object is as list of length N, where each element is an Mx2 array, where M is the number + points in each polygon, and N is the number of ploygons in the annotation file. + """ + + ann = self.get_annotation(file_name) + lines_split = ann.splitlines() + + yolo_labels = np.zeros(len(lines_split)) + + # Initialize box and polygon coordinates in order to mypy to pass, since types are different. This is computationally negligible. + yolo_coordinates_box = np.zeros((len(lines_split), 4)) + yolo_coordinates_poly = [] + + # Go through each line and parse data + for l, line in enumerate(lines_split): + line_split = line.split() + + if is_box: + yolo_coordinates_box[l, :] = np.array( + ( + float(line_split[1]), + float(line_split[2]), + float(line_split[3]), + float(line_split[4]), + ) + ) + else: # Assume it's a polygon + coordinates = np.array([float(item) for item in line_split[1:]]) + + if coordinates.size % 2 != 0: + raise IngestionError( + f"Error ih annotation {file_name}. Polygons must have an even number of points." + ) + + yolo_coordinates_poly.append( + coordinates.reshape((int(coordinates.size / 2), 2)) + ) + + yolo_labels[l] = int(line_split[0]) + + if is_box: + return yolo_labels, yolo_coordinates_box + else: + return yolo_labels, yolo_coordinates_poly + + def get_full_path_image(self, image_name: str) -> str: + return os.path.join(self.root, image_name) + + def get_image( + self, + image: str, + is_link: Optional[bool] = False, + creds_key: Optional[str] = None, + ): + + if is_link: + return deeplake.link(self.get_full_path_image(image), creds_key=creds_key) + + return deeplake.read(self.get_full_path_image(image), storage=self.provider) + + def get_annotation(self, annotation: str): + return self.provider_annotations.get_bytes(annotation).decode() diff --git a/deeplake/auto/unstructured/yolo/yolo.py b/deeplake/auto/unstructured/yolo/yolo.py new file mode 100644 index 0000000000..c3c7db1c50 --- /dev/null +++ b/deeplake/auto/unstructured/yolo/yolo.py @@ -0,0 +1,329 @@ +import deeplake + +from pathlib import Path +from typing import Dict, Optional + +from deeplake.core.dataset import Dataset +from deeplake.util.exceptions import IngestionError +from deeplake.client.log import logger + +from ..base import UnstructuredDataset +from ..util import DatasetStructure, TensorStructure +from .utils import YoloData + +import numpy as np + +from .constants import ( + DEFAULT_YOLO_COORDINATES_TENSOR_PARAMS, + DEFAULT_YOLO_LABEL_TENSOR_PARAMS, + DEFAULT_IMAGE_TENSOR_PARAMS, +) + + +class YoloDataset(UnstructuredDataset): + def __init__( + self, + data_directory: str, + class_names_file: Optional[str] = None, + annotations_directory: Optional[str] = None, + image_params: Optional[Dict] = None, + label_params: Optional[Dict] = None, + coordinates_params: Optional[Dict] = None, + allow_no_annotation: Optional[bool] = False, + verify_class_names: Optional[bool] = True, + inspect_limit: Optional[int] = 1000, + creds: Optional[Dict] = None, + image_creds_key: Optional[str] = None, + ): + """Container for access to Yolo Data, parsing of key information, and conversions to a Deep Lake dataset""" + + super().__init__(data_directory) + + self.class_names_file = class_names_file + self.data_directory = data_directory + self.annotations_directory = annotations_directory + + self.allow_no_annotation = allow_no_annotation + self.verify_class_names = verify_class_names + self.creds = creds + self.image_creds_key = image_creds_key + self.inspect_limit = inspect_limit + + self.data = YoloData( + self.data_directory, + creds, + self.annotations_directory, + self.class_names_file, + ) + self._validate_data() + + # Create a separate list of tuples with the intestion data (img_fn, annotation_fn). + # We do this in advance so missing files are discovered before the ingestion process. + self._create_ingestion_list() + + self._validate_ingestion_data() + + self._initialize_params( + image_params or {}, label_params or {}, coordinates_params or {} + ) + self._validate_image_params() + + def _parse_coordinates_type(self): + """Function inspects up to inspect_limit annotation files in order to infer whether they are polygons or bounding boxes""" + + # If the htype or name of the coordinates is not specified (htype could be bbox or polygon), auto-infer it by reading some of the annotation files + if ( + "htype" not in self.coordinates_params.keys() + or "name" not in self.coordinates_params.keys() + ): + + # Read the annotation files assuming they are polygons and check if there are any non-empty annotations without 4 coordinates + coordinates_htype = "bbox" # Initialize to bbox and change if contradicted + coordinates_name = "boxes" # Initialize to boxes and change if contradicted + count = 0 + while count < min(self.inspect_limit, len(self.ingestion_data)): + fn = self.ingestion_data[count][1] + if fn is not None: + _, coordinates = self.data.read_yolo_coordinates(fn, is_box=False) + for c in coordinates: + coord_size = c.size + if coord_size > 0 and coord_size != 4: + coordinates_htype = "polygon" + coordinates_name = "polygons" + + count = ( + self.inspect_limit + 1 + ) # Set this to exit the while loop + break + + ## TODO: Add fancier math to see whether even coordinates with 4 elements could be polygons + count += 1 + + if "htype" not in self.coordinates_params.keys(): + self.coordinates_params["htype"] = coordinates_htype + + if "name" not in self.coordinates_params.keys(): + self.coordinates_params["name"] = coordinates_name + + def _initialize_params(self, image_params, label_params, coordinates_params): + image_params_updated = DEFAULT_IMAGE_TENSOR_PARAMS.copy() + for k, v in image_params.items(): + image_params_updated[k] = v + self.image_params = image_params_updated + + coordinates_params_updated = DEFAULT_YOLO_COORDINATES_TENSOR_PARAMS.copy() + for k, v in coordinates_params.items(): + coordinates_params_updated[k] = v + self.coordinates_params = coordinates_params_updated + + label_params_updated = DEFAULT_YOLO_LABEL_TENSOR_PARAMS.copy() + for k, v in label_params.items(): + label_params_updated[k] = v + self.label_params = label_params_updated + + self._parse_coordinates_type() + + def _create_ingestion_list(self): + """Function creates a list of tuples (image_filename, annotation_filename) that is passed to a deeplake.compute ingestion function""" + + ingestion_data = [] + for img_fn in self.data.supported_images: + base_name = Path(img_fn).stem + if base_name + ".txt" in self.data.supported_annotations: + ingestion_data.append((img_fn, base_name + ".txt")) + else: + if self.allow_no_annotation: + logger.warning( + f"Annotation was not found for {img_fn}. Empty annotation data will be appended for this image." + ) + + else: + raise IngestionError( + f"Annotation was not found for {img_fn}. Please add an annotation for this image, of specify allow_no_annotation=True, which will automatically append an empty annotation to the Deep Lake dataset." + ) + ingestion_data.append((img_fn, None)) + + self.ingestion_data = ingestion_data + + def prepare_structure(self) -> DatasetStructure: + structure = DatasetStructure(ignore_one_group=True) + self._add_annotation_tensors(structure) + self._add_images_tensor(structure) + + return structure + + def _validate_data(self): + if ( + len(self.data.supported_images) != len(self.data.supported_annotations) + and self.allow_no_annotation == False + ): + raise IngestionError( + "The number of supported images and annotations in the input data is not equal. Please ensure that each image has a corresponding annotation, or set allow_no_annotation = True" + ) + + if len(self.data.supported_images) == 0: + raise IngestionError( + "There are no supported images in the input data. Please verify the source directory." + ) + + def _validate_ingestion_data(self): + if len(self.ingestion_data) == 0: + raise IngestionError( + "The data parser was not able to find any annotations corresponding to the images. Please check your directories, filename, and extenstions, or consider setting allow_no_annotation = True in order to upload empty annotations." + ) + + def _validate_image_params(self): + if "name" not in self.image_params: + raise IngestionError( + "Image params must contain a name for the image tensor." + ) + + def _add_annotation_tensors( + self, + structure: DatasetStructure, + ): + + structure.add_first_level_tensor( + TensorStructure( + name=self.label_params["name"], + params={ + i: self.label_params[i] for i in self.label_params if i != "name" + }, + ) + ) + + structure.add_first_level_tensor( + TensorStructure( + name=self.coordinates_params["name"], + params={ + i: self.coordinates_params[i] + for i in self.coordinates_params + if i != "name" + }, + ) + ) + + def _add_images_tensor(self, structure: DatasetStructure): + img_params = self.image_params.copy() + + img_params["sample_compression"] = self.image_params.get( + "sample_compression", self.data.most_frequent_image_extension + ) + name = self.image_params.get("name") + + structure.add_first_level_tensor( + TensorStructure( + name=name, + params={i: img_params[i] for i in img_params if i != "name"}, + ) + ) + + def _ingest_data(self, ds: Dataset, progressbar: bool = True, num_workers: int = 0): + """Functions appends the the data to the dataset object using deeplake.compute""" + + if self.image_creds_key is not None: + ds.add_creds_key(self.image_creds_key, managed=True) + + # Wrap tensor data needed by the deeplake.compute function into a net dict. + tensor_meta = { + "images": ds[self.image_params["name"]].meta, + "labels": ds[self.label_params["name"]].meta, + "coordinates": ds[self.coordinates_params["name"]].meta, + } + + @deeplake.compute + def append_data_bbox(data, sample_out, tensor_meta: Dict = tensor_meta): + + # If the ingestion data is None, create empty annotations corresponding to the file + if data[1]: + yolo_labels, yolo_coordinates = self.data.read_yolo_coordinates( + data[1], is_box=True + ) + else: + yolo_labels = np.zeros((0)) + yolo_coordinates = np.zeros((4, 0)) + + sample_out.append( + { + self.image_params["name"]: self.data.get_image( + data[0], + tensor_meta["images"].is_link, + self.image_creds_key, + ), + self.label_params["name"]: yolo_labels.astype( + tensor_meta["labels"].dtype + ), + self.coordinates_params["name"]: yolo_coordinates.astype( + tensor_meta["coordinates"].dtype + ), + } + ) + + @deeplake.compute + def append_data_polygon(data, sample_out, tensor_meta: Dict = tensor_meta): + + # If the ingestion data is None, create empty annotations corresponding to the file + if data[1]: + yolo_labels, yolo_coordinates = self.data.read_yolo_coordinates( + data[1], is_box=False + ) + else: + yolo_labels = np.zeros((0)) + yolo_coordinates = [] + + sample_out.append( + { + self.image_params["name"]: self.data.get_image( + data[0], + tensor_meta["images"].is_link, + self.image_creds_key, + ), + self.label_params["name"]: yolo_labels.astype( + tensor_meta["labels"].dtype + ), + self.coordinates_params["name"]: yolo_coordinates, + } + ) + + if tensor_meta["coordinates"].htype == "bbox": + append_data_bbox(tensor_meta=tensor_meta).eval( + self.ingestion_data, + ds, + progressbar=progressbar, + num_workers=num_workers, + ) + else: + append_data_polygon(tensor_meta=tensor_meta).eval( + self.ingestion_data, + ds, + progressbar=progressbar, + num_workers=num_workers, + ) + + def structure(self, ds: Dataset, progressbar: bool = True, num_workers: int = 0): # type: ignore + + # Set class names in the dataset + if self.data.class_names: + ds[self.label_params["name"]].info["class_names"] = self.data.class_names + + # Set bounding box format in the dataset + if ds[self.coordinates_params["name"]].meta.htype == "bbox": + ds[self.coordinates_params["name"]].info["coords"] = { + "type": "fractional", + "mode": "CCWH", + } + + self._ingest_data(ds, progressbar, num_workers) + + if self.verify_class_names and self.data.class_names: + + labels = ds[self.label_params.get("name")].numpy(aslist=True) + + max_label = max( + [l.max(initial=0) for l in labels] + ) # Assume a label is 0 if array is empty. This is technically incorrect, but it's highly unlikely that all labels are empty + + if max_label != len(ds[self.label_params.get("name")].info.class_names) - 1: + raise IngestionError( + "Dataset has been created but the largest numeric label in the annotations is inconsistent with the number of classes in the classes file." + ) diff --git a/deeplake/tests/path_fixtures.py b/deeplake/tests/path_fixtures.py index fc9b2e5ccb..1b2904f9ed 100644 --- a/deeplake/tests/path_fixtures.py +++ b/deeplake/tests/path_fixtures.py @@ -99,6 +99,19 @@ def _download_hub_test_coco_data(): } +def _download_hub_test_yolo_data(): + path = _git_clone(_HUB_TEST_RESOURCES_URL) + return { + "data_directory": path + "/yolo/data", + "class_names_file": path + "/yolo/classes.names", + "data_directory_no_annotations": path + "/yolo/images_only", + "annotations_directory": path + "/yolo/annotations_only", + "data_directory_missing_annotations": path + "/yolo/data_missing_annotations", + "data_directory_unsupported_annotations": path + + "/yolo/data_unsupported_annotations", + } + + def _download_pil_test_images(ext=[".jpg", ".png"]): paths = {e: [] for e in ext} corrupt_file_keys = [ @@ -512,3 +525,8 @@ def hub_token(request): @pytest.fixture(scope="session") def coco_ingestion_data(): return _download_hub_test_coco_data() + + +@pytest.fixture(scope="session") +def yolo_ingestion_data(): + return _download_hub_test_yolo_data()