|
1 | 1 | from unittest.mock import patch |
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | import pyspark |
4 | 5 | import pytest |
5 | 6 |
|
| 7 | +from datasets import Features, Image, IterableDataset |
6 | 8 | from datasets.builder import InvalidConfigName |
7 | 9 | from datasets.data_files import DataFilesList |
8 | 10 | from datasets.packaged_modules.spark.spark import ( |
@@ -131,3 +133,38 @@ def test_repartition_df_if_needed_max_num_df_rows(): |
131 | 133 | spark_builder._repartition_df_if_needed(max_shard_size=1) |
132 | 134 | # The new number of partitions should not be greater than the number of rows. |
133 | 135 | assert spark_builder.df.rdd.getNumPartitions() == 100 |
| 136 | + |
| 137 | + |
| 138 | +@require_not_windows |
| 139 | +@require_dill_gt_0_3_2 |
| 140 | +def test_iterable_image_features(): |
| 141 | + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() |
| 142 | + img_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes() |
| 143 | + data = [(img_bytes,)] |
| 144 | + df = spark.createDataFrame(data, "image: binary") |
| 145 | + features = Features({"image": Image(decode=False)}) |
| 146 | + dset = IterableDataset.from_spark(df, features=features) |
| 147 | + item = next(iter(dset)) |
| 148 | + assert item.keys() == {"image"} |
| 149 | + assert item == {"image": {"path": None, "bytes": img_bytes}} |
| 150 | + |
| 151 | + |
| 152 | +@require_not_windows |
| 153 | +@require_dill_gt_0_3_2 |
| 154 | +def test_iterable_image_features_decode(): |
| 155 | + from io import BytesIO |
| 156 | + |
| 157 | + import PIL.Image |
| 158 | + |
| 159 | + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() |
| 160 | + img = PIL.Image.fromarray(np.zeros((10, 10, 3), dtype=np.uint8), "RGB") |
| 161 | + buffer = BytesIO() |
| 162 | + img.save(buffer, format="PNG") |
| 163 | + img_bytes = bytes(buffer.getvalue()) |
| 164 | + data = [(img_bytes,)] |
| 165 | + df = spark.createDataFrame(data, "image: binary") |
| 166 | + features = Features({"image": Image()}) |
| 167 | + dset = IterableDataset.from_spark(df, features=features) |
| 168 | + item = next(iter(dset)) |
| 169 | + assert item.keys() == {"image"} |
| 170 | + assert isinstance(item["image"], PIL.Image.Image) |
0 commit comments