Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ class Module(object):
files have changed.
"""

def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={},
sbt_test_goals=(), python_test_goals=(), excluded_python_implementations=(),
test_tags=(), should_run_r_tests=False, should_run_build_tests=False):
def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
environ=None, sbt_test_goals=(), python_test_goals=(),
excluded_python_implementations=(), test_tags=(), should_run_r_tests=False,
should_run_build_tests=False):
"""
Define a new module.

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
self.source_file_prefixes = source_file_regexes
self.sbt_test_goals = sbt_test_goals
self.build_profile_flags = build_profile_flags
self.environ = environ
self.environ = environ or {}
self.python_test_goals = python_test_goals
self.excluded_python_implementations = excluded_python_implementations
self.test_tags = test_tags
Expand Down
2 changes: 1 addition & 1 deletion dev/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ max-line-length=100
exclude=python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*

[flake8]
select = E901,E999,F821,F822,F823,F401,F405
select = E901,E999,F821,F822,F823,F401,F405,B006
exclude = python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi
max-line-length = 100
2 changes: 2 additions & 0 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

[mypy]

no_implicit_optional = True

[mypy-pyspark.cloudpickle.*]
ignore_errors = True

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,7 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
@keyword_only
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), # noqa: B005
quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
Expand All @@ -1819,7 +1819,7 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p
@since("1.6.0")
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), # noqa: B005
quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,13 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
.. versionadded:: 1.4.0
"""

def __init__(self, bestModel, avgMetrics=[], subModels=None):
def __init__(self, bestModel, avgMetrics=None, subModels=None):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
#: CrossValidator.estimatorParamMaps, in the corresponding order.
self.avgMetrics = avgMetrics
self.avgMetrics = avgMetrics or []
#: sub model list from cross validation
self.subModels = subModels

Expand Down Expand Up @@ -920,12 +920,12 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable,
.. versionadded:: 2.0.0
"""

def __init__(self, bestModel, validationMetrics=[], subModels=None):
def __init__(self, bestModel, validationMetrics=None, subModels=None):
super(TrainValidationSplitModel, self).__init__()
#: best model from train validation split
self.bestModel = bestModel
#: evaluated validation metrics
self.validationMetrics = validationMetrics
self.validationMetrics = validationMetrics or []
#: sub models from train validation split
self.subModels = subModels

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tuning.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CrossValidatorModel(
def __init__(
self,
bestModel: Model,
avgMetrics: List[float] = ...,
avgMetrics: Optional[List[float]] = ...,
subModels: Optional[List[List[Model]]] = ...,
) -> None: ...
def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ...
Expand Down Expand Up @@ -171,7 +171,7 @@ class TrainValidationSplitModel(
def __init__(
self,
bestModel: Model,
validationMetrics: List[float] = ...,
validationMetrics: Optional[List[float]] = ...,
subModels: Optional[List[Model]] = ...,
) -> None: ...
def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ...
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/resource/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ class ResourceProfile(object):
This API is evolving.
"""

def __init__(self, _java_resource_profile=None, _exec_req={}, _task_req={}):
def __init__(self, _java_resource_profile=None, _exec_req=None, _task_req=None):
if _java_resource_profile is not None:
self._java_resource_profile = _java_resource_profile
else:
self._java_resource_profile = None
self._executor_resource_requests = _exec_req
self._task_resource_requests = _task_req
self._executor_resource_requests = _exec_req or {}
self._task_resource_requests = _task_req or {}

@property
def id(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.util import _print_missing_jar


def from_avro(data, jsonFormatSchema, options={}):
def from_avro(data, jsonFormatSchema, options=None):
"""
Converts a binary column of Avro format into its corresponding catalyst value.
The specified schema must match the read data, otherwise the behavior is undefined:
Expand Down Expand Up @@ -70,7 +70,7 @@ def from_avro(data, jsonFormatSchema, options={}):
sc = SparkContext._active_spark_context
try:
jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro(
_to_java_column(data), jsonFormatSchema, options)
_to_java_column(data), jsonFormatSchema, options or {})
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":
_print_missing_jar("Avro", "avro", "avro", sc.version)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/avro/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict
from typing import Dict, Optional

from pyspark.sql._typing import ColumnOrName
from pyspark.sql.column import Column

def from_avro(
data: ColumnOrName, jsonFormatSchema: str, options: Dict[str, str] = ...
data: ColumnOrName, jsonFormatSchema: str, options: Optional[Dict[str, str]] = ...
) -> Column: ...
def to_avro(data: ColumnOrName, jsonFormatSchema: str = ...) -> Column: ...
18 changes: 10 additions & 8 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def _invoke_binary_math_function(name, col1, col2):
)


def _options_to_str(options):
return {key: to_str(value) for (key, value) in options.items()}
def _options_to_str(options=None):
if options:
return {key: to_str(value) for (key, value) in options.items()}
return {}


def lit(col):
Expand Down Expand Up @@ -3415,7 +3417,7 @@ def json_tuple(col, *fields):
return Column(jc)


def from_json(col, schema, options={}):
def from_json(col, schema, options=None):
"""
Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
as keys type, :class:`StructType` or :class:`ArrayType` with
Expand Down Expand Up @@ -3471,7 +3473,7 @@ def from_json(col, schema, options={}):
return Column(jc)


def to_json(col, options={}):
def to_json(col, options=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we still have to modify a few annotations, right?

Probably something like functions.pyi.patch.txt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I've added them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why implicit optional is still allowed 🤔

"""
Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType`
into a JSON string. Throws an exception, in the case of an unsupported type.
Expand Down Expand Up @@ -3518,7 +3520,7 @@ def to_json(col, options={}):
return Column(jc)


def schema_of_json(json, options={}):
def schema_of_json(json, options=None):
"""
Parses a JSON string and infers its schema in DDL format.

Expand Down Expand Up @@ -3555,7 +3557,7 @@ def schema_of_json(json, options={}):
return Column(jc)


def schema_of_csv(csv, options={}):
def schema_of_csv(csv, options=None):
"""
Parses a CSV string and infers its schema in DDL format.

Expand Down Expand Up @@ -3588,7 +3590,7 @@ def schema_of_csv(csv, options={}):
return Column(jc)


def to_csv(col, options={}):
def to_csv(col, options=None):
"""
Converts a column containing a :class:`StructType` into a CSV string.
Throws an exception, in the case of an unsupported type.
Expand Down Expand Up @@ -3999,7 +4001,7 @@ def sequence(start, stop, step=None):
_to_java_column(start), _to_java_column(stop), _to_java_column(step)))


def from_csv(col, schema, options={}):
def from_csv(col, schema, options=None):
"""
Parses a column containing a CSV string to a row with the specified schema.
Returns `null`, in the case of an unparseable string.
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def json_tuple(col: ColumnOrName, *fields: str) -> Column: ...
def from_json(
col: ColumnOrName,
schema: Union[ArrayType, StructType, Column, str],
options: Dict[str, str] = ...,
options: Optional[Dict[str, str]] = ...,
) -> Column: ...
def to_json(col: ColumnOrName, options: Dict[str, str] = ...) -> Column: ...
def schema_of_json(json: ColumnOrName, options: Dict[str, str] = ...) -> Column: ...
Expand Down