@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
6767 tEncoder : Encoder [T ]) extends Queryable with Serializable {
6868
6969 /**
70- * An unresolved version of the internal encoder for the type of this dataset . This one is marked
71- * implicit so that we can use it when constructing new [[Dataset ]] objects that have the same
72- * object type (that will be possibly resolved to a different schema).
70+ * An unresolved version of the internal encoder for the type of this [[ Dataset ]] . This one is
71+ * marked implicit so that we can use it when constructing new [[Dataset ]] objects that have the
72+ * same object type (that will be possibly resolved to a different schema).
7373 */
7474 private [sql] implicit val unresolvedTEncoder : ExpressionEncoder [T ] = encoderFor(tEncoder)
7575
7676 /** The encoder for this [[Dataset ]] that has been resolved to its output schema. */
7777 private [sql] val resolvedTEncoder : ExpressionEncoder [T ] =
78- unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes .outerScopes)
78+ unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes .outerScopes)
79+
80+ /**
81+ * The encoder where the expressions used to construct an object from an input row have been
82+ * bound to the ordinals of the given schema.
83+ */
84+ private [sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
7985
8086 private implicit def classTag = resolvedTEncoder.clsTag
8187
@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
8995 override def schema : StructType = resolvedTEncoder.schema
9096
9197 /**
92- * Prints the schema of the underlying [[DataFrame ]] to the console in a nice tree format.
98+ * Prints the schema of the underlying [[Dataset ]] to the console in a nice tree format.
9399 * @since 1.6.0
94100 */
95101 override def printSchema (): Unit = toDF().printSchema()
@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
111117 * ************* */
112118
113119 /**
114- * Returns a new ` Dataset` where each record has been mapped on to the specified type. The
120+ * Returns a new [[ Dataset ]] where each record has been mapped on to the specified type. The
115121 * method used to map columns depend on the type of `U`:
116122 * - When `U` is a class, fields for the class will be mapped to columns of the same name
117123 * (case sensitivity is determined by `spark.sql.caseSensitive`)
@@ -145,23 +151,20 @@ class Dataset[T] private[sql](
145151 def toDF (): DataFrame = DataFrame (sqlContext, logicalPlan)
146152
147153 /**
148- * Returns this Dataset.
154+ * Returns this [[ Dataset ]] .
149155 * @since 1.6.0
150156 */
151157 // This is declared with parentheses to prevent the Scala compiler from treating
152158 // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
153159 def toDS (): Dataset [T ] = this
154160
155161 /**
156- * Converts this Dataset to an RDD.
162+ * Converts this [[ Dataset ]] to an [[ RDD ]] .
157163 * @since 1.6.0
158164 */
159165 def rdd : RDD [T ] = {
160- val tEnc = resolvedTEncoder
161- val input = queryExecution.analyzed.output
162166 queryExecution.toRdd.mapPartitions { iter =>
163- val bound = tEnc.bind(input)
164- iter.map(bound.fromRow)
167+ iter.map(boundTEncoder.fromRow)
165168 }
166169 }
167170
@@ -189,15 +192,15 @@ class Dataset[T] private[sql](
189192 def show (numRows : Int ): Unit = show(numRows, truncate = true )
190193
191194 /**
192- * Displays the top 20 rows of [[DataFrame ]] in a tabular form. Strings more than 20 characters
195+ * Displays the top 20 rows of [[Dataset ]] in a tabular form. Strings more than 20 characters
193196 * will be truncated, and all cells will be aligned right.
194197 *
195198 * @since 1.6.0
196199 */
197200 def show (): Unit = show(20 )
198201
199202 /**
200- * Displays the top 20 rows of [[DataFrame ]] in a tabular form.
203+ * Displays the top 20 rows of [[Dataset ]] in a tabular form.
201204 *
202205 * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
203206 * be truncated and all cells will be aligned right
@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
207210 def show (truncate : Boolean ): Unit = show(20 , truncate)
208211
209212 /**
210- * Displays the [[DataFrame ]] in a tabular form. For example:
213+ * Displays the [[Dataset ]] in a tabular form. For example:
211214 * {{{
212215 * year month AVG('Adj Close) MAX('Adj Close)
213216 * 1980 12 0.503218 0.595103
@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
291294
292295 /**
293296 * (Scala-specific)
294- * Returns a new [[Dataset ]] that contains the result of applying `func` to each element .
297+ * Returns a new [[Dataset ]] that contains the result of applying `func` to each partition .
295298 * @since 1.6.0
296299 */
297300 def mapPartitions [U : Encoder ](func : Iterator [T ] => Iterator [U ]): Dataset [U ] = {
@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
307310
308311 /**
309312 * (Java-specific)
310- * Returns a new [[Dataset ]] that contains the result of applying `func` to each element .
313+ * Returns a new [[Dataset ]] that contains the result of applying `func` to each partition .
311314 * @since 1.6.0
312315 */
313316 def mapPartitions [U ](f : MapPartitionsFunction [T , U ], encoder : Encoder [U ]): Dataset [U ] = {
@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
341344
342345 /**
343346 * (Scala-specific)
344- * Runs `func` on each element of this Dataset.
347+ * Runs `func` on each element of this [[ Dataset ]] .
345348 * @since 1.6.0
346349 */
347350 def foreach (func : T => Unit ): Unit = rdd.foreach(func)
348351
349352 /**
350353 * (Java-specific)
351- * Runs `func` on each element of this Dataset.
354+ * Runs `func` on each element of this [[ Dataset ]] .
352355 * @since 1.6.0
353356 */
354357 def foreach (func : ForeachFunction [T ]): Unit = foreach(func.call(_))
355358
356359 /**
357360 * (Scala-specific)
358- * Runs `func` on each partition of this Dataset.
361+ * Runs `func` on each partition of this [[ Dataset ]] .
359362 * @since 1.6.0
360363 */
361364 def foreachPartition (func : Iterator [T ] => Unit ): Unit = rdd.foreachPartition(func)
362365
363366 /**
364367 * (Java-specific)
365- * Runs `func` on each partition of this Dataset.
368+ * Runs `func` on each partition of this [[ Dataset ]] .
366369 * @since 1.6.0
367370 */
368371 def foreachPartition (func : ForeachPartitionFunction [T ]): Unit =
@@ -374,27 +377,27 @@ class Dataset[T] private[sql](
374377
375378 /**
376379 * (Scala-specific)
377- * Reduces the elements of this Dataset using the specified binary function. The given function
380+ * Reduces the elements of this [[ Dataset ]] using the specified binary function. The given `func`
378381 * must be commutative and associative or the result may be non-deterministic.
379382 * @since 1.6.0
380383 */
381384 def reduce (func : (T , T ) => T ): T = rdd.reduce(func)
382385
383386 /**
384387 * (Java-specific)
385- * Reduces the elements of this Dataset using the specified binary function. The given function
388+ * Reduces the elements of this Dataset using the specified binary function. The given `func`
386389 * must be commutative and associative or the result may be non-deterministic.
387390 * @since 1.6.0
388391 */
389392 def reduce (func : ReduceFunction [T ]): T = reduce(func.call(_, _))
390393
391394 /**
392395 * (Scala-specific)
393- * Returns a [[GroupedDataset ]] where the data is grouped by the given key function .
396+ * Returns a [[GroupedDataset ]] where the data is grouped by the given key `func` .
394397 * @since 1.6.0
395398 */
396399 def groupBy [K : Encoder ](func : T => K ): GroupedDataset [K , T ] = {
397- val inputPlan = queryExecution.analyzed
400+ val inputPlan = logicalPlan
398401 val withGroupingKey = AppendColumns (func, resolvedTEncoder, inputPlan)
399402 val executed = sqlContext.executePlan(withGroupingKey)
400403
@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
429432
430433 /**
431434 * (Java-specific)
432- * Returns a [[GroupedDataset ]] where the data is grouped by the given key function .
435+ * Returns a [[GroupedDataset ]] where the data is grouped by the given key `func` .
433436 * @since 1.6.0
434437 */
435- def groupBy [K ](f : MapFunction [T , K ], encoder : Encoder [K ]): GroupedDataset [K , T ] =
436- groupBy(f .call(_))(encoder)
438+ def groupBy [K ](func : MapFunction [T , K ], encoder : Encoder [K ]): GroupedDataset [K , T ] =
439+ groupBy(func .call(_))(encoder)
437440
438441 /* ****************** *
439442 * Typed Relational *
440443 * ****************** */
441444
442445 /**
443- * Selects a set of column based expressions.
446+ * Returns a new [[ DataFrame ]] by selecting a set of column based expressions.
444447 * {{{
445448 * df.select($"colA", $"colB" + 1)
446449 * }}}
@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
464467 sqlContext,
465468 Project (
466469 c1.withInputType(
467- resolvedTEncoder.bind(queryExecution.analyzed.output) ,
468- queryExecution.analyzed .output).named :: Nil ,
470+ boundTEncoder ,
471+ logicalPlan .output).named :: Nil ,
469472 logicalPlan))
470473 }
471474
@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
477480 protected def selectUntyped (columns : TypedColumn [_, _]* ): Dataset [_] = {
478481 val encoders = columns.map(_.encoder)
479482 val namedColumns =
480- columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed .output).named)
483+ columns.map(_.withInputType(resolvedTEncoder, logicalPlan .output).named)
481484 val execution = new QueryExecution (sqlContext, Project (namedColumns, logicalPlan))
482485
483486 new Dataset (sqlContext, execution, ExpressionEncoder .tuple(encoders))
@@ -654,25 +657,22 @@ class Dataset[T] private[sql](
654657 * Returns an array that contains all the elements in this [[Dataset ]].
655658 *
656659 * Running collect requires moving all the data into the application's driver process, and
657- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
660+ * doing so on a very large [[ Dataset ]] can crash the driver process with OutOfMemoryError.
658661 *
659662 * For Java API, use [[collectAsList ]].
660663 * @since 1.6.0
661664 */
662665 def collect (): Array [T ] = {
663666 // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
664667 // to convert the rows into objects of type T.
665- val tEnc = resolvedTEncoder
666- val input = queryExecution.analyzed.output
667- val bound = tEnc.bind(input)
668- queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
668+ queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
669669 }
670670
671671 /**
672672 * Returns an array that contains all the elements in this [[Dataset ]].
673673 *
674674 * Running collect requires moving all the data into the application's driver process, and
675- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
675+ * doing so on a very large [[ Dataset ]] can crash the driver process with OutOfMemoryError.
676676 *
677677 * For Java API, use [[collectAsList ]].
678678 * @since 1.6.0
@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
683683 * Returns the first `num` elements of this [[Dataset ]] as an array.
684684 *
685685 * Running take requires moving data into the application's driver process, and doing so with
686- * a very large `n ` can crash the driver process with OutOfMemoryError.
686+ * a very large `num ` can crash the driver process with OutOfMemoryError.
687687 * @since 1.6.0
688688 */
689689 def take (num : Int ): Array [T ] = withPlan(Limit (Literal (num), _)).collect()
@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
692692 * Returns the first `num` elements of this [[Dataset ]] as an array.
693693 *
694694 * Running take requires moving data into the application's driver process, and doing so with
695- * a very large `n ` can crash the driver process with OutOfMemoryError.
695+ * a very large `num ` can crash the driver process with OutOfMemoryError.
696696 * @since 1.6.0
697697 */
698698 def takeAsList (num : Int ): java.util.List [T ] = java.util.Arrays .asList(take(num) : _* )
0 commit comments