Skip to content

Commit eebde70

Browse files
authored
Merge pull request #2094 from activeloopai/fy_new_transform
[AL-2038] Faster transforms
2 parents b2bb8f0 + 357d0dd commit eebde70

File tree

10 files changed

+572
-435
lines changed

10 files changed

+572
-435
lines changed

deeplake/core/chunk_engine.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,10 @@ def _convert_to_list(self, samples):
664664
return False
665665

666666
def check_each_sample(self, samples, verify=True):
667+
# overridden in LinkedChunkEngine
667668
return
668669

669-
def _sanitize_samples(self, samples, verify=True):
670+
def _sanitize_samples(self, samples, verify=True, pg_callback=None):
670671
check_samples_type(samples)
671672
if isinstance(samples, list):
672673
samples = [
@@ -994,7 +995,9 @@ def _extend(self, samples, progressbar, pg_callback=None, update_commit_diff=Tru
994995
return
995996
if len(samples) == 0:
996997
return
997-
samples, verified_samples = self._sanitize_samples(samples)
998+
samples, verified_samples = self._sanitize_samples(
999+
samples, pg_callback=pg_callback
1000+
)
9981001
self._samples_to_chunks(
9991002
samples,
10001003
start_chunk=self.last_appended_chunk(),
@@ -1025,15 +1028,22 @@ def extend(
10251028
if self.is_sequence:
10261029
samples = tqdm(samples) if progressbar else samples
10271030
verified_samples = []
1028-
for sample in samples:
1029-
if sample is None:
1030-
sample = []
1031-
verified_sample = self._extend(
1032-
sample, progressbar=False, update_commit_diff=False
1033-
)
1034-
self.sequence_encoder.register_samples(len(sample), 1)
1035-
self.commit_diff.add_data(1)
1036-
verified_samples.append(verified_sample or sample)
1031+
num_samples_added = 0
1032+
try:
1033+
for sample in samples:
1034+
if sample is None:
1035+
sample = []
1036+
verified_sample = self._extend(
1037+
sample, progressbar=False, update_commit_diff=False
1038+
)
1039+
self.sequence_encoder.register_samples(len(sample), 1)
1040+
self.commit_diff.add_data(1)
1041+
num_samples_added += 1
1042+
verified_samples.append(verified_sample or sample)
1043+
except Exception:
1044+
for _ in range(num_samples_added):
1045+
self.pop()
1046+
raise
10371047
if link_callback:
10381048
samples = [
10391049
None if is_empty_list(s) else s for s in verified_samples
@@ -2051,6 +2061,9 @@ def list_all_chunks(self) -> List[str]:
20512061
"""Return list of all chunks for current `version_state['commit_id']` and tensor"""
20522062
commit_id = self.commit_id
20532063
if commit_id == FIRST_COMMIT_ID:
2064+
arr = self.chunk_id_encoder._encoded
2065+
if not arr.size:
2066+
return []
20542067
return [
20552068
ChunkIdEncoder.name_from_id(chunk_id)
20562069
for chunk_id in self.chunk_id_encoder._encoded[:, CHUNK_ID_COLUMN]

deeplake/core/linked_chunk_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,13 @@ def check_each_sample(self, samples, verify=True):
263263
verified_samples.append(sample)
264264
else:
265265
try:
266+
_verify = verify and self.verify
266267
verified_samples.append(
267268
read_linked_sample(
268269
sample.path,
269270
sample.creds_key,
270271
self.link_creds,
271-
verify=verify and self.verify,
272+
verify=_verify,
272273
)
273274
)
274275
except Exception as e:

deeplake/core/polygon.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from deeplake.util.exceptions import EmptyPolygonError
12
from typing import Union, List
3+
24
import numpy as np
35
import deeplake
46

@@ -7,6 +9,10 @@ class Polygon:
79
"""Represents a polygon."""
810

911
def __init__(self, coords: Union[np.ndarray, List[float]], dtype="float32"):
12+
if coords is None or len(coords) == 0:
13+
raise EmptyPolygonError(
14+
"A polygons sample can be empty or None but a polygon within a sample cannot be empty or None."
15+
)
1016
self.coords = coords
1117
self.dtype = dtype
1218

deeplake/core/transform/test_transform.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
from click.testing import CliRunner
55
from deeplake.core.storage.memory import MemoryProvider
6-
from deeplake.core.transform.transform_tensor import TransformTensor
76
from deeplake.core.version_control.test_version_control import (
87
compare_dataset_diff,
98
compare_tensor_diff,
@@ -1439,3 +1438,56 @@ def upload(item, ds):
14391438
assert ds["boxes"].meta.max_shape == [20, 4]
14401439

14411440
assert ds["labels"].numpy().shape == (40, 10)
1441+
1442+
1443+
def test_transform_numpy_only(local_ds):
1444+
@deeplake.compute
1445+
def upload(i, ds):
1446+
ds.abc.extend(i * np.ones((10, 5, 5)))
1447+
1448+
with local_ds as ds:
1449+
ds.create_tensor("abc")
1450+
1451+
upload().eval(list(range(100)), ds, num_workers=2)
1452+
1453+
assert len(local_ds) == 1000
1454+
1455+
for i in range(100):
1456+
np.testing.assert_array_equal(
1457+
ds.abc[i * 10 : (i + 1) * 10].numpy(), i * np.ones((10, 5, 5))
1458+
)
1459+
1460+
1461+
@deeplake.compute
1462+
def add_samples(i, ds, flower_path):
1463+
ds.abc.extend(i * np.ones((5, 5, 5)))
1464+
ds.images.extend([deeplake.read(flower_path) for _ in range(5)])
1465+
1466+
1467+
@deeplake.compute
1468+
def mul_by_2(sample_in, samples_out):
1469+
samples_out.abc.append(2 * sample_in.abc.numpy())
1470+
samples_out.images.append(sample_in.images.numpy() - 1)
1471+
1472+
1473+
def test_pipeline(local_ds, flower_path):
1474+
pipeline = deeplake.compose([add_samples(flower_path), mul_by_2()])
1475+
1476+
flower_arr = np.array(deeplake.read(flower_path))
1477+
1478+
with local_ds as ds:
1479+
ds.create_tensor("abc")
1480+
ds.create_tensor("images", htype="image", sample_compression="png")
1481+
1482+
pipeline.eval(list(range(10)), ds, num_workers=2)
1483+
1484+
assert len(local_ds) == 50
1485+
1486+
for i in range(10):
1487+
np.testing.assert_array_equal(
1488+
ds.abc[i * 5 : (i + 1) * 5].numpy(), i * 2 * np.ones((5, 5, 5))
1489+
)
1490+
np.testing.assert_array_equal(
1491+
ds.images[i * 5 : (i + 1) * 5].numpy(),
1492+
np.tile(flower_arr - 1, (5, 1, 1, 1)),
1493+
)

deeplake/core/transform/transform.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from deeplake.util.version_control import auto_checkout
3636
from deeplake.util.class_label import sync_labels
3737

38+
import posixpath
39+
3840

3941
class ComputeFunction:
4042
def __init__(self, func, args, kwargs, name: Optional[str] = None):
@@ -55,6 +57,7 @@ def eval(
5557
check_lengths: bool = True,
5658
pad_data_in: bool = False,
5759
read_only_ok: bool = False,
60+
cache_size: int = 16,
5861
checkpoint_interval: int = 0,
5962
ignore_errors: bool = False,
6063
**kwargs,
@@ -78,6 +81,7 @@ def eval(
7881
Defaults to False.
7982
read_only_ok (bool): If ``True`` and output dataset is same as input dataset, the read-only check is skipped. This can be used to read data in parallel without making changes to underlying dataset.
8083
Defaults to False.
84+
cache_size (int): Cache size to be used by transform per worker.
8185
checkpoint_interval (int): If > 0, the transform will be checkpointed with a commit every ``checkpoint_interval`` input samples to avoid restarting full transform due to intermitten failures. If the transform is interrupted, the intermediate data is deleted and the dataset is reset to the last commit.
8286
If <= 0, no checkpointing is done. Checkpoint interval should be a multiple of num_workers if num_workers > 0. Defaults to 0.
8387
ignore_errors (bool): If ``True``, input samples that causes transform to fail will be skipped and the errors will be ignored **if possible**.
@@ -102,6 +106,7 @@ def eval(
102106
check_lengths,
103107
pad_data_in,
104108
read_only_ok,
109+
cache_size,
105110
checkpoint_interval,
106111
ignore_errors,
107112
**kwargs,
@@ -130,6 +135,7 @@ def eval(
130135
check_lengths: bool = True,
131136
pad_data_in: bool = False,
132137
read_only_ok: bool = False,
138+
cache_size: int = 16,
133139
checkpoint_interval: int = 0,
134140
ignore_errors: bool = False,
135141
**kwargs,
@@ -153,6 +159,7 @@ def eval(
153159
Defaults to ``False``.
154160
read_only_ok (bool): If ``True`` and output dataset is same as input dataset, the read-only check is skipped.
155161
Defaults to False.
162+
cache_size (int): Cache size to be used by transform per worker.
156163
checkpoint_interval (int): If > 0, the transform will be checkpointed with a commit every ``checkpoint_interval`` input samples to avoid restarting full transform due to intermitten failures. If the transform is interrupted, the intermediate data is deleted and the dataset is reset to the last commit.
157164
If <= 0, no checkpointing is done. Checkpoint interval should be a multiple of num_workers if num_workers > 0. Defaults to 0.
158165
ignore_errors (bool): If ``True``, input samples that causes transform to fail will be skipped and the errors will be ignored **if possible**.
@@ -218,7 +225,7 @@ def my_fn(sample_in: Any, samples_out, my_arg0, my_arg1=0):
218225
initial_autoflush = target_ds.storage.autoflush
219226
target_ds.storage.autoflush = False
220227

221-
if not check_lengths:
228+
if not check_lengths or read_only_ok:
222229
skip_ok = True
223230

224231
checkpointing_enabled = checkpoint_interval > 0
@@ -267,6 +274,7 @@ def my_fn(sample_in: Any, samples_out, my_arg0, my_arg1=0):
267274
overwrite,
268275
skip_ok,
269276
read_only_ok and overwrite,
277+
cache_size,
270278
pbar,
271279
pqueue,
272280
ignore_errors,
@@ -286,11 +294,12 @@ def my_fn(sample_in: Any, samples_out, my_arg0, my_arg1=0):
286294
index, sample = None, None
287295
if isinstance(e, TransformError):
288296
index, sample = e.index, e.sample
297+
e = e.__cause__ # type: ignore
289298
raise TransformError(
290299
index=index,
291300
sample=sample,
292301
samples_processed=samples_processed,
293-
).with_traceback(e.__traceback__)
302+
) from e
294303
finally:
295304
reload_and_rechunk(
296305
overwrite,
@@ -316,6 +325,7 @@ def run(
316325
overwrite: bool = False,
317326
skip_ok: bool = False,
318327
read_only: bool = False,
328+
cache_size: int = 16,
319329
pbar=None,
320330
pqueue=None,
321331
ignore_errors: bool = False,
@@ -340,16 +350,14 @@ def run(
340350
else []
341351
)
342352
label_temp_tensors = {}
343-
actual_tensors = (
344-
None
345-
if not class_label_tensors
346-
else [target_ds[t].key for t in target_ds.tensors]
347-
)
353+
354+
visible_tensors = list(target_ds.tensors)
355+
visible_tensors = [target_ds[t].key for t in visible_tensors]
348356

349357
if not read_only:
350358
for tensor in class_label_tensors:
351359
actual_tensor = target_ds[tensor]
352-
temp_tensor = f"__temp{tensor}_{uuid4().hex[:4]}"
360+
temp_tensor = f"__temp{posixpath.relpath(tensor, target_ds.group_index)}_{uuid4().hex[:4]}"
353361
with target_ds:
354362
temp_tensor_obj = target_ds.create_tensor(
355363
temp_tensor,
@@ -361,13 +369,9 @@ def run(
361369
create_id_tensor=False,
362370
)
363371
temp_tensor_obj.meta._disable_temp_transform = True
364-
label_temp_tensors[tensor] = temp_tensor
372+
label_temp_tensors[tensor] = temp_tensor_obj.key
365373
target_ds.flush()
366374

367-
visible_tensors = list(target_ds.tensors)
368-
visible_tensors = [target_ds[t].key for t in visible_tensors]
369-
visible_tensors = list(set(visible_tensors) - set(class_label_tensors))
370-
371375
tensors = list(target_ds._tensors())
372376
tensors = [target_ds[t].key for t in tensors]
373377
tensors = list(set(tensors) - set(class_label_tensors))
@@ -384,12 +388,12 @@ def run(
384388
tensors,
385389
visible_tensors,
386390
label_temp_tensors,
387-
actual_tensors,
388391
self,
389392
version_state,
390393
target_ds.link_creds,
391394
skip_ok,
392395
extend_only,
396+
cache_size,
393397
ignore_errors,
394398
)
395399
map_inp = zip(slices, offsets, storages, repeat(args))

0 commit comments

Comments
 (0)