Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand Down Expand Up @@ -53,6 +54,7 @@ DataFrames = "1"
DimensionalData = "0.29.24, 0.30"
Downloads = "1.6.0"
ForwardDiff = "1"
Functors = "0.5.2"
GPUArraysCore = "0.2.0"
Hyperopt = "0.5.6"
JLD2 = "0.5.13, 0.6"
Expand Down
1 change: 1 addition & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AxisKeys.axiskeys(da::AbstractDimArray) = Tuple(lookup(da, d) for d in dims(da))
AxisKeys.axiskeys(da::AbstractDimArray, i::Int) = lookup(da, dims(da)[i])
AxisKeys.axiskeys(da::AbstractDimArray, name::Symbol) = lookup(da, name)
using Downloads: Downloads
using Functors: children
using Hyperopt: Hyperopt, Hyperoptimizer
using JLD2: JLD2, jldopen
using LuxCore: LuxCore
Expand Down
93 changes: 93 additions & 0 deletions src/training/mc_dropout.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
function evaluate_mc_dropout(
ghm, x, y, y_no_nan, ps, st, loss_types, training_loss, extra_loss, agg;
n_samples::Int = 100, file_path::Union{String, Nothing} = nothing, train_or_val_name::String = "val"
)

if !has_dropout(ghm)
@info "MC Dropout skipped: no Dropout layers detected in the model.\nFalling back to standard deterministic evaluation."
loss_val, sts, ŷ = evaluate_acc(ghm, x, y, y_no_nan, ps, st, loss_types, training_loss, extra_loss, agg)
return _store_sample(file_path, train_or_val_name, ŷ, loss_val, nothing)
end

st_train = Lux.trainmode(st)

for k in 1:n_samples
loss_k, _, ŷ_k = compute_loss(
ghm, ps, st_train,
(x, (y, y_no_nan)),
logging = LoggingLoss(
train_mode = true,
loss_types = loss_types,
training_loss = training_loss,
extra_loss = extra_loss,
agg = agg
)
)
Comment on lines +15 to +25
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compute_loss function returns an empty NamedTuple for the stats field (the third return value) when logging.train_mode is set to true. Since this loop explicitly sets train_mode = true to keep dropout active, ŷ_k will be empty, and the model's predictions will not be captured. This prevents the MC Dropout from collecting the necessary samples for uncertainty estimation. You should consider calling the model directly to obtain predictions while in training mode, or adjusting compute_loss to return predictions even when train_mode is true.

_store_sample(file_path, train_or_val_name, ŷ_k, loss_k, k)
end
Comment on lines +14 to +27
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When file_path is nothing, the samples generated in the loop are not collected or returned. The _store_sample function returns the prediction and loss, but these values are ignored by the loop, and evaluate_mc_dropout returns nothing. If the intention is to allow in-memory evaluation, you should accumulate the results in a list and return them at the end of the function.


return nothing
end


function _store_sample(file_path::String, name, ŷ, loss, sample)
return jldopen(file_path, "a+") do file
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Opening and closing the JLD2 file with jldopen in append mode ("a+") inside the loop is inefficient. For a large number of samples, this will cause significant I/O overhead. It is recommended to open the file once before the loop starts and pass the open file handle to the storage function.

key = isnothing(sample) ? name : "$(name)/sample_$(sample)"
file["predictions/$key"] = ŷ
file["losses/$key"] = loss
end
end

function _store_sample(::Nothing, name, ŷ, loss, sample)
return (; ŷ, loss)
end

function _has_dropout(model)
return model isa Lux.Dropout || model isa Lux.AlphaDropout || model isa Lux.VariationalHiddenDropout
end

function _has_dropout(model::Lux.AbstractLuxContainerLayer)
return any(_has_dropout, children(model))
end

function has_dropout(model)
return _has_dropout(model)
end

function mc_dropout_statistics(storage::NamedTuple)
predictions = [s.ŷ for s in storage]
losses = [s.loss for s in storage]

pred_stack = stack(predictions, dims = ndims(first(predictions)) + 1)
mean_pred = mean(pred_stack, dims = ndims(pred_stack))
var_pred = var(pred_stack, dims = ndims(pred_stack))
mean_loss = mean(losses)

return (; mean_pred, var_pred, mean_loss)
end

function mc_dropout_statistics(file_path::String, train_or_val_name::String)
return jldopen(file_path, "r") do file
keys = sort(keys(file["predictions/$train_or_val_name"]), by = k -> parse(Int, split(k, "_")[end]))
losses = [file["losses/$train_or_val_name/$(k)"] for k in keys]

# Welford online algorithm to avoid loading all predictions at once
first_pred = file["predictions/$train_or_val_name/$(keys[1])"]
mean_pred = copy(first_pred)
M2 = zero(first_pred)
mean_loss = first(losses)

for (k, (key, loss)) in enumerate(zip(keys[2:end], losses[2:end]))
ŷ_k = file["predictions/$train_or_val_name/$(key)"]
delta = ŷ_k .- mean_pred
mean_pred .+= delta ./ k
delta2 = ŷ_k .- mean_pred
M2 .+= delta .* delta2
mean_loss += (loss - mean_loss) / k
end

var_pred = M2 ./ (length(keys) - 1)

return (; mean_pred, var_pred, mean_loss)
end
end
Loading