Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ include("composition/models/network_composite_types.jl")
include("composition/models/network_composite.jl")
include("composition/models/pipelines.jl")
include("composition/models/transformed_target_model.jl")
include("composition/models/freezable.jl")

include("operations.jl")

Expand Down Expand Up @@ -303,6 +304,7 @@ export machines, sources, Stack,
StaticPipeline, IntervalPipeline

export TransformedTargetModel
export Freezable

# resampling.jl:
export ResamplingStrategy, InSample, Holdout, CV, StratifiedCV, TimeSeriesCV,
Expand Down
234 changes: 234 additions & 0 deletions src/composition/models/freezable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
const FREEZABLE_SUPPORTED_ATOMS = (
:Deterministic,
:Probabilistic,
:Interval,
:Unsupervised,
)

# Each supported atomic type gets its own wrapper which must have appropriate supertype:
const FREEZABLE_TYPE_GIVEN_ATOM =
Dict(atom =>
Symbol("Freezable$atom") for atom in FREEZABLE_SUPPORTED_ATOMS)
const FREEZABLE_SUPER_GIVEN_ATOM =
Dict(atom =>
Symbol("$(atom)NetworkComposite") for atom in FREEZABLE_SUPPORTED_ATOMS)

# Type definitions:
for From in FREEZABLE_SUPPORTED_ATOMS
New = FREEZABLE_TYPE_GIVEN_ATOM[From]
To = FREEZABLE_SUPER_GIVEN_ATOM[From]
ex = quote
mutable struct $New{M <: $From} <: $To
model::M
frozen::Bool
cache::Bool
end
end
eval(ex)
end

# dict whose keys and values are now types instead of symbols:
const freezable_type_given_atom = Dict()
for atom in FREEZABLE_SUPPORTED_ATOMS
atom_str = string(atom)
type = FREEZABLE_TYPE_GIVEN_ATOM[atom]
@eval(freezable_type_given_atom[$atom] = $type)
end

# not exported:
const FREEZABLE_TYPES = values(freezable_type_given_atom)
const FREEZABLE_TYPE_EXS = values(FREEZABLE_TYPE_GIVEN_ATOM)
const SomeFreezable = Union{FREEZABLE_TYPES...}
const SupervisedFreezable = Union{
freezable_type_given_atom[Deterministic],
freezable_type_given_atom[Probabilistic],
freezable_type_given_atom[Interval],
}
const FreezableSupported = Union{keys(freezable_type_given_atom)...}

const ERR_FREEZABLE_MODEL_UNSPECIFIED = ArgumentError(
"Expecting atomic model as argument. None specified."
)
const ERR_FREEZABLE_TOO_MANY_ARGUMENTS = ArgumentError(
"At most one non-keyword argument, a model, allowed."
)
const PRETTY_FREEZABLE_SUPPORT_OPTIONS =
join([string("`", opt, "`") for opt in FREEZABLE_SUPPORTED_ATOMS],
", ",
", and ")
const err_freezable_unsupported(model) = ArgumentError(
"Only these model supertypes support `Freezable` wrapping: "*
"$PRETTY_FREEZABLE_SUPPORT_OPTIONS. "*
"Model provided has type `$(typeof(model))`."
)

"""
Freezable(model; frozen=true, cache=true)

Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when
Comment on lines +68 to +69
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when
Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, `TunedModel`, or other `NetworkComposite` model, and the
inner component skips retraining even when

the parent rebuilds its learning network on a row change.

Set `frozen=false` to allow normal retraining. Use [`freeze!`](@ref) and [`thaw!`](@ref)
to toggle after construction. Set `cache=false` to prioritize memory over speed.

### Example 1: Freezing a single model

This example and the next assume you have `MLJDecisionTreeInterface` in your environment.

```julia
using MLJ # or `using MLJBase, MLJModels`

X, y = make_regression(100)
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree

model = Freezable(DecisionTreeRegressor()) # frozen=true by default
mach = machine(model, X, y)

fit!(mach) # first fit trains
fit!(mach, rows=1:50) # no-op while frozen
thaw!(model)
fit!(mach, rows=1:50) # retrains
```

### Example 2: Freezing a component inside a pipeline

```julia
using MLJ # or `using MLJBase, MLJModels, MLJTransforms`

X, y = make_blobs(200)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

pipe = Pipeline(
scaler = Freezable(Standardizer()),
clf = DecisionTreeClassifier(),
)
mach = machine(pipe, X, y)
fit!(mach, rows=1:100) # both components train
fit!(mach, rows=101:200) # only clf retrains; scaler is frozen
```

"""
function Freezable(
args...;
model=nothing,
frozen::Bool=true,
cache::Bool=true,
)
length(args) < 2 || throw(ERR_FREEZABLE_TOO_MANY_ARGUMENTS)

if length(args) === 1
atom = first(args)
model === nothing ||
@warn "Using `model=$atom`. Ignoring specification `model=$model`. "
else
model === nothing && throw(ERR_FREEZABLE_MODEL_UNSPECIFIED)
atom = model
end
atom isa FreezableSupported || throw(err_freezable_unsupported(atom))

abstract_atom = MMI.abstract_type(atom)
haskey(freezable_type_given_atom, abstract_atom) ||
throw(err_freezable_unsupported(atom))

metamodel =
freezable_type_given_atom[abstract_atom](atom,
frozen,
cache)
message = clean!(metamodel)
isempty(message) || @warn message
return metamodel
end

function clean!(model::SomeFreezable)
message = ""
return message
end

"""
freeze!(model::SomeFreezable)

Set `model.frozen = true`. After the first training pass, subsequent `fit!` calls on a
machine wrapping this model become no-ops.

See also [`thaw!`](@ref).
"""
freeze!(model::SomeFreezable) = (model.frozen = true; model)

"""
thaw!(model::SomeFreezable)

Set `model.frozen = false`. The next `fit!` call on a machine wrapping this model
retrains.

See also [`freeze!`](@ref).
"""
thaw!(model::SomeFreezable) = (model.frozen = false; model)


# Prefit methods
function prefit(model::SupervisedFreezable, verbosity, X, y)
Xs = source(X)
ys = source(y)
mach = machine(:model, Xs, ys; cache=model.cache)
(predict=predict(mach, Xs), transform=transform(mach, Xs))
end

function prefit(model::FreezableUnsupervised, verbosity, X)
Xs = source(X)
mach = machine(:model, Xs; cache=model.cache)
(transform=transform(mach, Xs), inverse_transform=inverse_transform(mach, Xs))
end

# Forwarded methods: the wrapper should be transparent to per-fit operations
# the inner model supports.

const ERR_FREEZABLE_MISSING_REPORT =
"Cannot find report for the atomic model wrapped by `Freezable`. "

function MMI.training_losses(composite::SomeFreezable, freezable_report)
hasproperty(freezable_report, :model) || throw(ERR_FREEZABLE_MISSING_REPORT)
atomic_report = getproperty(freezable_report, :model)
return training_losses(composite.model, atomic_report)
end

function MMI.feature_importances(composite::SupervisedFreezable, fitresult, report)
predict_node = fitresult.interface.predict
mach = only(MLJBase.machines_given_model(predict_node)[:model])
return feature_importances(composite.model, mach.fitresult, mach.report[:fit])
end


# Model traits
MMI.package_name(::Type{<:SomeFreezable}) = "MLJBase"
MMI.is_wrapper(::Type{<:SomeFreezable}) = true
MMI.load_path(::Type{<:SomeFreezable}) = "MLJBase.Freezable"
MMI.constructor(::Type{<:SomeFreezable}) = Freezable

for New in FREEZABLE_TYPE_EXS
quote
MMI.iteration_parameter(::Type{<:$New{M}}) where M =
MLJBase.prepend(:model, iteration_parameter(M))
end |> eval
for trait in [
:input_scitype,
:output_scitype,
:target_scitype,
:fit_data_scitype,
:predict_scitype,
:transform_scitype,
:inverse_transform_scitype,
:is_pure_julia,
:supports_weights,
:supports_class_weights,
:supports_online,
:supports_training_losses,
:reports_feature_importances,
:is_supervised,
:prediction_type,
]
quote
MMI.$trait(::Type{<:$New{M}}) where M = MMI.$trait(M)
end |> eval
end
end
38 changes: 38 additions & 0 deletions src/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ function MLJModelInterface.update(
start_over = MLJBase.start_over(composite, old_composite, greatest_lower_bound)
start_over && return MLJModelInterface.fit(composite, verbosity, data...)

# If the composite has frozen descendants, rebuild a fresh network on `data` and
# transfer trained state from the old inner machines for the frozen children. Fresh
# non-frozen children train from scratch; frozen children short-circuit via the
Copy link
Copy Markdown
Member

@ablaom ablaom May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"train from scratch" is too strong. It may be that some non-frozen components only need an update; see my comment below on the example of a RandomForest warm restart.

# A/B/C/D checks in `fit_only!` because their state survived the transfer.
if has_frozen_descendants(composite)
return _rebuild_preserving_frozen(composite, verbosity, fitresult, data...)
end

# retrain the network:
fit!(greatest_lower_bound; verbosity, composite)

Expand All @@ -82,6 +90,36 @@ function MLJModelInterface.update(
return fitresult, cache, report
end

function _rebuild_preserving_frozen(composite, verbosity, fitresult, data...)
old_glb = MLJBase.glb(fitresult)
old_machs = MLJBase.machines_given_model(old_glb)

new_fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature
new_glb = MLJBase.glb(new_fitresult)
new_machs = MLJBase.machines_given_model(new_glb)

for (sym, machs) in old_machs
sym in propertynames(composite) || continue
frozen(getproperty(composite, sym)) || continue
haskey(new_machs, sym) || continue
old_m = first(machs)
isdefined(old_m, :fitresult) || continue
new_m = first(new_machs[sym])
new_m.fitresult = old_m.fitresult
isdefined(old_m, :cache) && (new_m.cache = old_m.cache)
isdefined(old_m, :report) && (new_m.report = old_m.report)
new_m.state = old_m.state
new_m.old_model = deepcopy(getproperty(composite, sym))
new_m.old_upstream_state = MLJBase.upstream(new_m)
end

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not sufficient to update only the first machine in the value associated with a given sym. You need to update all of them.

The basic example of multiple machines associated with a single model is a homogeneous ensemble, such as a random forest. Since a single model controls the whole ensemble, freezing the atomic model aught to freeze the ensemble of associated machines. (Actually, we don't implement EnsmbleModel using NetworkComposite, so I can point to a concrete example off the top of my head, sorry.)

fit!(new_glb; verbosity, composite)

report = MLJBase.report(new_fitresult)
cache = deepcopy(composite)
return new_fitresult, cache, report
end

MLJModelInterface.fitted_params(composite::NetworkComposite, signature) =
fitted_params(signature)

Expand Down
Loading
Loading