From 18b1340e43934ac3c0a142bd63fb910d85e90f05 Mon Sep 17 00:00:00 2001 From: Jose Esparza Date: Thu, 16 Apr 2026 15:14:04 -0500 Subject: [PATCH 1/8] Adding `Freezable` wrapper for models --- src/MLJBase.jl | 2 + src/composition/models/freezable.jl | 322 +++++++++++++++++++++ test/composition/models/freezable.jl | 409 +++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 734 insertions(+) create mode 100644 src/composition/models/freezable.jl create mode 100644 test/composition/models/freezable.jl diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 22bb354f..c305f0df 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -167,6 +167,7 @@ include("composition/models/network_composite_types.jl") include("composition/models/network_composite.jl") include("composition/models/pipelines.jl") include("composition/models/transformed_target_model.jl") +include("composition/models/freezable.jl") include("operations.jl") @@ -303,6 +304,7 @@ export machines, sources, Stack, StaticPipeline, IntervalPipeline export TransformedTargetModel +export Freezable # resampling.jl: export ResamplingStrategy, InSample, Holdout, CV, StratifiedCV, TimeSeriesCV, diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl new file mode 100644 index 00000000..02cc5c35 --- /dev/null +++ b/src/composition/models/freezable.jl @@ -0,0 +1,322 @@ +const FREEZABLE_SUPPORTED_ATOMS = ( + :Deterministic, + :Probabilistic, + :Interval, + :Unsupervised, +) + +# Each supported atomic type gets its own wrapper which must have appropriate supertype: +const FREEZABLE_TYPE_GIVEN_ATOM = + Dict(atom => + Symbol("Freezable$atom") for atom in FREEZABLE_SUPPORTED_ATOMS) +const FREEZABLE_SUPER_GIVEN_ATOM = + Dict(atom => + Symbol("$(atom)NetworkComposite") for atom in FREEZABLE_SUPPORTED_ATOMS) + +# Type definitions: +for From in FREEZABLE_SUPPORTED_ATOMS + New = FREEZABLE_TYPE_GIVEN_ATOM[From] + To = FREEZABLE_SUPER_GIVEN_ATOM[From] + ex = quote + mutable struct $New{M <: $From} <: $To + model::M + frozen::Bool + cache::Bool + end + end + eval(ex) +end + +# dict whose keys and values are now types instead of symbols: +const freezable_type_given_atom = Dict() +for atom in FREEZABLE_SUPPORTED_ATOMS + atom_str = string(atom) + type = FREEZABLE_TYPE_GIVEN_ATOM[atom] + @eval(freezable_type_given_atom[$atom] = $type) +end + +# not exported: +const FREEZABLE_TYPES = values(freezable_type_given_atom) +const FREEZABLE_TYPE_EXS = values(FREEZABLE_TYPE_GIVEN_ATOM) +const SomeFreezable = Union{FREEZABLE_TYPES...} +const SupervisedFreezable = Union{ + freezable_type_given_atom[Deterministic], + freezable_type_given_atom[Probabilistic], + freezable_type_given_atom[Interval], +} +const FreezableSupported = Union{keys(freezable_type_given_atom)...} + +const ERR_FREEZABLE_MODEL_UNSPECIFIED = ArgumentError( + "Expecting atomic model as argument. None specified." +) +const ERR_FREEZABLE_TOO_MANY_ARGUMENTS = ArgumentError( + "At most one non-keyword argument, a model, allowed." +) +const PRETTY_FREEZABLE_SUPPORT_OPTIONS = + join([string("`", opt, "`") for opt in FREEZABLE_SUPPORTED_ATOMS], + ", ", + ", and ") +const err_freezable_unsupported(model) = ArgumentError( + "Only these model supertypes support `Freezable` wrapping: "* + "$PRETTY_FREEZABLE_SUPPORT_OPTIONS. "* + "Model provided has type `$(typeof(model))`." +) + +""" + Freezable(model; frozen=true, cache=true) + +Wrap the atomic `model` in a `Freezable` wrapper. When `frozen=true`, +training is skipped after initial fit, even if training rows change. +This is useful for avoiding expensive recomputation during +cross-validation or hyperparameter tuning, at the cost of data +hygiene. + +Unlike `freeze!(mach)`, which operates on an already-constructed +machine, `Freezable` operates at the model level. This means the +freeze semantics compose: a `Freezable`-wrapped model can be placed +inside a `Pipeline`, `Stack`, or `TunedModel`, and the inner +component will automatically skip retraining without the user needing +access to the internal machines that the composite creates. + +Set `frozen=false` to allow normal retraining. The `frozen` field can +be toggled after construction. + +Specify `cache=false` to prioritize memory over speed, or to guarantee +data anonymity. + +### Example 1: Freezing a single model + +```julia +using MLJBase + +X, y = make_regression(100) + +model = Freezable(DecisionTreeRegressor()) # frozen=true by default +mach = machine(model, X, y) + +fit!(mach) # initial training always proceeds +predict(mach, X) # works normally + +fit!(mach, rows=1:50) # no-op: frozen, so retraining is skipped + +thaw!(model) # or equivalently: model.frozen = false +fit!(mach, rows=1:50) # retrains on the new rows +``` + +### Example 2: Freezing a component inside a pipeline + +The main use case for `Freezable` is inside composites. Here a +`Standardizer` is frozen so it is trained once on the first fold and +then reused across all subsequent folds, while the classifier +retrains normally on each fold: + +```julia +using MLJBase + +X, y = make_blobs(200) + +pipe = Pipeline( + scaler = Freezable(Standardizer()), # trained once, then frozen + clf = DecisionTreeClassifier(), # retrains on every fold +) + +mach = machine(pipe, X, y) +fit!(mach, rows=1:100) # both components train +fit!(mach, rows=101:200) # only clf retrains; scaler is frozen +``` + +""" +function Freezable( + args...; + model=nothing, + frozen::Bool=true, + cache::Bool=true, +) + length(args) < 2 || throw(ERR_FREEZABLE_TOO_MANY_ARGUMENTS) + + if length(args) === 1 + atom = first(args) + model === nothing || + @warn "Using `model=$atom`. Ignoring specification `model=$model`. " + else + model === nothing && throw(ERR_FREEZABLE_MODEL_UNSPECIFIED) + atom = model + end + atom isa FreezableSupported || throw(err_freezable_unsupported(atom)) + + abstract_atom = MMI.abstract_type(atom) + haskey(freezable_type_given_atom, abstract_atom) || + throw(err_freezable_unsupported(atom)) + + metamodel = + freezable_type_given_atom[abstract_atom](atom, + frozen, + cache) + message = clean!(metamodel) + isempty(message) || @warn message + return metamodel +end + +function clean!(model::SomeFreezable) + message = "" + return message +end + +""" + freeze!(model::SomeFreezable) + +Set `model.frozen = true`. Subsequent `fit!` calls on a machine +wrapping this model will be no-ops (after initial training). + +See also [`thaw!`](@ref). +""" +freeze!(model::SomeFreezable) = (model.frozen = true; model) + +""" + thaw!(model::SomeFreezable) + +Set `model.frozen = false`. The next `fit!` call on a machine +wrapping this model will retrain normally. + +See also [`freeze!`](@ref). +""" +thaw!(model::SomeFreezable) = (model.frozen = false; model) + + +# Prefit methods +function prefit(model::SupervisedFreezable, verbosity, X, y) + Xs = source(X) + ys = source(y) + mach = machine(:model, Xs, ys; cache=model.cache) + (predict=predict(mach, Xs), transform=transform(mach, Xs)) +end + +function prefit(model::FreezableUnsupervised, verbosity, X) + Xs = source(X) + mach = machine(:model, Xs; cache=model.cache) + (transform=transform(mach, Xs), inverse_transform=inverse_transform(mach, Xs)) +end + +function MLJModelInterface.fit(composite::SomeFreezable, verbosity, data...) + # Build the learning network (inner machine starts unfrozen): + fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature + + # Train the network (initial training always proceeds): + greatest_lower_bound = MLJBase.glb(fitresult) + acceleration = MLJBase.acceleration(fitresult) + fit!(greatest_lower_bound; verbosity, composite, acceleration) + + # After initial training, freeze the inner machine if frozen=true: + if composite.frozen + d = MLJBase.machines_given_model(greatest_lower_bound) + if haskey(d, :model) + for mach in d[:model] + freeze!(mach) + end + end + end + + report = MLJBase.report(fitresult) + + # for passing to `update` so changes in `composite` can be detected: + cache = deepcopy(composite) + + return fitresult, cache, report +end + +function MLJModelInterface.update( + composite::SomeFreezable, + verbosity, + fitresult, + old_composite, + data..., +) + greatest_lower_bound = MLJBase.glb(fitresult) + + # Synchronize frozen state on the inner machine(s): + d = MLJBase.machines_given_model(greatest_lower_bound) + if haskey(d, :model) + for mach in d[:model] + composite.frozen ? freeze!(mach) : thaw!(mach) + end + end + + # Check if non-model, non-frozen hyperparameters changed (e.g., `cache`). + # Changes to `frozen` are handled above via freeze!/thaw! and should not + # trigger a full refit: + non_frozen_changed = any(propertynames(composite)) do field + field in MLJBase.models(greatest_lower_bound) && return false + field === :frozen && return false + old_value = getproperty(old_composite, field) + value = getproperty(composite, field) + value != old_value + end + non_frozen_changed && return MLJModelInterface.fit(composite, verbosity, data...) + + # retrain the network: + fit!(greatest_lower_bound; verbosity, composite) + + report = MLJBase.report(fitresult) + + # for passing to `update` so changes in `composite` can be detected: + cache = deepcopy(composite) + + return fitresult, cache, report +end + + +# When the Freezable model has frozen=true and the outer machine has already been +# trained (state > 0), skip retraining entirely. This prevents the NetworkComposite +# from rebuilding the learning network when training rows change. +# +# We synchronize the outer machine's `frozen` flag with the model's `frozen` field +# before each fit. The standard `fit_only!` checks `mach.frozen` at the top and +# returns immediately if true. Initial training (state == 0) always proceeds. +function MLJBase.fit!(mach::Machine{<:SomeFreezable}; kwargs...) + # Synchronize outer machine frozen flag: + if mach.model.frozen && mach.state > 0 + mach.frozen = true + else + mach.frozen = false + end + + # Delegate to the standard fit! logic: + glb_node = MLJBase.glb(mach.args...) + fit!(glb_node; kwargs...) + MLJBase.fit_only!(mach; kwargs...) +end + + +# Model traits +MMI.package_name(::Type{<:SomeFreezable}) = "MLJBase" +MMI.is_wrapper(::Type{<:SomeFreezable}) = true +MMI.load_path(::Type{<:SomeFreezable}) = "MLJBase.Freezable" +MMI.constructor(::Type{<:SomeFreezable}) = Freezable + +for New in FREEZABLE_TYPE_EXS + quote + MMI.iteration_parameter(::Type{<:$New{M}}) where M = + MLJBase.prepend(:model, iteration_parameter(M)) + end |> eval + for trait in [ + :input_scitype, + :output_scitype, + :target_scitype, + :fit_data_scitype, + :predict_scitype, + :transform_scitype, + :inverse_transform_scitype, + :is_pure_julia, + :supports_weights, + :supports_class_weights, + :supports_online, + :supports_training_losses, + :reports_feature_importances, + :is_supervised, + :prediction_type, + ] + quote + MMI.$trait(::Type{<:$New{M}}) where M = MMI.$trait(M) + end |> eval + end +end diff --git a/test/composition/models/freezable.jl b/test/composition/models/freezable.jl new file mode 100644 index 00000000..8230bfbf --- /dev/null +++ b/test/composition/models/freezable.jl @@ -0,0 +1,409 @@ +module TestFreezable + +using MLJBase +using Test +using ..Models +using StableRNGs +const MMI = MLJBase.MLJModelInterface + +@testset "Freezable types and constructor" begin + @testset "deterministic model wrapping" begin + atom = DeterministicConstantRegressor() + model = Freezable(atom) + @test model isa MLJBase.FreezableDeterministic + @test model isa MLJBase.DeterministicNetworkComposite + end + + @testset "probabilistic model wrapping" begin + atom = ConstantRegressor() + model = Freezable(atom) + @test model isa MLJBase.FreezableProbabilistic + @test model isa MLJBase.ProbabilisticNetworkComposite + end + + @testset "unsupervised model wrapping" begin + atom = UnivariateStandardizer() + model = Freezable(atom) + @test model isa MLJBase.FreezableUnsupervised + @test model isa MLJBase.UnsupervisedNetworkComposite + end + + @testset "default values" begin + atom = DeterministicConstantRegressor() + model = Freezable(atom) + @test model.frozen == true + @test model.cache == true + end + + @testset "custom keyword arguments" begin + atom = DeterministicConstantRegressor() + model = Freezable(atom; frozen=false, cache=false) + @test model.frozen == false + @test model.cache == false + end + + @testset "error: no model" begin + @test_throws( + MLJBase.ERR_FREEZABLE_MODEL_UNSPECIFIED, + Freezable(), + ) + end + + @testset "error: unsupported model type (Static)" begin + static_model = Averager(mix=0.5) + @test_throws(ArgumentError, Freezable(static_model)) + end + + @testset "error: too many arguments" begin + @test_throws( + MLJBase.ERR_FREEZABLE_TOO_MANY_ARGUMENTS, + Freezable(DeterministicConstantRegressor(), ConstantRegressor()), + ) + end + + @testset "model field identity" begin + atom = DeterministicConstantRegressor() + model = Freezable(atom) + @test model.model === atom + end + +end + +@testset "constructor type correctness across model types" begin + EXPECTED_SUPER = Dict( + MLJBase.Deterministic => MLJBase.DeterministicNetworkComposite, + MLJBase.Probabilistic => MLJBase.ProbabilisticNetworkComposite, + MLJBase.Unsupervised => MLJBase.UnsupervisedNetworkComposite, + ) + atoms = [ + DeterministicConstantRegressor(), + ConstantRegressor(), + ConstantClassifier(), + DecisionTreeClassifier(), + DecisionTreeRegressor(), + UnivariateStandardizer(), + Standardizer(), + ] + for atom in atoms + wrapped = Freezable(atom) + abstract_atom = MMI.abstract_type(atom) + expected_super = EXPECTED_SUPER[abstract_atom] + @test wrapped isa expected_super + @test wrapped.model === atom + end +end + +# Define iteration_parameter for DeterministicConstantRegressor within this module +# so we can test that Freezable prepends :model to the path. +MMI.iteration_parameter(::Type{DeterministicConstantRegressor}) = :n + +@testset "trait delegation" begin + atom_det = DeterministicConstantRegressor() + wrapped_det = Freezable(atom_det) + + atom_clf = DecisionTreeClassifier() + wrapped_clf = Freezable(atom_clf) + + @testset "input_scitype delegation" begin + @test MLJBase.input_scitype(typeof(wrapped_det)) == + MLJBase.input_scitype(typeof(atom_det)) + @test MLJBase.input_scitype(typeof(wrapped_clf)) == + MLJBase.input_scitype(typeof(atom_clf)) + end + + @testset "target_scitype delegation" begin + @test MLJBase.target_scitype(typeof(wrapped_det)) == + MLJBase.target_scitype(typeof(atom_det)) + @test MLJBase.target_scitype(typeof(wrapped_clf)) == + MLJBase.target_scitype(typeof(atom_clf)) + end + + @testset "is_wrapper" begin + @test MLJBase.is_wrapper(typeof(wrapped_det)) == true + @test MLJBase.is_wrapper(typeof(wrapped_clf)) == true + end + + @testset "package_name" begin + @test MLJBase.package_name(typeof(wrapped_det)) == "MLJBase" + end + + @testset "load_path" begin + @test MLJBase.load_path(typeof(wrapped_det)) == "MLJBase.Freezable" + end + + @testset "constructor" begin + @test MLJBase.constructor(typeof(wrapped_det)) == Freezable + end + + @testset "iteration_parameter prepends :model" begin + @test MLJBase.iteration_parameter(typeof(wrapped_det)) == :(model.n) + end +end + +@testset "trait delegation across model types" begin + DELEGATED_TRAITS = [ + MLJBase.input_scitype, + MLJBase.target_scitype, + MLJBase.fit_data_scitype, + MLJBase.predict_scitype, + MLJBase.transform_scitype, + MLJBase.is_pure_julia, + MLJBase.supports_weights, + MLJBase.supports_class_weights, + MLJBase.supports_training_losses, + MLJBase.is_supervised, + MLJBase.prediction_type, + ] + atoms = [ + DeterministicConstantRegressor(), + ConstantRegressor(), + ConstantClassifier(), + DecisionTreeClassifier(), + DecisionTreeRegressor(), + UnivariateStandardizer(), + Standardizer(), + ] + for atom in atoms + wrapped = Freezable(atom) + for trait_fn in DELEGATED_TRAITS + @test trait_fn(typeof(wrapped)) == trait_fn(typeof(atom)) + end + end +end + +@testset "Freezable export" begin + @test :Freezable in names(MLJBase) +end + +@testset "supervised Freezable end-to-end" begin + rng = StableRNG(42) + X = MLJBase.table(randn(rng, 20, 3)) + y = randn(rng, 20) + + # Freezable-wrapped model (frozen=false so it behaves like unwrapped) + fmodel = Freezable(DeterministicConstantRegressor(); frozen=false) + fmach = machine(fmodel, X, y) + fit!(fmach; verbosity=0) + fpreds = predict(fmach, X) + + # Unwrapped model + atom = DeterministicConstantRegressor() + amach = machine(atom, X, y) + fit!(amach; verbosity=0) + apreds = predict(amach, X) + + @test fpreds ≈ apreds + + fp = fitted_params(fmach) + @test :model in keys(fp) + + rep = report(fmach) + @test rep === nothing || rep isa NamedTuple +end + +@testset "unsupervised Freezable end-to-end" begin + rng = StableRNG(123) + v = randn(rng, 50) + + fmodel = Freezable(UnivariateStandardizer(); frozen=false) + fmach = machine(fmodel, v) + fit!(fmach; verbosity=0) + ftransformed = transform(fmach, v) + + atom = UnivariateStandardizer() + amach = machine(atom, v) + fit!(amach; verbosity=0) + atransformed = transform(amach, v) + + @test ftransformed ≈ atransformed +end + +@testset "initial training always proceeds when frozen" begin + rng = StableRNG(77) + X = MLJBase.table(randn(rng, 20, 3)) + y = randn(rng, 20) + + model = Freezable(DeterministicConstantRegressor(); frozen=true) + mach = machine(model, X, y) + @test mach.state == 0 + + fit!(mach; verbosity=0) + @test mach.state > 0 + + preds = predict(mach, X) + @test length(preds) == 20 + @test all(isfinite, preds) +end + +@testset "frozen training behavior" begin + rng = StableRNG(99) + + # Rows 1:10 have mean 1.0, rows 11:20 have mean 100.0. + # DeterministicConstantRegressor predicts the mean, so we can detect + # whether retraining happened by checking the predicted value. + X = MLJBase.table(randn(rng, 20, 2)) + y = vcat(fill(1.0, 10), fill(100.0, 10)) + + model = Freezable(DeterministicConstantRegressor(), frozen=true) + @test model.frozen == true + + # Initial fit with rows 1:10 — should proceed even though frozen=true + mach = machine(model, X, y) + fit!(mach; rows=1:10, verbosity=0) + @test mach.state > 0 + preds_initial = predict(mach, X) + @test all(p -> p ≈ 1.0, preds_initial) + + # Second fit with rows 11:20 — should be a no-op (frozen skip) + fit!(mach; rows=11:20, verbosity=0) + preds_frozen = predict(mach, X) + @test preds_frozen == preds_initial + + # Thaw and retrain + model.frozen = false + @test model.frozen == false + fit!(mach; rows=11:20, verbosity=0) + preds_thawed = predict(mach, X) + @test all(p -> p ≈ 100.0, preds_thawed) + @test preds_thawed != preds_initial +end + +@testset "freeze! and thaw! on models" begin + model = Freezable(DeterministicConstantRegressor(); frozen=false) + @test model.frozen == false + + # freeze! sets frozen=true and returns the model + ret = freeze!(model) + @test model.frozen == true + @test ret === model + + # thaw! sets frozen=false and returns the model + ret = thaw!(model) + @test model.frozen == false + @test ret === model +end + +@testset "thaw! triggers retraining" begin + rng = StableRNG(88) + X = MLJBase.table(randn(rng, 20, 3)) + y = randn(rng, 20) + + model = Freezable(DeterministicConstantRegressor(); frozen=true) + mach = machine(model, X, y) + fit!(mach; verbosity=0) + state_after_initial = mach.state + @test state_after_initial > 0 + + thaw!(model) + fit!(mach; verbosity=0) + @test mach.state > state_after_initial +end + +@testset "frozen skip with different data sizes" begin + for n in [5, 10, 30, 50] + rng = StableRNG(hash(n)) + X = MLJBase.table(randn(rng, n, 2)) + y = randn(rng, n) + + model = Freezable(DeterministicConstantRegressor(); frozen=true) + mach = machine(model, X, y) + fit!(mach; verbosity=0) + preds1 = predict(mach, X) + + half = max(1, n ÷ 2) + fit!(mach; rows=1:half, verbosity=0) + preds2 = predict(mach, X) + + @test preds1 == preds2 + end +end + +@testset "pipeline with frozen component" begin + rng = StableRNG(555) + + X_part1 = fill(1.0, 10, 2) + X_part2 = fill(100.0, 10, 2) + X = MLJBase.table(vcat(X_part1, X_part2)) + y = vcat(fill(10.0, 10), fill(200.0, 10)) + + frozen_std = Freezable(Standardizer(), frozen=true) + pipe = Pipeline( + std = frozen_std, + reg = DeterministicConstantRegressor(), + ) + + mach = machine(pipe, X, y) + fit!(mach; rows=1:10, verbosity=0) + preds_first = predict(mach, X) + @test all(p -> p ≈ 10.0, preds_first) + + # Retrain with rows 11:20: frozen Standardizer skips, predictor retrains + fit!(mach; rows=11:20, verbosity=0) + preds_second = predict(mach, X) + @test all(p -> p ≈ 200.0, preds_second) + @test length(preds_second) == 20 +end + +@testset "pipeline frozen component skip with varying sizes" begin + for n in [10, 20, 40] + rng = StableRNG(hash(n)) + half = n ÷ 2 + + X = MLJBase.table(randn(rng, n, 3)) + y = vcat(fill(1.0, half), fill(100.0, n - half)) + + frozen_std = Freezable(Standardizer(), frozen=true) + pipe = Pipeline( + std = frozen_std, + reg = DeterministicConstantRegressor(), + ) + + mach = machine(pipe, X, y) + fit!(mach; rows=1:half, verbosity=0) + preds_first = predict(mach, X) + @test all(p -> p ≈ 1.0, preds_first) + + fit!(mach; rows=(half+1):n, verbosity=0) + preds_second = predict(mach, X) + @test all(p -> p ≈ 100.0, preds_second) + @test length(preds_second) == n + end +end + +@testset "thawed supervised equivalence" begin + rng = StableRNG(200) + X = MLJBase.table(randn(rng, 30, 4)) + y = randn(rng, 30) + + fmodel = Freezable(DeterministicConstantRegressor(); frozen=false) + fmach = machine(fmodel, X, y) + fit!(fmach; verbosity=0) + fpreds = predict(fmach, X) + + atom = DeterministicConstantRegressor() + amach = machine(atom, X, y) + fit!(amach; verbosity=0) + apreds = predict(amach, X) + + @test fpreds ≈ apreds +end + +@testset "thawed unsupervised equivalence" begin + rng = StableRNG(300) + v = randn(rng, 40) + + fmodel = Freezable(UnivariateStandardizer(); frozen=false) + fmach = machine(fmodel, v) + fit!(fmach; verbosity=0) + ftransformed = transform(fmach, v) + + atom = UnivariateStandardizer() + amach = machine(atom, v) + fit!(amach; verbosity=0) + atransformed = transform(amach, v) + + @test ftransformed ≈ atransformed +end + +end # module +true diff --git a/test/runtests.jl b/test/runtests.jl index 0c5593af..5632818f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,7 @@ end @test include("composition/models/inspection.jl") @test include("composition/models/pipelines.jl") @test include("composition/models/transformed_target_model.jl") + @test include("composition/models/freezable.jl") @test include("composition/models/stacking.jl") @test include("composition/models/static_transformers.jl") end From 75dd7b7da21e0e2da84dede80dce828df49e5d0d Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:51:06 -0500 Subject: [PATCH 2/8] Update src/composition/models/freezable.jl Co-authored-by: Anthony Blaom, PhD --- src/composition/models/freezable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index 02cc5c35..c135c1a6 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -87,7 +87,7 @@ data anonymity. ### Example 1: Freezing a single model ```julia -using MLJBase +using MLJ # or `using MLJBase, MLJModels` X, y = make_regression(100) From 66038510cbc52deeacb9fc3093b642add4f6016d Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:51:15 -0500 Subject: [PATCH 3/8] Update src/composition/models/freezable.jl Co-authored-by: Anthony Blaom, PhD --- src/composition/models/freezable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index c135c1a6..7afc8964 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -85,7 +85,7 @@ Specify `cache=false` to prioritize memory over speed, or to guarantee data anonymity. ### Example 1: Freezing a single model - +This example and the next assume you have MLJDecisionTreeInterface in your package environment. ```julia using MLJ # or `using MLJBase, MLJModels` From a63be5bb7defd10b328f1479b67447eb5e27a5df Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:51:37 -0500 Subject: [PATCH 4/8] Update src/composition/models/freezable.jl Co-authored-by: Anthony Blaom, PhD --- src/composition/models/freezable.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index 7afc8964..ca5daa26 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -91,6 +91,7 @@ using MLJ # or `using MLJBase, MLJModels` X, y = make_regression(100) +DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree model = Freezable(DecisionTreeRegressor()) # frozen=true by default mach = machine(model, X, y) From 72e3fc9831648f9f2fc6a1e12417eb6f10247118 Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:51:45 -0500 Subject: [PATCH 5/8] Update src/composition/models/freezable.jl Co-authored-by: Anthony Blaom, PhD --- src/composition/models/freezable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index ca5daa26..4fb46402 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -112,7 +112,7 @@ then reused across all subsequent folds, while the classifier retrains normally on each fold: ```julia -using MLJBase +using MLJ # or `using MLJBase, MLJModels, using MLJTransforms` X, y = make_blobs(200) From 09c4862e314c4971c4fe82608ce3b8dcd6e96ba6 Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:51:53 -0500 Subject: [PATCH 6/8] Update src/composition/models/freezable.jl Co-authored-by: Anthony Blaom, PhD --- src/composition/models/freezable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index 4fb46402..c9f48a8d 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -115,7 +115,7 @@ retrains normally on each fold: using MLJ # or `using MLJBase, MLJModels, using MLJTransforms` X, y = make_blobs(200) - +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree pipe = Pipeline( scaler = Freezable(Standardizer()), # trained once, then frozen clf = DecisionTreeClassifier(), # retrains on every fold From b80aa10fa1c55b9011c6b1c3ec4bf4fb8133a5cf Mon Sep 17 00:00:00 2001 From: Jose Esparza Date: Sun, 3 May 2026 00:21:26 -0500 Subject: [PATCH 7/8] Adding a stash mechanism to enable a better cache strategy --- src/composition/models/freezable.jl | 69 ++++++++++++++++++++++++++-- src/machines.jl | 3 +- test/composition/models/freezable.jl | 55 ++++++++++++++++++++++ test/machines.jl | 2 +- 4 files changed, 121 insertions(+), 8 deletions(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index c9f48a8d..24b83176 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -13,7 +13,12 @@ const FREEZABLE_SUPER_GIVEN_ATOM = Dict(atom => Symbol("$(atom)NetworkComposite") for atom in FREEZABLE_SUPPORTED_ATOMS) -# Type definitions: +# Type definitions. The `_stash` field is a 0-or-1-element vector that, when populated, +# holds the (fitresult, cache, report) tuple from a prior successful fit. It lets +# `MLJModelInterface.fit` short-circuit when the model is `frozen=true` and is invoked +# again as a fresh fit (e.g. when a parent composite rebuilds its learning network on a +# row change). The field is hidden from `propertynames` below so it does not participate +# in model equality, display, or trait machinery. for From in FREEZABLE_SUPPORTED_ATOMS New = FREEZABLE_TYPE_GIVEN_ATOM[From] To = FREEZABLE_SUPER_GIVEN_ATOM[From] @@ -22,6 +27,7 @@ for From in FREEZABLE_SUPPORTED_ATOMS model::M frozen::Bool cache::Bool + _stash::Vector{Any} end end eval(ex) @@ -46,6 +52,11 @@ const SupervisedFreezable = Union{ } const FreezableSupported = Union{keys(freezable_type_given_atom)...} +# Hide private `_stash` from equality (`is_same_except`), display, and any other +# property-iteration code paths. +const _FREEZABLE_PUBLIC_PROPS = (:model, :frozen, :cache) +Base.propertynames(::SomeFreezable) = _FREEZABLE_PUBLIC_PROPS + const ERR_FREEZABLE_MODEL_UNSPECIFIED = ArgumentError( "Expecting atomic model as argument. None specified." ) @@ -152,7 +163,8 @@ function Freezable( metamodel = freezable_type_given_atom[abstract_atom](atom, frozen, - cache) + cache, + Any[]) message = clean!(metamodel) isempty(message) || @warn message return metamodel @@ -181,7 +193,11 @@ wrapping this model will retrain normally. See also [`freeze!`](@ref). """ -thaw!(model::SomeFreezable) = (model.frozen = false; model) +function thaw!(model::SomeFreezable) + model.frozen = false + empty!(model._stash) + return model +end # Prefit methods @@ -199,6 +215,18 @@ function prefit(model::FreezableUnsupervised, verbosity, X) end function MLJModelInterface.fit(composite::SomeFreezable, verbosity, data...) + # If the model has been fit before AND is currently frozen, reuse the cached + # fit output verbatim. This handles the case where a parent composite (e.g. a + # `Pipeline`) rebuilds its learning network on a row change: from the parent's + # perspective this is a fresh fit, but the freeze contract says we must not + # retrain. The `Signature` carried in the cached fitresult contains the original + # trained inner machine; subsequent `predict`/`transform` calls clone the + # signature and rebind source nodes to new data, so reusing the cached fitresult + # is correct. + if composite.frozen && !isempty(composite._stash) + return composite._stash[end] + end + # Build the learning network (inner machine starts unfrozen): fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature @@ -222,7 +250,17 @@ function MLJModelInterface.fit(composite::SomeFreezable, verbosity, data...) # for passing to `update` so changes in `composite` can be detected: cache = deepcopy(composite) - return fitresult, cache, report + output = (fitresult, cache, report) + _refresh_stash!(composite, output) + return output +end + +# Maintain at most one entry: the latest fit output, only while the model is frozen. +# A thawed model holds no stash, so the next `fit` call retrains from scratch. +function _refresh_stash!(composite::SomeFreezable, output) + empty!(composite._stash) + composite.frozen && push!(composite._stash, output) + return composite end function MLJModelInterface.update( @@ -262,7 +300,9 @@ function MLJModelInterface.update( # for passing to `update` so changes in `composite` can be detected: cache = deepcopy(composite) - return fitresult, cache, report + output = (fitresult, cache, report) + _refresh_stash!(composite, output) + return output end @@ -288,6 +328,25 @@ function MLJBase.fit!(mach::Machine{<:SomeFreezable}; kwargs...) end +# Forwarded methods: the wrapper should be transparent to per-fit operations +# the inner model supports. + +const ERR_FREEZABLE_MISSING_REPORT = + "Cannot find report for the atomic model wrapped by `Freezable`. " + +function MMI.training_losses(composite::SomeFreezable, freezable_report) + hasproperty(freezable_report, :model) || throw(ERR_FREEZABLE_MISSING_REPORT) + atomic_report = getproperty(freezable_report, :model) + return training_losses(composite.model, atomic_report) +end + +function MMI.feature_importances(composite::SupervisedFreezable, fitresult, report) + predict_node = fitresult.interface.predict + mach = only(MLJBase.machines_given_model(predict_node)[:model]) + return feature_importances(composite.model, mach.fitresult, mach.report[:fit]) +end + + # Model traits MMI.package_name(::Type{<:SomeFreezable}) = "MLJBase" MMI.is_wrapper(::Type{<:SomeFreezable}) = true diff --git a/src/machines.jl b/src/machines.jl index 0eaf2779..cb5efe9c 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -493,8 +493,6 @@ end function fitlog(mach, action::Symbol, verbosity) if verbosity < -1000 put!(MACHINE_CHANNEL, (action, mach)) - elseif verbosity > -1 && action == :frozen - @warn "$mach not trained as it is frozen." elseif verbosity > 0 action == :train && (@info "Training $mach."; return) action == :update && (@info "Updating $mach."; return) @@ -502,6 +500,7 @@ function fitlog(mach, action::Symbol, verbosity) @info "Not retraining $mach. Use `force=true` to force." return end + action == :frozen && (@info "Not retraining $mach as it is frozen."; return) end end diff --git a/test/composition/models/freezable.jl b/test/composition/models/freezable.jl index 8230bfbf..13d1c597 100644 --- a/test/composition/models/freezable.jl +++ b/test/composition/models/freezable.jl @@ -3,6 +3,7 @@ module TestFreezable using MLJBase using Test using ..Models +using ..TestUtilities using StableRNGs const MMI = MLJBase.MLJModelInterface @@ -344,6 +345,60 @@ end @test length(preds_second) == 20 end +@testset "pipeline frozen component is not retrained on row change" begin + # Regression test: when a `Freezable` component is inside a pipeline and the + # outer machine is re-fitted with new rows, the parent composite rebuilds + # its learning network. The frozen inner model must NOT be retrained — its + # fitted_params must match the first fit byte-for-byte. + rng = StableRNG(777) + X = (x1 = randn(rng, 200), x2 = randn(rng, 200)) + y = randn(rng, 200) + + frozen_std = Freezable(Standardizer(), frozen=true) + pipe = Pipeline( + scaler = frozen_std, + reg = DeterministicConstantRegressor(), + ) + + mach = machine(pipe, X, y) + fit!(mach; rows=1:100, verbosity=0) + fp_first = fitted_params(mach).scaler.model + + fit!(mach; rows=101:200, verbosity=0) + fp_second = fitted_params(mach).scaler.model + + @test fp_first == fp_second +end + +@testset "pipeline frozen component fit/skip sequence" begin + # Use @test_mach_sequence to assert the inner Standardizer's machine is + # trained once and not retrained on a subsequent row-change refit. + rng = StableRNG(778) + X = (x1 = randn(rng, 60), x2 = randn(rng, 60)) + y = randn(rng, 60) + + frozen_std = Freezable(Standardizer(), frozen=true) + pipe = Pipeline( + scaler = frozen_std, + reg = DeterministicConstantRegressor(), + ) + mach = machine(pipe, X, y) + + # Drive the channel once to discover the machine objects, then assert. + fit!(mach; rows=1:30, verbosity=-5000) + seq1 = MLJBase.flush!(MLJBase.MACHINE_CHANNEL) + @test any(t -> t[1] === :train && t[2].model isa Symbol && t[2].model === :model, seq1) + + # Second fit on different rows must not produce a :train event for the inner + # Standardizer machine (the one whose symbolic model is `:model`, owned by the + # FreezableUnsupervised composite). + fit!(mach; rows=31:60, verbosity=-5000) + seq2 = MLJBase.flush!(MLJBase.MACHINE_CHANNEL) + inner_train_events = + filter(t -> t[1] === :train && t[2].model isa Symbol && t[2].model === :model, seq2) + @test isempty(inner_train_events) +end + @testset "pipeline frozen component skip with varying sizes" begin for n in [10, 20, 40] rng = StableRNG(hash(n)) diff --git a/test/machines.jl b/test/machines.jl index 95080b37..bbe3c168 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -88,7 +88,7 @@ end # test a frozen Machine stand = machine(Standardizer(), source((x1=rand(10),))) freeze!(stand) - @test_logs (:warn, r"not trained as it is frozen\.$") fit!(stand) + @test_logs (:info, r"Not retraining .* as it is frozen\.$") fit!(stand) end @testset "machine instantiation warnings" begin From f417b7f8aa11690299593582e79506dca6dc2ea0 Mon Sep 17 00:00:00 2001 From: Jose Esparza Date: Tue, 26 May 2026 16:11:15 -0500 Subject: [PATCH 8/8] Move freeze dispatch into `fit_only!` and rebuild-with-state-transfer in `NetworkComposite` --- src/composition/models/freezable.jl | 196 +++----------------- src/composition/models/network_composite.jl | 38 ++++ src/machines.jl | 83 +++++++-- 3 files changed, 129 insertions(+), 188 deletions(-) diff --git a/src/composition/models/freezable.jl b/src/composition/models/freezable.jl index 24b83176..a1feb1c4 100644 --- a/src/composition/models/freezable.jl +++ b/src/composition/models/freezable.jl @@ -13,12 +13,7 @@ const FREEZABLE_SUPER_GIVEN_ATOM = Dict(atom => Symbol("$(atom)NetworkComposite") for atom in FREEZABLE_SUPPORTED_ATOMS) -# Type definitions. The `_stash` field is a 0-or-1-element vector that, when populated, -# holds the (fitresult, cache, report) tuple from a prior successful fit. It lets -# `MLJModelInterface.fit` short-circuit when the model is `frozen=true` and is invoked -# again as a fresh fit (e.g. when a parent composite rebuilds its learning network on a -# row change). The field is hidden from `propertynames` below so it does not participate -# in model equality, display, or trait machinery. +# Type definitions: for From in FREEZABLE_SUPPORTED_ATOMS New = FREEZABLE_TYPE_GIVEN_ATOM[From] To = FREEZABLE_SUPER_GIVEN_ATOM[From] @@ -27,7 +22,6 @@ for From in FREEZABLE_SUPPORTED_ATOMS model::M frozen::Bool cache::Bool - _stash::Vector{Any} end end eval(ex) @@ -52,11 +46,6 @@ const SupervisedFreezable = Union{ } const FreezableSupported = Union{keys(freezable_type_given_atom)...} -# Hide private `_stash` from equality (`is_same_except`), display, and any other -# property-iteration code paths. -const _FREEZABLE_PUBLIC_PROPS = (:model, :frozen, :cache) -Base.propertynames(::SomeFreezable) = _FREEZABLE_PUBLIC_PROPS - const ERR_FREEZABLE_MODEL_UNSPECIFIED = ArgumentError( "Expecting atomic model as argument. None specified." ) @@ -76,62 +65,44 @@ const err_freezable_unsupported(model) = ArgumentError( """ Freezable(model; frozen=true, cache=true) -Wrap the atomic `model` in a `Freezable` wrapper. When `frozen=true`, -training is skipped after initial fit, even if training rows change. -This is useful for avoiding expensive recomputation during -cross-validation or hyperparameter tuning, at the cost of data -hygiene. +Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside +a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when +the parent rebuilds its learning network on a row change. -Unlike `freeze!(mach)`, which operates on an already-constructed -machine, `Freezable` operates at the model level. This means the -freeze semantics compose: a `Freezable`-wrapped model can be placed -inside a `Pipeline`, `Stack`, or `TunedModel`, and the inner -component will automatically skip retraining without the user needing -access to the internal machines that the composite creates. +Set `frozen=false` to allow normal retraining. Use [`freeze!`](@ref) and [`thaw!`](@ref) +to toggle after construction. Set `cache=false` to prioritize memory over speed. -Set `frozen=false` to allow normal retraining. The `frozen` field can -be toggled after construction. +### Example 1: Freezing a single model -Specify `cache=false` to prioritize memory over speed, or to guarantee -data anonymity. +This example and the next assume you have `MLJDecisionTreeInterface` in your environment. -### Example 1: Freezing a single model -This example and the next assume you have MLJDecisionTreeInterface in your package environment. ```julia using MLJ # or `using MLJBase, MLJModels` X, y = make_regression(100) - DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree -model = Freezable(DecisionTreeRegressor()) # frozen=true by default -mach = machine(model, X, y) - -fit!(mach) # initial training always proceeds -predict(mach, X) # works normally -fit!(mach, rows=1:50) # no-op: frozen, so retraining is skipped +model = Freezable(DecisionTreeRegressor()) # frozen=true by default +mach = machine(model, X, y) -thaw!(model) # or equivalently: model.frozen = false -fit!(mach, rows=1:50) # retrains on the new rows +fit!(mach) # first fit trains +fit!(mach, rows=1:50) # no-op while frozen +thaw!(model) +fit!(mach, rows=1:50) # retrains ``` ### Example 2: Freezing a component inside a pipeline -The main use case for `Freezable` is inside composites. Here a -`Standardizer` is frozen so it is trained once on the first fold and -then reused across all subsequent folds, while the classifier -retrains normally on each fold: - ```julia -using MLJ # or `using MLJBase, MLJModels, using MLJTransforms` +using MLJ # or `using MLJBase, MLJModels, MLJTransforms` X, y = make_blobs(200) DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree + pipe = Pipeline( - scaler = Freezable(Standardizer()), # trained once, then frozen - clf = DecisionTreeClassifier(), # retrains on every fold + scaler = Freezable(Standardizer()), + clf = DecisionTreeClassifier(), ) - mach = machine(pipe, X, y) fit!(mach, rows=1:100) # both components train fit!(mach, rows=101:200) # only clf retrains; scaler is frozen @@ -163,8 +134,7 @@ function Freezable( metamodel = freezable_type_given_atom[abstract_atom](atom, frozen, - cache, - Any[]) + cache) message = clean!(metamodel) isempty(message) || @warn message return metamodel @@ -178,8 +148,8 @@ end """ freeze!(model::SomeFreezable) -Set `model.frozen = true`. Subsequent `fit!` calls on a machine -wrapping this model will be no-ops (after initial training). +Set `model.frozen = true`. After the first training pass, subsequent `fit!` calls on a +machine wrapping this model become no-ops. See also [`thaw!`](@ref). """ @@ -188,16 +158,12 @@ freeze!(model::SomeFreezable) = (model.frozen = true; model) """ thaw!(model::SomeFreezable) -Set `model.frozen = false`. The next `fit!` call on a machine -wrapping this model will retrain normally. +Set `model.frozen = false`. The next `fit!` call on a machine wrapping this model +retrains. See also [`freeze!`](@ref). """ -function thaw!(model::SomeFreezable) - model.frozen = false - empty!(model._stash) - return model -end +thaw!(model::SomeFreezable) = (model.frozen = false; model) # Prefit methods @@ -214,120 +180,6 @@ function prefit(model::FreezableUnsupervised, verbosity, X) (transform=transform(mach, Xs), inverse_transform=inverse_transform(mach, Xs)) end -function MLJModelInterface.fit(composite::SomeFreezable, verbosity, data...) - # If the model has been fit before AND is currently frozen, reuse the cached - # fit output verbatim. This handles the case where a parent composite (e.g. a - # `Pipeline`) rebuilds its learning network on a row change: from the parent's - # perspective this is a fresh fit, but the freeze contract says we must not - # retrain. The `Signature` carried in the cached fitresult contains the original - # trained inner machine; subsequent `predict`/`transform` calls clone the - # signature and rebind source nodes to new data, so reusing the cached fitresult - # is correct. - if composite.frozen && !isempty(composite._stash) - return composite._stash[end] - end - - # Build the learning network (inner machine starts unfrozen): - fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature - - # Train the network (initial training always proceeds): - greatest_lower_bound = MLJBase.glb(fitresult) - acceleration = MLJBase.acceleration(fitresult) - fit!(greatest_lower_bound; verbosity, composite, acceleration) - - # After initial training, freeze the inner machine if frozen=true: - if composite.frozen - d = MLJBase.machines_given_model(greatest_lower_bound) - if haskey(d, :model) - for mach in d[:model] - freeze!(mach) - end - end - end - - report = MLJBase.report(fitresult) - - # for passing to `update` so changes in `composite` can be detected: - cache = deepcopy(composite) - - output = (fitresult, cache, report) - _refresh_stash!(composite, output) - return output -end - -# Maintain at most one entry: the latest fit output, only while the model is frozen. -# A thawed model holds no stash, so the next `fit` call retrains from scratch. -function _refresh_stash!(composite::SomeFreezable, output) - empty!(composite._stash) - composite.frozen && push!(composite._stash, output) - return composite -end - -function MLJModelInterface.update( - composite::SomeFreezable, - verbosity, - fitresult, - old_composite, - data..., -) - greatest_lower_bound = MLJBase.glb(fitresult) - - # Synchronize frozen state on the inner machine(s): - d = MLJBase.machines_given_model(greatest_lower_bound) - if haskey(d, :model) - for mach in d[:model] - composite.frozen ? freeze!(mach) : thaw!(mach) - end - end - - # Check if non-model, non-frozen hyperparameters changed (e.g., `cache`). - # Changes to `frozen` are handled above via freeze!/thaw! and should not - # trigger a full refit: - non_frozen_changed = any(propertynames(composite)) do field - field in MLJBase.models(greatest_lower_bound) && return false - field === :frozen && return false - old_value = getproperty(old_composite, field) - value = getproperty(composite, field) - value != old_value - end - non_frozen_changed && return MLJModelInterface.fit(composite, verbosity, data...) - - # retrain the network: - fit!(greatest_lower_bound; verbosity, composite) - - report = MLJBase.report(fitresult) - - # for passing to `update` so changes in `composite` can be detected: - cache = deepcopy(composite) - - output = (fitresult, cache, report) - _refresh_stash!(composite, output) - return output -end - - -# When the Freezable model has frozen=true and the outer machine has already been -# trained (state > 0), skip retraining entirely. This prevents the NetworkComposite -# from rebuilding the learning network when training rows change. -# -# We synchronize the outer machine's `frozen` flag with the model's `frozen` field -# before each fit. The standard `fit_only!` checks `mach.frozen` at the top and -# returns immediately if true. Initial training (state == 0) always proceeds. -function MLJBase.fit!(mach::Machine{<:SomeFreezable}; kwargs...) - # Synchronize outer machine frozen flag: - if mach.model.frozen && mach.state > 0 - mach.frozen = true - else - mach.frozen = false - end - - # Delegate to the standard fit! logic: - glb_node = MLJBase.glb(mach.args...) - fit!(glb_node; kwargs...) - MLJBase.fit_only!(mach; kwargs...) -end - - # Forwarded methods: the wrapper should be transparent to per-fit operations # the inner model supports. diff --git a/src/composition/models/network_composite.jl b/src/composition/models/network_composite.jl index dcec4f83..923eeb1e 100644 --- a/src/composition/models/network_composite.jl +++ b/src/composition/models/network_composite.jl @@ -71,6 +71,14 @@ function MLJModelInterface.update( start_over = MLJBase.start_over(composite, old_composite, greatest_lower_bound) start_over && return MLJModelInterface.fit(composite, verbosity, data...) + # If the composite has frozen descendants, rebuild a fresh network on `data` and + # transfer trained state from the old inner machines for the frozen children. Fresh + # non-frozen children train from scratch; frozen children short-circuit via the + # A/B/C/D checks in `fit_only!` because their state survived the transfer. + if has_frozen_descendants(composite) + return _rebuild_preserving_frozen(composite, verbosity, fitresult, data...) + end + # retrain the network: fit!(greatest_lower_bound; verbosity, composite) @@ -82,6 +90,36 @@ function MLJModelInterface.update( return fitresult, cache, report end +function _rebuild_preserving_frozen(composite, verbosity, fitresult, data...) + old_glb = MLJBase.glb(fitresult) + old_machs = MLJBase.machines_given_model(old_glb) + + new_fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature + new_glb = MLJBase.glb(new_fitresult) + new_machs = MLJBase.machines_given_model(new_glb) + + for (sym, machs) in old_machs + sym in propertynames(composite) || continue + frozen(getproperty(composite, sym)) || continue + haskey(new_machs, sym) || continue + old_m = first(machs) + isdefined(old_m, :fitresult) || continue + new_m = first(new_machs[sym]) + new_m.fitresult = old_m.fitresult + isdefined(old_m, :cache) && (new_m.cache = old_m.cache) + isdefined(old_m, :report) && (new_m.report = old_m.report) + new_m.state = old_m.state + new_m.old_model = deepcopy(getproperty(composite, sym)) + new_m.old_upstream_state = MLJBase.upstream(new_m) + end + + fit!(new_glb; verbosity, composite) + + report = MLJBase.report(new_fitresult) + cache = deepcopy(composite) + return new_fitresult, cache, report +end + MLJModelInterface.fitted_params(composite::NetworkComposite, signature) = fitted_params(signature) diff --git a/src/machines.jl b/src/machines.jl index cb5efe9c..d6851517 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -105,6 +105,34 @@ more detail, see the discussion of training logic at [`fit_only!`](@ref). """ age(mach::Machine) = mach.state +""" + frozen(model) + +Return `true` if `model` has a `frozen` property set to `true`. `fit_only!` treats such a +model as a no-op once `age(mach) >= 1`. + +""" +frozen(model) = :frozen in propertynames(model) && getproperty(model, :frozen)::Bool + +""" + has_frozen_descendants(model) + +Return `true` if any field of `model` is itself a model with `frozen=true`, recursively. +`fit_only!` consults this on a row-change refit to keep the existing learning network in +place when a composite contains a frozen subcomponent. + +""" +function has_frozen_descendants(model) + model isa Model || return false + for name in propertynames(model) + child = getproperty(model, name) + child isa Model || continue + frozen(child) && return true + has_frozen_descendants(child) && return true + end + return false +end + """ replace(mach::Machine, field1 => value1, field2 => value2, ...) @@ -568,22 +596,33 @@ the true model given by `getproperty(composite, model)`. See also [`machine`](@r ### Training action logic -For the action to be a no-operation, either `mach.frozen == true` or -or none of the following apply: +The action is a no-operation precisely when one or more of the following apply: + +A. `mach.frozen == true`. + +B. `mach.model` is a `Model` (not a symbol), `frozen(mach.model) == true`, and + `age(mach) >= 1`. + +C. `mach.model` is a symbol, `frozen(getproperty(composite, mach.model)) == true`, and + `age(mach) >= 1`. + +D. None of the numbered conditions below apply. + +Otherwise the action is selected from the numbered conditions: 1. `mach` has never been trained (`mach.state == 0`). 2. `force == true`. -3. The `state` of some other machine on which `mach` depends has - changed since the last time `mach` was trained (ie, the last time - `mach.state` was last incremented). +3. The `state` of some other machine on which `mach` depends has changed since the last + time `mach` was trained. -4. The specified `rows` have changed since the last retraining and - `mach.model` does not have `Static` type. +4. The specified `rows` have changed since the last retraining and `mach.model` does not + have `Static` type. If the resolved model has frozen descendants, this routes to an + update instead of an ab initio fit, so the existing learning network survives. -5. `mach.model` is a `Model` (i.e, not a symbol) and is different from the last model used - for training (but has the same type). +5. `mach.model` is a `Model` (not a symbol) and is different from the last model used for + training (but has the same type). 6. `mach.model` is a `Model` but has a type different from the last model used for training. @@ -594,8 +633,8 @@ or none of the following apply: 8. `mach.model` is a symbol and `getproperty(composite, mach.model)` has a different type from the last model used for training. -In any of the cases (1) - (4), (6), or (8), `mach` is trained ab initio. -If (5) or (7) is true, then a training update is applied. +Cases (1), (2), (3), (6), (8), and (4) without frozen descendants train ab initio. +Cases (5), (7), and (4) with frozen descendants apply a training update. To freeze or unfreeze `mach`, use `freeze!(mach)` or `thaw!(mach)`. @@ -621,8 +660,8 @@ function fit_only!( composite=nothing, ) where cache_data + # Condition A: the machine itself is frozen. if mach.frozen - # no-op; do not increment `state`. fitlog(mach, :frozen, verbosity) return mach end @@ -633,8 +672,7 @@ function fit_only!( "cannot be trained. ") end - # If `mach.model` is a symbol, then we want to replace it with the bone fide model - # `getproperty(composite, mach.model)`: + # Resolve a symbolic model to the actual model living on the parent composite. model = if mach.model isa Symbol isnothing(composite) && throw(err_no_real_model(mach)) mach.model in propertynames(composite) || @@ -644,6 +682,13 @@ function fit_only!( mach.model end + # Conditions B and C: the resolved model is frozen and `mach` has been trained at + # least once. Treat it as a no-op without consulting any numbered condition below. + if frozen(model) && age(mach) >= 1 + fitlog(mach, :frozen, verbosity) + return mach + end + # neither `old_model` nor `model` are symbols here: modeltype_changed = !isdefined(mach, :old_model) ? true : typeof(model) === typeof(mach.old_model) ? false : @@ -660,6 +705,12 @@ function fit_only!( condition_4 = rows_is_new && !(mach.model isa Static) + # A composite with frozen descendants must not rebuild on a row change. Rebuilding + # would create fresh inner machines at `age = 0` and the freeze contract would break. + # Route condition (4) to the update path instead so the existing network survives. + rebuild_on_rows = condition_4 && !has_frozen_descendants(model) + update_on_rows = condition_4 && has_frozen_descendants(model) + upstream_has_changed = mach.old_upstream_state != upstream_state data_is_valid = isdefined(mach, :data) && !upstream_has_changed @@ -680,7 +731,7 @@ function fit_only!( if mach.state == 0 || # condition (1) force == true || # condition (2) upstream_has_changed || # condition (3) - condition_4 || # condition (4) + rebuild_on_rows || # condition (4) without frozen descendants modeltype_changed # conditions (6) or (7) isdefined(mach, :report) || (mach.report = LittleDict{Symbol,Any}()) @@ -709,7 +760,7 @@ function fit_only!( rethrow() end - elseif model != mach.old_model # condition (5) + elseif update_on_rows || model != mach.old_model # condition (5) or row-change update # update the model: fitlog(mach, :update, verbosity)