diff --git a/datasets/datasets.py b/datasets/datasets.py index 84ff7fc..4b44a9c 100644 --- a/datasets/datasets.py +++ b/datasets/datasets.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- import json -from io import BytesIO +from io import TextIOWrapper from os import SEEK_SET from os.path import splitext +from tempfile import SpooledTemporaryFile from unicodedata import normalize from uuid import uuid4 @@ -16,11 +17,14 @@ from pandas.io.common import infer_compression from platiagro import load_dataset, save_dataset, stat_dataset, update_dataset_metadata from platiagro.featuretypes import infer_featuretypes, validate_featuretypes + +from datasets import monkeypatch # noqa: F401 from datasets.exceptions import BadRequest, NotFound from datasets.utils import data_pagination NOT_FOUND = NotFound("The specified dataset does not exist") +SPOOLED_MAX_SIZE = 1024 * 1024 # 1MB def list_datasets(): @@ -83,21 +87,24 @@ def create_dataset(file_object): featuretypes = infer_featuretypes(df) metadata = { + "columns": columns, "featuretypes": featuretypes, "original-filename": filename, + "total": len(df.index), } + file.seek(0, SEEK_SET) # uses PlatIAgro SDK to save the dataset - save_dataset(name, df, metadata=metadata) + save_dataset(name, file, metadata=metadata) columns = [{"name": col, "featuretype": ftype} for col, ftype in zip(columns, featuretypes)] - content = load_dataset(name=name) + # Replaces NaN value by a text "NaN" so JSON encode doesn't fail - content.replace(np.nan, "NaN", inplace=True, regex=True) - content.replace(np.inf, "Inf", inplace=True, regex=True) - content.replace(-np.inf, "-Inf", inplace=True, regex=True) - data = content.values.tolist() - return {"name": name, "columns": columns, "data": data, "total": len(content.index), "filename": filename} + df.replace(np.nan, "NaN", inplace=True, regex=True) + df.replace(np.inf, "Inf", inplace=True, regex=True) + df.replace(-np.inf, "-Inf", inplace=True, regex=True) + data = df.values.tolist() + return {"name": name, "columns": columns, "data": data, "total": len(df.index), "filename": filename} def create_google_drive_dataset(gfile): @@ -148,7 +155,7 @@ def create_google_drive_dataset(gfile): else: request = service.files().get_media(fileId=file_id) - fh = BytesIO() + fh = SpooledTemporaryFile(max_size=SPOOLED_MAX_SIZE) downloader = MediaIoBaseDownload(fh, request) done = False try: @@ -291,6 +298,7 @@ def read_into_dataframe(file, filename=None, nrows=100, max_characters=50): ----- If no filename is given, a hex uuid will be used as the file name. """ + detector = UniversalDetector() for line, text in enumerate(file): detector.feed(text) @@ -305,23 +313,23 @@ def read_into_dataframe(file, filename=None, nrows=100, max_characters=50): compression = infer_compression(filename, "infer") file.seek(0, SEEK_SET) - contents = file.read() - - with BytesIO(contents) as file: - df0 = pd.read_csv( - file, - encoding=encoding, - compression=compression, - sep=None, - engine="python", - header="infer", - nrows=nrows, - ) + + pdread = TextIOWrapper(file, encoding=encoding) + df0 = pd.read_csv( + pdread, + encoding=encoding, + compression=compression, + sep=None, + engine="python", + header="infer", + nrows=nrows, + ) df0_cols = list(df0.columns) # Check if all columns are strings and short strings(text values tend to be long) column_names_checker = all([type(item) == str for item in df0_cols]) + if column_names_checker: column_names_checker = all([len(item) < max_characters for item in df0_cols]) @@ -340,16 +348,17 @@ def read_into_dataframe(file, filename=None, nrows=100, max_characters=50): header = "infer" if final_checker else None prefix = None if header else "col" - with BytesIO(contents) as file: - df = pd.read_csv( - file, - encoding=encoding, - compression=compression, - sep=None, - engine="python", - header=header, - prefix=prefix, - ) + pdread.seek(0, SEEK_SET) + df = pd.read_csv( + pdread, + encoding=encoding, + compression=compression, + sep=None, + engine="python", + header=header, + nrows=nrows, + prefix=prefix, + ) return df diff --git a/datasets/monkeypatch.py b/datasets/monkeypatch.py new file mode 100644 index 0000000..315eccb --- /dev/null +++ b/datasets/monkeypatch.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" +Monkey's patched methods for the SpooledTemporaryFile class. +This is because the SpooledTemporaryFile does not inherit / implement the IOBase class. +""" +from tempfile import SpooledTemporaryFile + + +def _readable(self): + return self._file.readable() + + +def _writable(self): + return self._file.writable() + + +def _seekable(self): + return self._file.seekable() + + +SpooledTemporaryFile.readable = _readable +SpooledTemporaryFile.writable = _writable +SpooledTemporaryFile.seekable = _seekable diff --git a/tests/test_api.py b/tests/test_api.py index 6a3e5d3..04ca94a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -173,12 +173,11 @@ def test_get_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa'], ['01/01/2003', 4.6, 3.1, 1.5, 0.2, 'Iris-setosa']], "filename": "iris.data", - "total": 4 + "total": 3 } self.assertIn("name", result) @@ -198,10 +197,10 @@ def test_get_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa']], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa']], "filename": "iris.data", - "total": 4 + "total": 3 } del result["name"] self.assertDictEqual(expected, result) @@ -218,11 +217,11 @@ def test_get_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa']], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa'], + ['01/01/2003', 4.6, 3.1, 1.5, 0.2, 'Iris-setosa']], "filename": "iris.data", - "total": 4 + "total": 3 } del result["name"] self.assertDictEqual(expected, result) @@ -230,7 +229,7 @@ def test_get_dataset(self): rv = TEST_CLIENT.get("/datasets/iris.data?page=15&page_size=2") result = rv.json() - expected = {"message": "The specified page does not exist"} + expected = {'message': 'The specified page does not exist'} self.assertDictEqual(expected, result) self.assertEqual(rv.status_code, 404) @@ -262,12 +261,11 @@ def test_get_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa'], ['01/01/2003', 4.6, 3.1, 1.5, 0.2, 'Iris-setosa']], "filename": "iris.data", - "total": 4 + "total": 3 } # name is machine-generated # we assert it exists, but we don't check its value @@ -287,10 +285,10 @@ def test_get_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa']], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa']], "filename": "iris.data", - "total": 4 + "total": 3 } # name is machine-generated # we assert it exists, but we don't check its value @@ -434,13 +432,12 @@ def test_patch_dataset(self): {"name": "col4", "featuretype": "Numerical"}, {"name": "col5", "featuretype": "Categorical"}, ], - "data": [['01/01/2000', 5.1, 3.5, 1.4, 0.2, 'Iris-setosa'], - ['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], + "data": [['01/01/2001', 4.9, 3.0, 1.4, 0.2, 'Iris-setosa'], ['01/01/2002', 4.7, 3.2, 1.3, 0.2, 'Iris-setosa'], ['01/01/2003', 4.6, 3.1, 1.5, 0.2, 'Iris-setosa']], "filename": "iris.data", "name": name, - "total": 4 + "total": 3 } self.assertDictEqual(expected, result) self.assertEqual(rv.status_code, 200)