Skip to content

Commit 5d96a71

Browse files
gatorsmilemarmbrus
authored andcommitted
[SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs
This PR contains the following updates: - Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`. - Replaced all the `queryExecution.analyzed` by the function call `logicalPlan` - A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`) - A few API descriptions are wrong. (e.g., `mapPartitions`) marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you! Author: gatorsmile <gatorsmile@gmail.com> Closes apache#10184 from gatorsmile/datasetClean.
1 parent c0b13d5 commit 5d96a71

1 file changed

Lines changed: 40 additions & 40 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
6767
tEncoder: Encoder[T]) extends Queryable with Serializable {
6868

6969
/**
70-
* An unresolved version of the internal encoder for the type of this dataset. This one is marked
71-
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
72-
* object type (that will be possibly resolved to a different schema).
70+
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
71+
* marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
72+
* same object type (that will be possibly resolved to a different schema).
7373
*/
7474
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
7575

7676
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
7777
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
78-
unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
78+
unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
79+
80+
/**
81+
* The encoder where the expressions used to construct an object from an input row have been
82+
* bound to the ordinals of the given schema.
83+
*/
84+
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
7985

8086
private implicit def classTag = resolvedTEncoder.clsTag
8187

@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
8995
override def schema: StructType = resolvedTEncoder.schema
9096

9197
/**
92-
* Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
98+
* Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
9399
* @since 1.6.0
94100
*/
95101
override def printSchema(): Unit = toDF().printSchema()
@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
111117
* ************* */
112118

113119
/**
114-
* Returns a new `Dataset` where each record has been mapped on to the specified type. The
120+
* Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
115121
* method used to map columns depend on the type of `U`:
116122
* - When `U` is a class, fields for the class will be mapped to columns of the same name
117123
* (case sensitivity is determined by `spark.sql.caseSensitive`)
@@ -145,23 +151,20 @@ class Dataset[T] private[sql](
145151
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
146152

147153
/**
148-
* Returns this Dataset.
154+
* Returns this [[Dataset]].
149155
* @since 1.6.0
150156
*/
151157
// This is declared with parentheses to prevent the Scala compiler from treating
152158
// `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
153159
def toDS(): Dataset[T] = this
154160

155161
/**
156-
* Converts this Dataset to an RDD.
162+
* Converts this [[Dataset]] to an [[RDD]].
157163
* @since 1.6.0
158164
*/
159165
def rdd: RDD[T] = {
160-
val tEnc = resolvedTEncoder
161-
val input = queryExecution.analyzed.output
162166
queryExecution.toRdd.mapPartitions { iter =>
163-
val bound = tEnc.bind(input)
164-
iter.map(bound.fromRow)
167+
iter.map(boundTEncoder.fromRow)
165168
}
166169
}
167170

@@ -189,15 +192,15 @@ class Dataset[T] private[sql](
189192
def show(numRows: Int): Unit = show(numRows, truncate = true)
190193

191194
/**
192-
* Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
195+
* Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
193196
* will be truncated, and all cells will be aligned right.
194197
*
195198
* @since 1.6.0
196199
*/
197200
def show(): Unit = show(20)
198201

199202
/**
200-
* Displays the top 20 rows of [[DataFrame]] in a tabular form.
203+
* Displays the top 20 rows of [[Dataset]] in a tabular form.
201204
*
202205
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
203206
* be truncated and all cells will be aligned right
@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
207210
def show(truncate: Boolean): Unit = show(20, truncate)
208211

209212
/**
210-
* Displays the [[DataFrame]] in a tabular form. For example:
213+
* Displays the [[Dataset]] in a tabular form. For example:
211214
* {{{
212215
* year month AVG('Adj Close) MAX('Adj Close)
213216
* 1980 12 0.503218 0.595103
@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
291294

292295
/**
293296
* (Scala-specific)
294-
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
297+
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
295298
* @since 1.6.0
296299
*/
297300
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
307310

308311
/**
309312
* (Java-specific)
310-
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
313+
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
311314
* @since 1.6.0
312315
*/
313316
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
341344

342345
/**
343346
* (Scala-specific)
344-
* Runs `func` on each element of this Dataset.
347+
* Runs `func` on each element of this [[Dataset]].
345348
* @since 1.6.0
346349
*/
347350
def foreach(func: T => Unit): Unit = rdd.foreach(func)
348351

349352
/**
350353
* (Java-specific)
351-
* Runs `func` on each element of this Dataset.
354+
* Runs `func` on each element of this [[Dataset]].
352355
* @since 1.6.0
353356
*/
354357
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
355358

356359
/**
357360
* (Scala-specific)
358-
* Runs `func` on each partition of this Dataset.
361+
* Runs `func` on each partition of this [[Dataset]].
359362
* @since 1.6.0
360363
*/
361364
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
362365

363366
/**
364367
* (Java-specific)
365-
* Runs `func` on each partition of this Dataset.
368+
* Runs `func` on each partition of this [[Dataset]].
366369
* @since 1.6.0
367370
*/
368371
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
@@ -374,27 +377,27 @@ class Dataset[T] private[sql](
374377

375378
/**
376379
* (Scala-specific)
377-
* Reduces the elements of this Dataset using the specified binary function. The given function
380+
* Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
378381
* must be commutative and associative or the result may be non-deterministic.
379382
* @since 1.6.0
380383
*/
381384
def reduce(func: (T, T) => T): T = rdd.reduce(func)
382385

383386
/**
384387
* (Java-specific)
385-
* Reduces the elements of this Dataset using the specified binary function. The given function
388+
* Reduces the elements of this Dataset using the specified binary function. The given `func`
386389
* must be commutative and associative or the result may be non-deterministic.
387390
* @since 1.6.0
388391
*/
389392
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
390393

391394
/**
392395
* (Scala-specific)
393-
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
396+
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
394397
* @since 1.6.0
395398
*/
396399
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
397-
val inputPlan = queryExecution.analyzed
400+
val inputPlan = logicalPlan
398401
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
399402
val executed = sqlContext.executePlan(withGroupingKey)
400403

@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
429432

430433
/**
431434
* (Java-specific)
432-
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
435+
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
433436
* @since 1.6.0
434437
*/
435-
def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
436-
groupBy(f.call(_))(encoder)
438+
def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
439+
groupBy(func.call(_))(encoder)
437440

438441
/* ****************** *
439442
* Typed Relational *
440443
* ****************** */
441444

442445
/**
443-
* Selects a set of column based expressions.
446+
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
444447
* {{{
445448
* df.select($"colA", $"colB" + 1)
446449
* }}}
@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
464467
sqlContext,
465468
Project(
466469
c1.withInputType(
467-
resolvedTEncoder.bind(queryExecution.analyzed.output),
468-
queryExecution.analyzed.output).named :: Nil,
470+
boundTEncoder,
471+
logicalPlan.output).named :: Nil,
469472
logicalPlan))
470473
}
471474

@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
477480
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
478481
val encoders = columns.map(_.encoder)
479482
val namedColumns =
480-
columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
483+
columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
481484
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
482485

483486
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
@@ -654,25 +657,22 @@ class Dataset[T] private[sql](
654657
* Returns an array that contains all the elements in this [[Dataset]].
655658
*
656659
* Running collect requires moving all the data into the application's driver process, and
657-
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
660+
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
658661
*
659662
* For Java API, use [[collectAsList]].
660663
* @since 1.6.0
661664
*/
662665
def collect(): Array[T] = {
663666
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
664667
// to convert the rows into objects of type T.
665-
val tEnc = resolvedTEncoder
666-
val input = queryExecution.analyzed.output
667-
val bound = tEnc.bind(input)
668-
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
668+
queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
669669
}
670670

671671
/**
672672
* Returns an array that contains all the elements in this [[Dataset]].
673673
*
674674
* Running collect requires moving all the data into the application's driver process, and
675-
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
675+
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
676676
*
677677
* For Java API, use [[collectAsList]].
678678
* @since 1.6.0
@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
683683
* Returns the first `num` elements of this [[Dataset]] as an array.
684684
*
685685
* Running take requires moving data into the application's driver process, and doing so with
686-
* a very large `n` can crash the driver process with OutOfMemoryError.
686+
* a very large `num` can crash the driver process with OutOfMemoryError.
687687
* @since 1.6.0
688688
*/
689689
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
692692
* Returns the first `num` elements of this [[Dataset]] as an array.
693693
*
694694
* Running take requires moving data into the application's driver process, and doing so with
695-
* a very large `n` can crash the driver process with OutOfMemoryError.
695+
* a very large `num` can crash the driver process with OutOfMemoryError.
696696
* @since 1.6.0
697697
*/
698698
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)

0 commit comments

Comments
 (0)