Skip to content

Commit 42890f2

Browse files
authored
Python evaluator module fix (#863)
* Python evaluator module fix Remove redundant and unnecessary sortings Refactor get_top_k_items to return DataFrame with 'rank' column same as pyspark's * Update test to catch corner case
1 parent 7662798 commit 42890f2

2 files changed

Lines changed: 23 additions & 23 deletions

File tree

reco_utils/evaluation/python_evaluation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def merge_ranking_true_pred(
363363
364364
Returns:
365365
pd.DataFrame, pd.DataFrame, int:
366-
DataFrame of recommendation hits
366+
DataFrame of recommendation hits, sorted by `col_user` and `rank`
367367
DataFrmae of hit counts vs actual relevant items per user
368368
number of unique user ids
369369
"""
@@ -390,9 +390,6 @@ def merge_ranking_true_pred(
390390
col_rating=col_prediction,
391391
k=top_k,
392392
)
393-
df_hit["rank"] = df_hit.groupby(col_user)[col_prediction].rank(
394-
method="first", ascending=False
395-
)
396393
df_hit = pd.merge(df_hit, rating_true_common, on=[col_user, col_item])[
397394
[col_user, col_item, "rank"]
398395
]
@@ -559,7 +556,7 @@ def ndcg_at_k(
559556
# relevance in this case is always 1
560557
df_dcg["dcg"] = 1 / np.log1p(df_dcg["rank"])
561558
# sum up discount gained to get discount cumulative gain
562-
df_dcg = df_dcg.groupby(col_user, as_index=False).agg({"dcg": "sum"})
559+
df_dcg = df_dcg.groupby(col_user, as_index=False, sort=False).agg({"dcg": "sum"})
563560
# calculate ideal discounted cumulative gain
564561
df_ndcg = pd.merge(df_dcg, df_hit_count, on=[col_user])
565562
df_ndcg["idcg"] = df_ndcg["actual"].apply(
@@ -625,8 +622,8 @@ def map_at_k(
625622
return 0.0
626623

627624
# calculate reciprocal rank of items for each user and sum them up
628-
df_hit_sorted = df_hit.sort_values([col_user, "rank"])
629-
df_hit_sorted["rr"] = (df_hit.groupby(col_user).cumcount() + 1) / df_hit["rank"]
625+
df_hit_sorted = df_hit.copy()
626+
df_hit_sorted["rr"] = (df_hit_sorted.groupby(col_user).cumcount() + 1) / df_hit_sorted["rank"]
630627
df_hit_sorted = df_hit_sorted.groupby(col_user).agg({"rr": "sum"}).reset_index()
631628

632629
df_merge = pd.merge(df_hit_sorted, df_hit_count, on=col_user)
@@ -651,14 +648,17 @@ def get_top_k_items(
651648
k (int): number of items for each user
652649
653650
Returns:
654-
pd.DataFrame: DataFrame of top k items for each user
651+
pd.DataFrame: DataFrame of top k items for each user, sorted by `col_user` and `rank`
655652
"""
656-
657-
return (
653+
# Sort dataframe by col_user and (top k) col_rating
654+
top_k_items = (
658655
dataframe.groupby(col_user, as_index=False)
659656
.apply(lambda x: x.nlargest(k, col_rating))
660657
.reset_index(drop=True)
661658
)
659+
# Add ranks
660+
top_k_items["rank"] = top_k_items.groupby(col_user, sort=False).cumcount() + 1
661+
return top_k_items
662662

663663

664664
"""Function name and function mapper.

tests/unit/test_python_evaluation.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535
def rating_true():
3636
return pd.DataFrame(
3737
{
38-
DEFAULT_USER_COL: [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
38+
DEFAULT_USER_COL: [1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,],
3939
DEFAULT_ITEM_COL: [
40-
1,
41-
2,
4240
3,
4341
1,
4442
4,
@@ -55,8 +53,10 @@ def rating_true():
5553
12,
5654
13,
5755
14,
56+
1,
57+
2,
5858
],
59-
DEFAULT_RATING_COL: [5, 4, 3, 5, 5, 3, 3, 1, 5, 5, 5, 4, 4, 3, 3, 3, 2, 1],
59+
DEFAULT_RATING_COL: [3, 5, 5, 3, 3, 1, 5, 5, 5, 4, 4, 3, 3, 3, 2, 1, 5, 4,],
6060
}
6161
)
6262

@@ -65,10 +65,8 @@ def rating_true():
6565
def rating_pred():
6666
return pd.DataFrame(
6767
{
68-
DEFAULT_USER_COL: [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
68+
DEFAULT_USER_COL: [1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,],
6969
DEFAULT_ITEM_COL: [
70-
3,
71-
10,
7270
12,
7371
10,
7472
3,
@@ -85,10 +83,10 @@ def rating_pred():
8583
2,
8684
11,
8785
14,
86+
3,
87+
10,
8888
],
8989
DEFAULT_PREDICTION_COL: [
90-
14,
91-
13,
9290
12,
9391
14,
9492
13,
@@ -105,8 +103,10 @@ def rating_pred():
105103
7,
106104
6,
107105
5,
106+
14,
107+
13,
108108
],
109-
DEFAULT_RATING_COL: [5, 4, 3, 5, 5, 3, 3, 1, 5, 5, 5, 4, 4, 3, 3, 3, 2, 1],
109+
DEFAULT_RATING_COL: [3, 5, 5, 3, 3, 1, 5, 5, 5, 4, 4, 3, 3, 3, 2, 1, 5, 4,],
110110
}
111111
)
112112

@@ -115,11 +115,9 @@ def rating_pred():
115115
def rating_nohit():
116116
return pd.DataFrame(
117117
{
118-
DEFAULT_USER_COL: [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
118+
DEFAULT_USER_COL: [1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,],
119119
DEFAULT_ITEM_COL: [100] * 18,
120120
DEFAULT_PREDICTION_COL: [
121-
14,
122-
13,
123121
12,
124122
14,
125123
13,
@@ -136,6 +134,8 @@ def rating_nohit():
136134
7,
137135
6,
138136
5,
137+
14,
138+
13,
139139
],
140140
}
141141
)

0 commit comments

Comments
 (0)