@@ -22,22 +22,22 @@ import java.lang.reflect.Method
2222import java .security .PrivilegedExceptionAction
2323import java .util .{Arrays , Comparator }
2424
25+ import scala .collection .JavaConversions ._
26+ import scala .concurrent .duration ._
27+ import scala .language .postfixOps
28+
2529import com .google .common .primitives .Longs
2630import org .apache .hadoop .conf .Configuration
27- import org .apache .hadoop .fs .{FileStatus , FileSystem , Path , PathFilter }
2831import org .apache .hadoop .fs .FileSystem .Statistics
32+ import org .apache .hadoop .fs .{FileStatus , FileSystem , Path , PathFilter }
2933import org .apache .hadoop .hdfs .security .token .delegation .DelegationTokenIdentifier
3034import org .apache .hadoop .mapred .JobConf
3135import org .apache .hadoop .mapreduce .JobContext
3236import org .apache .hadoop .security .{Credentials , UserGroupInformation }
3337
34- import org .apache .spark .{Logging , SparkConf , SparkException }
3538import org .apache .spark .annotation .DeveloperApi
3639import org .apache .spark .util .Utils
37-
38- import scala .collection .JavaConversions ._
39- import scala .concurrent .duration ._
40- import scala .language .postfixOps
40+ import org .apache .spark .{Logging , SparkConf , SparkException }
4141
4242/**
4343 * :: DeveloperApi ::
@@ -199,13 +199,36 @@ class SparkHadoopUtil extends Logging {
199199 * that file.
200200 */
201201 def listLeafStatuses (fs : FileSystem , basePath : Path ): Seq [FileStatus ] = {
202- def recurse (path : Path ): Array [FileStatus ] = {
203- val (directories, leaves) = fs.listStatus(path).partition(_.isDir)
204- leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath))
202+ listLeafStatuses(fs, fs.getFileStatus(basePath))
203+ }
204+
205+ /**
206+ * Get [[FileStatus ]] objects for all leaf children (files) under the given base path. If the
207+ * given path points to a file, return a single-element collection containing [[FileStatus ]] of
208+ * that file.
209+ */
210+ def listLeafStatuses (fs : FileSystem , baseStatus : FileStatus ): Seq [FileStatus ] = {
211+ def recurse (status : FileStatus ): Seq [FileStatus ] = {
212+ val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir)
213+ leaves ++ directories.flatMap(f => listLeafStatuses(fs, f))
214+ }
215+
216+ if (baseStatus.isDir) recurse(baseStatus) else Seq (baseStatus)
217+ }
218+
219+ def listLeafDirStatuses (fs : FileSystem , basePath : Path ): Seq [FileStatus ] = {
220+ listLeafDirStatuses(fs, fs.getFileStatus(basePath))
221+ }
222+
223+ def listLeafDirStatuses (fs : FileSystem , baseStatus : FileStatus ): Seq [FileStatus ] = {
224+ def recurse (status : FileStatus ): Seq [FileStatus ] = {
225+ val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir)
226+ val leaves = if (directories.isEmpty) Seq (status) else Seq .empty[FileStatus ]
227+ leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir))
205228 }
206229
207- val baseStatus = fs.getFileStatus(basePath )
208- if (baseStatus.isDir) recurse(basePath) else Array (baseStatus)
230+ assert( baseStatus.isDir )
231+ recurse(baseStatus)
209232 }
210233
211234 /**
@@ -275,7 +298,7 @@ class SparkHadoopUtil extends Logging {
275298 logDebug(text + " matched " + HADOOP_CONF_PATTERN )
276299 val key = matched.substring(13 , matched.length() - 1 ) // remove ${hadoopconf- .. }
277300 val eval = Option [String ](hadoopConf.get(key))
278- .map { value =>
301+ .map { value =>
279302 logDebug(" Substituted " + matched + " with " + value)
280303 text.replace(matched, value)
281304 }
0 commit comments