Skip to content
16 changes: 16 additions & 0 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,19 @@ def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
], 'csr')

return (x * model['weights']).A + model['threshold']

def predict_topk(preprocessor, preds: np.ndarray, top_k: int = 5) -> 'list[tuple(str)]':
"""Make top k predictions from decision values.

Args:
preprocessor: The preprocessor object from ``Preprocessor`` API used to load and preprocess the data.
preds (np.ndarray): A matrix of decision values with dimension number of instances * number of classes.
top_k (int): Determine how many classes per instance should be predicted.

Returns:
list of tuples which contain predicted top k labels.
"""
top_k_ind = np.argpartition(preds, -top_k)[:, -top_k:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k_ind is sorted ascending this way. top_k_ind = np.argpartition(preds, -top_k)[:, :top_k:-1] sorts descending.

pred_result = np.zeros(preds.shape)
np.put_along_axis(pred_result, top_k_ind, 1, -1)
return preprocessor.binarizer.inverse_transform(pred_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we have this binarizer.inverse_transform defined??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we’re using sklearn’s MultiLabelBinarizer, I used this:
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html