Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ abstract class RDD[T: ClassTag](
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
if (fraction < Double.MinValue || fraction > Double.MaxValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use require. i.e.

require(fraction > Double.MinValue && fraction < Double.MaxValue, "...")

Shouldn't you just check for fraction > 0 but < 1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lower bound should be >= 0.0. Sample with replacement can have a faction greater than 1.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @rxin , I'm also a bit confused here, I think the name of the argument is a bit confusing

https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L357

The above line contains a multiplier to ensure that the sampling can return enough sample points in most of cases..(I think so), so the fraction value can actually be larger than 1

also, this value actually determines the mean value of Poisson/Bernoulli distribution

https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L314

throw new Exception("Invalid fraction value:" + fraction)
}
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
Expand Down Expand Up @@ -344,6 +347,10 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}

if (initialCount == 0) {
return new Array[T](0)
}

if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
Expand All @@ -362,7 +369,7 @@ abstract class RDD[T: ClassTag](
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for thei initial size
// this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
val emptySet = data.filter(_ => false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to create an empty RDD?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup do

data.mapPartitions { iter => Iterator.empty }


val sample = emptySet.takeSample(false, 20, 1)
assert(sample.size === 0)
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
Expand Down