@@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
8282 Row(user=1, item=0, prediction=2.6258413791656494)
8383 >>> predictions[2]
8484 Row(user=2, item=0, prediction=-1.5018409490585327)
85+ >>> user_recs = model.recommendForAllUsers(3)
86+ >>> user_recs.where(user_recs.user == 0)\
87+ .select("recommendations.item", "recommendations.rating").collect()
88+ [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
89+ >>> item_recs = model.recommendForAllItems(3)
90+ >>> item_recs.where(item_recs.item == 2)\
91+ .select("recommendations.user", "recommendations.rating").collect()
92+ [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
8593 >>> als_path = temp_path + "/als"
8694 >>> als.save(als_path)
8795 >>> als2 = ALS.load(als_path)
@@ -384,6 +392,28 @@ def itemFactors(self):
384392 """
385393 return self ._call_java ("itemFactors" )
386394
395+ @since ("2.2.0" )
396+ def recommendForAllUsers (self , numItems ):
397+ """
398+ Returns top `numItems` items recommended for each user, for all users.
399+
400+ :param numItems: max number of recommendations for each user
401+ :return: a DataFrame of (userCol, recommendations), where recommendations are
402+ stored as an array of (itemCol, rating) Rows.
403+ """
404+ return self ._call_java ("recommendForAllUsers" , numItems )
405+
406+ @since ("2.2.0" )
407+ def recommendForAllItems (self , numUsers ):
408+ """
409+ Returns top `numUsers` users recommended for each item, for all items.
410+
411+ :param numItems: max number of recommendations for each item
412+ :return: a DataFrame of (itemCol, recommendations), where recommendations are
413+ stored as an array of (userCol, rating) Rows.
414+ """
415+ return self ._call_java ("recommendForAllItems" , numUsers )
416+
387417
388418if __name__ == "__main__" :
389419 import doctest
0 commit comments