diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala new file mode 100644 index 000000000000..e5a4d82b9874 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.lib + +import org.apache.spark.graphx._ + +object ShortestPaths { + type SPMap = Map[VertexId, Int] // map of landmarks -> minimum distance to landmark + def SPMap(x: (VertexId, Int)*) = Map(x: _*) + def increment(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) } + def plus(spmap1: SPMap, spmap2: SPMap): SPMap = + (spmap1.keySet ++ spmap2.keySet).map{ + k => k -> scala.math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue)) + }.toMap + + /** + * Compute the shortest paths to each landmark for each vertex and + * return an RDD with the map of landmarks to their shortest-path + * lengths. + * + * @tparam VD the shortest paths map for the vertex + * @tparam ED the incremented shortest-paths map of the originating + * vertex (discarded in the computation) + * + * @param graph the graph for which to compute the shortest paths + * @param landmarks the list of landmark vertex ids + * + * @return a graph with vertex attributes containing a map of the + * shortest paths to each landmark + */ + def run[VD, ED](graph: Graph[VD, ED], landmarks: Seq[VertexId]) + (implicit m1: Manifest[VD], m2: Manifest[ED]): Graph[SPMap, SPMap] = { + + val spGraph = graph + .mapVertices{ (vid, attr) => + if (landmarks.contains(vid)) SPMap(vid -> 0) + else SPMap() + } + .mapTriplets{ edge => edge.srcAttr } + + val initialMessage = SPMap() + + def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = { + plus(attr, msg) + } + + def sendMessage(edge: EdgeTriplet[SPMap, SPMap]): Iterator[(VertexId, SPMap)] = { + val newAttr = increment(edge.srcAttr) + if (edge.dstAttr != plus(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) + else Iterator.empty + } + + def messageCombiner(s1: SPMap, s2: SPMap): SPMap = { + plus(s1, s2) + } + + Pregel(spGraph, initialMessage)( + vertexProgram, sendMessage, messageCombiner) + } + +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala new file mode 100644 index 000000000000..d095d3e791b5 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.lib + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.graphx._ +import org.apache.spark.graphx.lib._ +import org.apache.spark.graphx.util.GraphGenerators +import org.apache.spark.rdd._ + +class ShortestPathsSuite extends FunSuite with LocalSparkContext { + + test("Shortest Path Computations") { + withSpark { sc => + val shortestPaths = Set((1,Map(1 -> 0, 4 -> 2)), (2,Map(1 -> 1, 4 -> 2)), (3,Map(1 -> 2, 4 -> 1)), + (4,Map(1 -> 2, 4 -> 0)), (5,Map(1 -> 1, 4 -> 1)), (6,Map(1 -> 3, 4 -> 1))) + val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)).flatMap{ case e => Seq(e, e.swap) } + val edges = sc.parallelize(edgeSeq).map { case (v1, v2) => (v1.toLong, v2.toLong) } + val graph = Graph.fromEdgeTuples(edges, 1) + val landmarks = Seq(1, 4).map(_.toLong) + val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { case (v, spMap) => (v, spMap.mapValues(_.get)) } + assert(results.toSet === shortestPaths) + } + } + +}