@@ -10,9 +10,9 @@ export loess, predict
1010include (" kd.jl" )
1111
1212
13- mutable struct LoessModel{T <: AbstractFloat }
14- xs:: AbstractMatrix {T} # An n by m predictor matrix containing n observations from m predictors
15- ys:: AbstractVector {T} # A length n response vector
13+ struct LoessModel{T <: AbstractFloat }
14+ xs:: Matrix {T} # An n by m predictor matrix containing n observations from m predictors
15+ ys:: Vector {T} # A length n response vector
1616 predictions_and_gradients:: Dict{Vector{T}, Vector{T}} # kd-tree vertexes mapped to prediction and gradient at each vertex
1717 kdtree:: KDTree{T}
1818end
@@ -44,6 +44,10 @@ function loess(
4444 degree:: Integer = 2 ,
4545 cell:: AbstractFloat = 0.2
4646) where T<: AbstractFloat
47+
48+ Base. require_one_based_indexing (xs)
49+ Base. require_one_based_indexing (ys)
50+
4751 if size (xs, 1 ) != size (ys, 1 )
4852 throw (DimensionMismatch (" Predictor and response arrays must of the same length" ))
4953 end
@@ -80,8 +84,12 @@ function loess(
8084 end
8185
8286 # distance to each point
83- for i in 1 : n
84- ds[i] = euclidean (vec (vert), vec (xs[i,:]))
87+ @inbounds for i in 1 : n
88+ s = zero (T)
89+ for j in 1 : m
90+ s += (xs[i, j] - vert[j])^ 2
91+ end
92+ ds[i] = sqrt (s)
8593 end
8694
8795 # find the q closest points
@@ -128,7 +136,7 @@ function loess(
128136 ]
129137 end
130138
131- LoessModel {T} (xs, ys, predictions_and_gradients, kdtree)
139+ LoessModel (xs, ys, predictions_and_gradients, kdtree)
132140end
133141
134142loess (xs:: AbstractVector{T} , ys:: AbstractVector{T} ; kwargs... ) where {T<: AbstractFloat } =
@@ -153,50 +161,44 @@ end
153161# Returns:
154162# A length n' vector of predicted response values.
155163#
156- function predict (model:: LoessModel , z:: Real )
157- predict (model, [z])
158- end
159-
160- function predict (model:: LoessModel , zs:: AbstractVector )
161-
162- Base. require_one_based_indexing (zs)
164+ function predict (model:: LoessModel{T} , z:: Number ) where T
165+ adjacent_verts = traverse (model. kdtree, (T (z),))
163166
164- m = size (model. xs, 2 )
167+ @assert (length (adjacent_verts) == 2 )
168+ v₁, v₂ = adjacent_verts[1 ][1 ], adjacent_verts[2 ][1 ]
165169
166- # in the univariate case, interpret a non-singleton zs as vector of
167- # ponits, not one point
168- if m == 1 && length (zs) > 1
169- return predict (model, reshape (zs, (length (zs), 1 )))
170+ if z == v₁ || z == v₂
171+ return first (model. predictions_and_gradients[[z]])
170172 end
171173
172- if length (zs) != m
173- error (" $(m) -dimensional model applied to length $(length (zs)) vector" )
174- end
174+ y₁, dy₁ = model. predictions_and_gradients[[v₁]]
175+ y₂, dy₂ = model. predictions_and_gradients[[v₂]]
175176
176- adjacent_verts = traverse (model . kdtree, zs )
177+ b_int = cubic_interpolation (v₁, y₁, dy₁, v₂, y₂, dy₂ )
177178
178- if m == 1
179- @assert (length (adjacent_verts) == 2 )
180- z = zs[1 ]
181- v₁, v₂ = adjacent_verts[1 ][1 ], adjacent_verts[2 ][1 ]
179+ return evalpoly (z, b_int)
180+ end
182181
183- if z == v₁ || z == v₂
184- return first (model. predictions_and_gradients[[z]])
185- end
182+ function predict (model:: LoessModel , zs:: AbstractVector )
183+ if size (model. xs, 2 ) > 1
184+ throw (ArgumentError (" multivariate blending not yet implemented" ))
185+ end
186186
187- y₁, dy₁ = model . predictions_and_gradients[[v₁] ]
188- y₂, dy₂ = model . predictions_and_gradients[[v₂]]
187+ return [ predict (model, z) for z in zs ]
188+ end
189189
190- b_int = cubic_interpolation (v₁, y₁, dy₁, v₂, y₂, dy₂)
190+ function predict (model:: LoessModel , zs:: AbstractMatrix )
191+ if size (model. xs, 2 ) != size (zs, 2 )
192+ throw (DimensionMismatch (" number of columns in input matrix must match the number of columns in the model matrix" ))
193+ end
191194
192- return evalpoly (z, b_int)
195+ if size (zs, 2 ) == 1
196+ return predict (model, vec (zs))
193197 else
194- error ( " Multivariate blending not yet implemented " )
198+ return [ predict (model, row) for row in eachrow (zs)]
195199 end
196200end
197201
198- predict (model:: LoessModel , zs:: AbstractMatrix ) = map (Base. Fix1 (predict, model), eachrow (zs))
199-
200202"""
201203 tricubic(u)
202204
0 commit comments