Skip to content
17 changes: 17 additions & 0 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,20 @@ def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
], 'csr')

return (x * model['weights']).A + model['threshold']

def predict_topk(preprocessor, x: np.ndarray, top_k: int = 5) -> 'list[tuple(str)]':
"""Make top k predictions from decision values.

Args:
preprocessor: The preprocessor object from ``Preprocessor`` API used to load and preprocess the data.
x (np.ndarray): A matrix of decision values with dimension number of instances * number of classes.
top_k (int): Determine how many classes per instance should be predicted.

Returns:
list of tuples which contain predicted top k labels.
"""
top_k_ind = np.argpartition(x, -top_k)[:, -top_k:]
pred_result = np.zeros(x.shape)
np.put_along_axis(pred_result, top_k_ind, 1, -1)
pred_result[pred_result != 1] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why line 588 as pred_result is initialized as the zero matrix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inverse_transform has to take predictions in binary labels, so I first set pred_result to zero matrix and set the top k labels to 1. I realized that line 558 is redundant and removed it.

return preprocessor.binarizer.inverse_transform(pred_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we have this binarizer.inverse_transform defined??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we’re using sklearn’s MultiLabelBinarizer, I used this:
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html