Skip to content

Commit d242c0e

Browse files
authored
Merge pull request #54 from JuliaAI/kind-of-learner
LearnAPI 2.0
2 parents 1223f3f + 77de486 commit d242c0e

32 files changed

+625
-315
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
with:
3030
version: ${{ matrix.version }}
3131
arch: ${{ matrix.arch }}
32-
- uses: actions/cache@v1
32+
- uses: julia-actions/cache@v1
3333
env:
3434
cache-name: cache-artifacts
3535
with:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ sandbox/
88
/docs/site/
99
/docs/Manifest.toml
1010
.vscode
11+
LocalPreferences.toml

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
authors = ["Anthony D. Blaom <[email protected]>"]
12
name = "LearnAPI"
23
uuid = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
3-
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "1.0.1"
4+
version = "2.0.0"
55

66
[compat]
7+
Preferences = "1.5.0"
78
julia = "1.10"
89

10+
[deps]
11+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
12+
913
[extras]
1014
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1115

1216
[targets]
13-
test = ["Test",]
17+
test = ["Test"]

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Here `learner` specifies the configuration the algorithm (the hyperparameters) w
2424
`model` stores learned parameters and any byproducts of algorithm execution.
2525

2626
LearnAPI.jl is mostly method stubs and lots of documentation. It does not provide
27-
meta-algorithms, such as cross-validation or hyperparameter optimization, but does aim to
27+
meta-algorithms, such as cross-validation, hyperparameter optimization, or model composition, but does aim to
2828
support such algorithms.
2929

3030
## Related packages
@@ -37,6 +37,8 @@ support such algorithms.
3737

3838
- [StatisticalMeasures.jl](https://github.com/JuliaAI/StatisticalMeasures.jl): Package providing metrics, compatible with LearnAPI.jl
3939

40+
- [StatsModels.jl](https://github.com/JuliaStats/StatsModels.jl): Provides the R-style formula implementation of data preprocessing handled by [LearnDataFrontEnds.jl](https://github.com/JuliaAI/LearnDataFrontEnds.jl)
41+
4042
### Selected packages providing alternative API's
4143

4244
The following alphabetical list of packages provide public base API's. Some provide

ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"Common Implementation Patterns". As real-world implementations roll out, we could
1515
increasingly point to those instead, to conserve effort
1616
- [x] regression
17-
- [ ] classification
17+
- [x] classification
1818
- [ ] clustering
1919
- [x] gradient descent
2020
- [x] iterative algorithms

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
44
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
5-
LearnTestAPI = "3111ed91-c4f2-40e7-bb19-7f6c618409b8"
65
MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
76
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
87
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ using Documenter
22
using LearnAPI
33
using ScientificTypesBase
44
using DocumenterInterLinks
5-
using LearnTestAPI
5+
# using LearnTestAPI
66

77
const REPO = Remotes.GitHub("JuliaAI", "LearnAPI.jl")
88

99
makedocs(
10-
modules=[LearnAPI, LearnTestAPI],
10+
modules=[LearnAPI, ], #LearnTestAPI],
1111
format=Documenter.HTML(
1212
prettyurls = true,#get(ENV, "CI", nothing) == "true",
1313
collapselevel = 1,
@@ -18,6 +18,7 @@ makedocs(
1818
"Reference" => [
1919
"Overview" => "reference.md",
2020
"Public Names" => "list_of_public_names.md",
21+
"Kinds of learner" => "kinds_of_learner.md",
2122
"fit/update" => "fit_update.md",
2223
"predict/transform" => "predict_transform.md",
2324
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",

docs/src/anatomy_of_an_implementation.md

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Anatomy of an Implementation
22

3-
The core LearnAPI.jl pattern looks like this:
3+
LearnAPI.jl supports three core patterns. The default pattern, known as the
4+
[`LearnAPI.Descriminative`](@ref) pattern, looks like this:
45

56
```julia
67
model = fit(learner, data)
@@ -10,38 +11,51 @@ predict(model, newdata)
1011
Here `learner` specifies [hyperparameters](@ref hyperparameters), while `model` stores
1112
learned parameters and any byproducts of algorithm execution.
1213

13-
Variations on this pattern:
14+
[Transformers](@ref) ordinarily implement `transform` instead of `predict`. For more on
15+
`predict` versus `transform`, see [Predict or transform?](@ref)
1416

15-
- [Transformers](@ref) ordinarily implement `transform` instead of `predict`. For more on
16-
`predict` versus `transform`, see [Predict or transform?](@ref)
17+
Two other `fit`/`predict`/`transform` patterns supported by LearnAPI.jl are:
18+
[`LearnAPI.Generative`](@ref) which has the form:
1719

18-
- ["Static" (non-generalizing) algorithms](@ref static_algorithms), which includes some
19-
simple transformers and some clustering algorithms, have a `fit` that consumes no
20-
`data`. Instead `predict` or `transform` does the heavy lifting.
20+
```julia
21+
model = fit(learner, data)
22+
predict(model) # a single distribution, for example
23+
```
2124

22-
- In [density estimation](@ref density_estimation), the `newdata` argument in `predict` is
23-
missing.
25+
and [`LearnAPI.Static`](@ref), which looks like this:
26+
27+
```julia
28+
model = fit(learner) # no `data` argument
29+
predict(model, data) # may mutate `model` to record byproducts of computation
30+
```
2431

25-
These are the basic possibilities.
32+
Do not read too much into the names for these patterns, which are formalized [here](@ref kinds_of_learner). Use may not always correspond to prior associations.
2633

27-
Elaborating on the core pattern above, this tutorial details an implementation of the
28-
LearnAPI.jl for naive [ridge regression](https://en.wikipedia.org/wiki/Ridge_regression)
29-
with no intercept. The kind of workflow we want to enable has been previewed in [Sample
30-
workflow](@ref). Readers can also refer to the [demonstration](@ref workflow) of the
31-
implementation given later.
34+
Elaborating on the common `Descriminative` pattern above, this tutorial details an
35+
implementation of the LearnAPI.jl for naive [ridge
36+
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
37+
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
38+
refer to the [demonstration](@ref workflow) of the implementation given later.
3239

33-
## A basic implementation
40+
!!! tip "Quick Start for new implementations"
3441

35-
See [here](@ref code) for code without explanations.
42+
1. From this tutorial, read at least "[A basic implementation](@ref)" below.
43+
1. Looking over the examples in "[Common Implementation Patterns](@ref patterns)", identify the appropriate core learner pattern above for your algorithm.
44+
1. Implement `fit` (probably following an existing example). Read the [`fit`](@ref) document string to see what else may need to be implemented, paying particular attention to the "New implementations" section.
45+
3. Rinse and repeat with each new method implemented.
46+
4. Identify any additional [learner traits](@ref traits) that have appropriate overloadings; use the [`@trait`](@ref) macro to define these in one block.
47+
5. Ensure your implementation includes the compulsory method [`LearnAPI.learner`](@ref) and compulsory traits [`LearnAPI.constructor`](@ref) and [`LearnAPI.functions`](@ref). Read and apply "[Testing your implementation](@ref)".
3648

37-
We suppose our algorithm's `fit` method consumes data in the form `(X, y)`, where
38-
`X` is a suitable table¹ (the features) and `y` a vector (the target).
49+
If you get stuck, refer back to this tutorial and the [Reference](@ref reference) sections.
3950

40-
!!! important
4151

42-
Implementations wishing to support other data
43-
patterns may need to take additional steps explained under
44-
[Other data patterns](@ref di) below.
52+
## A basic implementation
53+
54+
See [here](@ref code) for code without explanations.
55+
56+
Let us suppose our algorithm's `fit` method is to consume data in the form `(X, y)`, where
57+
`X` is a suitable table¹ (the features, a.k.a., covariates or predictors) and `y` a vector
58+
(the target, a.k.a., labels or response).
4559

4660
The first line below imports the lightweight package LearnAPI.jl whose methods we will be
4761
extending. The second imports libraries needed for the core algorithm.
@@ -110,7 +124,7 @@ Note that we also include `learner` in the struct, for it must be possible to re
110124
The implementation of `fit` looks like this:
111125

112126
```@example anatomy
113-
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
127+
function LearnAPI.fit(learner::Ridge, data; verbosity=LearnAPI.default_verbosity())
114128
X, y = data
115129
116130
# data preprocessing:
@@ -158,6 +172,22 @@ If the kind of proxy is omitted, as in `predict(model, Xnew)`, then a fallback g
158172
first element of the tuple returned by [`LearnAPI.kinds_of_proxy(learner)`](@ref), which
159173
we overload appropriately below.
160174

175+
### Data deconstructors: `target` and `features`
176+
177+
LearnAPI.jl is flexible about the form of training `data`. However, to buy into
178+
meta-functionality, such as cross-validation, we'll need to say something about the
179+
structure of this data. We implement [`LearnAPI.target`](@ref) to say what
180+
part of the data constitutes a [target variable](@ref proxy), and
181+
[`LearnAPI.features`](@ref) to say what are the features (valid `newdata` in a
182+
`predict(model, newdata)` call):
183+
184+
```@example anatomy
185+
LearnAPI.target(learner::Ridge, (X, y)) = y
186+
LearnAPI.features(learner::Ridge, (X, y)) = X
187+
```
188+
189+
Another data deconstructor, for learners that support per-observation weights in training,
190+
is [`LearnAPI.weights`](@ref).
161191

162192
### [Accessor functions](@id af)
163193

@@ -241,15 +271,11 @@ the *type* of the argument.
241271
### The `functions` trait
242272

243273
The last trait, `functions`, above returns a list of all LearnAPI.jl methods that can be
244-
meaningfully applied to the learner or associated model, with the exception of traits. You
245-
always include the first five you see here: `fit`, `learner`, `clone` ,`strip`,
246-
`obs`. Here [`clone`](@ref) is a utility function provided by LearnAPI that you never
247-
overload, while [`obs`](@ref) is discussed under [Providing a separate data front
248-
end](@ref) below and is always included because it has a meaningful fallback. The
249-
`features` method, here provided by a fallback, articulates how the features `X` can be
250-
extracted from the training data `(X, y)`. We must also include `target` here to flag our
251-
model as supervised; again the method itself is provided by a fallback valid in the
252-
present case.
274+
meaningfully applied to the learner or the output of `fit` (denoted `model` above), with
275+
the exception of traits. You always include the first five you see here: `fit`, `learner`,
276+
`clone` ,`strip`, `obs`. Here [`clone`](@ref) is a utility function provided by LearnAPI
277+
that you never overload, while [`obs`](@ref) is discussed under [Providing a separate data
278+
front end](@ref) below and is always included because it has a meaningful fallback.
253279

254280
See [`LearnAPI.functions`](@ref) for a checklist of what the `functions` trait needs to
255281
return.
@@ -340,11 +366,6 @@ assumptions about data from those made above.
340366
under [Providing a separate data front end](@ref) below; or (ii) overload the trait
341367
[`LearnAPI.data_interface`](@ref) to specify a more relaxed data API.
342368

343-
- Where the form of data consumed by `fit` is different from that consumed by
344-
`predict/transform` (as in classical supervised learning) it may be necessary to
345-
explicitly overload the functions [`LearnAPI.features`](@ref) and (if supervised)
346-
[`LearnAPI.target`](@ref). The same holds if overloading [`obs`](@ref); see below.
347-
348369

349370
## Providing a separate data front end
350371

@@ -414,7 +435,7 @@ The [`obs`](@ref) methods exist to:
414435

415436
!!! important
416437

417-
While many new learner implementations will want to adopt a canned data front end, such as those provided by [LearnDataFrontEnds.jl](https://juliaai.github.io/LearnAPI.jl/dev/), we
438+
While many new learner implementations will want to adopt a canned data front end, such as those provided by [LearnDataFrontEnds.jl](https://juliaai.github.io/LearnDataFrontEnds.jl/dev/), we
418439
focus here on a self-contained implementation of `obs` for the ridge example above, to show
419440
how it works.
420441

@@ -448,14 +469,14 @@ newobservations = MLCore.getobs(observations, test_indices)
448469
predict(model, newobservations)
449470
```
450471

451-
which works for any non-static learner implementing `predict`, no matter how one is
452-
supposed to accesses the individual observations of `data` or `newdata`. See also the
453-
demonstration [below](@ref advanced_demo). Furthermore, fallbacks ensure the above pattern
454-
still works if we choose not to implement a front end at all, which is allowed, if
455-
supported `data` and `newdata` already implement `getobs`/`numobs`.
472+
which works for any [`LearnAPI.Descriminative`](@ref) learner implementing `predict`, no
473+
matter how one is supposed to accesses the individual observations of `data` or
474+
`newdata`. See also the demonstration [below](@ref advanced_demo). Furthermore, fallbacks
475+
ensure the above pattern still works if we choose not to implement a front end at all,
476+
which is allowed, if supported `data` and `newdata` already implement `getobs`/`numobs`.
456477

457-
Here we specifically wrap all the preprocessed data into single object, for which we
458-
introduce a new type:
478+
In the ridge regression example we specifically wrap all the preprocessed data into single
479+
object, for which we introduce a new type:
459480

460481
```@example anatomy2
461482
struct RidgeFitObs{T,M<:AbstractMatrix{T}}
@@ -476,13 +497,13 @@ function LearnAPI.obs(::Ridge, data)
476497
end
477498
```
478499

479-
We informally refer to the output of `obs` as "observations" (see [The `obs`
480-
contract](@ref) below). The previous core `fit` signature is now replaced with two
500+
We informally refer to the output of `obs` as "observations" (see "[The `obs`
501+
contract](@ref)" below). The previous core `fit` signature is now replaced with two
481502
methods - one to handle "regular" input, and one to handle the pre-processed data
482503
(observations) which appears first below:
483504

484505
```@example anatomy2
485-
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
506+
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=LearnAPI.default_verbosity())
486507
487508
lambda = learner.lambda
488509
@@ -545,13 +566,10 @@ LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
545566
predict(model, Point(), obs(model, Xnew))
546567
```
547568

548-
### `features` and `target` methods
569+
### Data deconstructors: `features` and `target`
549570

550-
Two methods [`LearnAPI.features`](@ref) and [`LearnAPI.target`](@ref) articulate how
551-
features and target can be extracted from `data` consumed by LearnAPI.jl
552-
methods. Fallbacks provided by LearnAPI.jl sufficed in our basic implementation
553-
above. Here we must explicitly overload them, so that they also handle the output of
554-
`obs(learner, data)`:
571+
These methods must be able to handle any `data` supported by `fit`, which includes the
572+
output of `obs(learner, data)`:
555573

556574
```@example anatomy2
557575
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
@@ -573,7 +591,7 @@ LearnAPI.target(learner::Ridge, data) = LearnAPI.target(learner, obs(learner, da
573591

574592
Since LearnAPI.jl provides fallbacks for `obs` that simply return the unadulterated data
575593
argument, overloading `obs` is optional. This is provided data in publicized
576-
`fit`/`predict` signatures already consists only of objects implement the
594+
`fit`/`predict` signatures already consists only of objects implementing the
577595
[`LearnAPI.RandomAccess`](@ref) interface (most tables¹, arrays³, and tuples thereof).
578596

579597
To opt out of supporting the MLCore.jl interface altogether, an implementation must

docs/src/common_implementation_patterns.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ This guide is intended to be consulted after reading [Anatomy of an Implementati
99
which introduces the main interface objects and terminology.
1010

1111
Although an implementation is defined purely by the methods and traits it implements, many
12-
implementations fall into one (or more) of the following informally understood patterns or
13-
tasks:
12+
implementations fall into one (or more) of the informally understood patterns or tasks
13+
below. While some generally fall into one of the core `Descriminative`, `Generative` or
14+
`Static` patterns detailed [here](@id kinds_of_learner), there are exceptions (such as
15+
clustering, which has both `Descriminative` and `Static` variations).
1416

1517
- [Regression](@ref): Supervised learners for continuous targets
1618

docs/src/examples.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ end
2424
Instantiate a ridge regression learner, with regularization of `lambda`.
2525
"""
2626
Ridge(; lambda=0.1) = Ridge(lambda)
27-
LearnAPI.constructor(::Ridge) = Ridge
2827

2928
# struct for output of `fit`
3029
struct RidgeFitted{T,F}
@@ -33,7 +32,7 @@ struct RidgeFitted{T,F}
3332
named_coefficients::F
3433
end
3534

36-
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
35+
function LearnAPI.fit(learner::Ridge, data; verbosity=LearnAPI.default_verbosity())
3736
X, y = data
3837

3938
# data preprocessing:
@@ -58,6 +57,10 @@ end
5857
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
5958
Tables.matrix(Xnew)*model.coefficients
6059

60+
# data deconstructors:
61+
LearnAPI.target(learner::Ridge, (X, y)) = y
62+
LearnAPI.features(learner::Ridge, (X, y)) = X
63+
6164
# accessor functions:
6265
LearnAPI.learner(model::RidgeFitted) = model.learner
6366
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
@@ -126,7 +129,11 @@ function LearnAPI.obs(::Ridge, data)
126129
end
127130
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
128131

129-
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
132+
function LearnAPI.fit(
133+
learner::Ridge,
134+
observations::RidgeFitObs;
135+
verbosity=LearnAPI.default_verbosity(),
136+
)
130137

131138
lambda = learner.lambda
132139

@@ -160,7 +167,7 @@ LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
160167
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
161168
predict(model, Point(), obs(model, Xnew))
162169

163-
# methods to deconstruct training data:
170+
# training data deconstructors:
164171
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
165172
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
166173
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, obs(learner, data))
@@ -223,7 +230,7 @@ frontend = FrontEnds.Saffron()
223230
LearnAPI.obs(learner::Ridge, data) = FrontEnds.fitobs(learner, data, frontend)
224231
LearnAPI.obs(model::RidgeFitted, data) = obs(model, data, frontend)
225232

226-
function LearnAPI.fit(learner::Ridge, observations::FrontEnds.Obs; verbosity=1)
233+
function LearnAPI.fit(learner::Ridge, observations::FrontEnds.Obs; verbosity=LearnAPI.default_verbosity())
227234

228235
lambda = learner.lambda
229236

0 commit comments

Comments
 (0)