-
Notifications
You must be signed in to change notification settings - Fork 28
Linear make prediction #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
5f4aa46
b46e73b
75e472f
30bde59
c59d9d5
fc74eb0
0f9437b
0a60f7f
96a6dc6
44563df
1de941b
687b58d
f3859a6
ff97bee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -540,3 +540,20 @@ def predict_values(model, x: sparse.csr_matrix) -> np.ndarray: | |
| ], 'csr') | ||
|
|
||
| return (x * model['weights']).A + model['threshold'] | ||
|
|
||
| def predict_topk(preprocessor, x: np.ndarray, top_k: int = 5) -> 'list[tuple(str)]': | ||
| """Make top k predictions from decision values. | ||
|
|
||
| Args: | ||
| preprocessor: The preprocessor object from ``Preprocessor`` API used to load and preprocess the data. | ||
| x (np.ndarray): A matrix of decision values with dimension number of instances * number of classes. | ||
Gordon119 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| top_k (int): Determine how many classes per instance should be predicted. | ||
|
|
||
| Returns: | ||
| list of tuples which contain predicted top k labels. | ||
| """ | ||
| top_k_ind = np.argpartition(x, -top_k)[:, -top_k:] | ||
| pred_result = np.zeros(x.shape) | ||
| np.put_along_axis(pred_result, top_k_ind, 1, -1) | ||
| pred_result[pred_result != 1] = 0 | ||
|
||
| return preprocessor.binarizer.inverse_transform(pred_result) | ||
|
||
Uh oh!
There was an error while loading. Please reload this page.