Skip to content
17 changes: 17 additions & 0 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,20 @@ def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
], 'csr')

return (x * model['weights']).A + model['threshold']


def get_topk_labels(label_mapping: np.ndarray, preds: np.ndarray, top_k: int = 5) -> 'list[list[str]]':
"""Get top k predictions from decision values.

Args:
label_mapping (np.ndarray): A ndarray of class labels that maps each label to its index.
preds (np.ndarray): A matrix of decision values with dimension number of instances * number of classes.
top_k (int): Determine how many classes per instance should be predicted.

Returns:
list of lists which contain top k labels.
"""
top_k_ind = np.argpartition(preds, -top_k)[:, -top_k:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k_ind is sorted ascending this way. top_k_ind = np.argpartition(preds, -top_k)[:, :top_k:-1] sorts descending.

pred_result = np.zeros(preds.shape)
np.put_along_axis(pred_result, top_k_ind, 1, -1)
return [list(label_mapping.compress(row)) for row in pred_result]
2 changes: 2 additions & 0 deletions libmultilabel/linear/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _generate_label_mapping(self, labels, classes=None):
self.binarizer = MultiLabelBinarizer(
sparse_output=True, classes=classes)
self.binarizer.fit(labels)
self.label_mapping = self.binarizer.classes_


def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
Expand All @@ -157,6 +158,7 @@ def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
data['label'] = data['label'].map(lambda s: s.split())
return data.to_dict('list')


def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]':
"""Read multi-label LIBSVM-format data.

Expand Down