diff --git a/petastorm/tests/test_copy_dataset.py b/petastorm/tests/test_copy_dataset.py index c51aef2e..f7076c1a 100644 --- a/petastorm/tests/test_copy_dataset.py +++ b/petastorm/tests/test_copy_dataset.py @@ -18,18 +18,12 @@ import numpy as np import pytest -from pyspark.sql import SparkSession from pyspark.sql.utils import AnalysisException from petastorm.reader import make_reader from petastorm.tools.copy_dataset import _main, copy_dataset -@pytest.fixture() -def spark_session(): - return SparkSession.builder.appName('petastorm-copy').getOrCreate() - - def test_copy_and_overwrite_cli(tmpdir, synthetic_dataset): target_url = 'file:///' + os.path.join(tmpdir.strpath, 'copied_data') _main([synthetic_dataset.url, target_url]) diff --git a/petastorm/tests/test_dataset_metadata.py b/petastorm/tests/test_dataset_metadata.py index fb617a67..36d848b1 100644 --- a/petastorm/tests/test_dataset_metadata.py +++ b/petastorm/tests/test_dataset_metadata.py @@ -13,9 +13,8 @@ # limitations under the License. import numpy as np -import pytest import pyarrow -from pyspark.sql import SparkSession +import pytest from pyspark.sql.types import IntegerType from petastorm.codecs import ScalarCodec @@ -37,7 +36,7 @@ def test_get_schema_from_dataset_url_bogus_url(): get_schema_from_dataset_url('/invalid_url') -def test_serialize_filesystem_factory(tmpdir): +def test_serialize_filesystem_factory(tmpdir, spark_test_ctx): SimpleSchema = Unischema('SimpleSchema', [ UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False), UnischemaField('foo', np.int32, (), ScalarCodec(IntegerType()), False), @@ -50,11 +49,11 @@ def __getstate__(self): rows_count = 10 output_url = "file://{0}/fs_factory_test".format(tmpdir) rowgroup_size_mb = 256 - spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate() + spark = spark_test_ctx.spark sc = spark.sparkContext with materialize_dataset(spark, output_url, SimpleSchema, rowgroup_size_mb, filesystem_factory=BogusFS): - rows_rdd = sc.parallelize(range(rows_count))\ - .map(lambda x: {'id': x, 'foo': x})\ + rows_rdd = sc.parallelize(range(rows_count)) \ + .map(lambda x: {'id': x, 'foo': x}) \ .map(lambda x: dict_to_spark_row(SimpleSchema, x)) spark.createDataFrame(rows_rdd, SimpleSchema.as_spark_schema()) \ diff --git a/petastorm/tests/test_end_to_end.py b/petastorm/tests/test_end_to_end.py index ba68ff96..68c03d15 100644 --- a/petastorm/tests/test_end_to_end.py +++ b/petastorm/tests/test_end_to_end.py @@ -743,13 +743,13 @@ def test_rowgroup_selector_wrong_index_name(synthetic_dataset, reader_factory): reader_factory(synthetic_dataset.url, rowgroup_selector=SingleIndexSelector('WrongIndexName', ['some_value'])) -def test_materialize_dataset_hadoop_config(tmpdir_factory): +def test_materialize_dataset_hadoop_config(tmpdir_factory, spark_test_ctx): """Test that using materialize_dataset does not alter the hadoop_config""" path = tmpdir_factory.mktemp('data').strpath tmp_url = "file://" + path # This test does not properly check if parquet.enable.summary-metadata is restored properly with pyspark < 2.4 - spark = SparkSession.builder.getOrCreate() + spark = spark_test_ctx.spark hadoop_config = spark.sparkContext._jsc.hadoopConfiguration() parquet_metadata_level = "COMMON_ONLY" @@ -775,26 +775,26 @@ def test_materialize_dataset_hadoop_config(tmpdir_factory): spark.stop() -def test_materialize_with_summary_metadata(tmpdir_factory): +def test_materialize_with_summary_metadata(tmpdir_factory, spark_test_ctx): """Verify _summary_metadata appears, when requested""" path = tmpdir_factory.mktemp('data').strpath tmp_url = "file://" + path - spark = SparkSession.builder.getOrCreate() + spark = spark_test_ctx.spark create_test_dataset(tmp_url, range(10), spark=spark, use_summary_metadata=True) assert os.path.exists(os.path.join(path, "_metadata")) spark.stop() -def test_pass_in_pyarrow_filesystem_to_materialize_dataset(synthetic_dataset, tmpdir): +def test_pass_in_pyarrow_filesystem_to_materialize_dataset(synthetic_dataset, tmpdir, spark_test_ctx): a_moved_path = tmpdir.join('moved').strpath copytree(synthetic_dataset.path, a_moved_path) local_fs = pyarrow.LocalFileSystem os.remove(a_moved_path + '/_common_metadata') - spark = SparkSession.builder.getOrCreate() + spark = spark_test_ctx.spark with materialize_dataset(spark, a_moved_path, TestSchema, filesystem_factory=local_fs): pass