Skip to content

Adding Freezable wrapper for models#1048

Open
pebeto wants to merge 8 commits into
JuliaAI:devfrom
pebeto:dev
Open

Adding Freezable wrapper for models#1048
pebeto wants to merge 8 commits into
JuliaAI:devfrom
pebeto:dev

Conversation

@pebeto
Copy link
Copy Markdown
Member

@pebeto pebeto commented Apr 16, 2026

This PR implements the Freezable model wrapper (#1016) following the TransformedTargetModel pattern.

  • Four mutable structs via metaprogramming (FreezableDeterministic, FreezableProbabilistic, FreezableInterval, FreezableUnsupervised), each subtyping the corresponding NetworkComposite
  • Freezable(model; frozen=true, cache=true) constructor that dispatches on the model's abstract supertype
  • prefit methods for supervised (returns predict/transform nodes) and unsupervised (returns transform/inverse_transform nodes) models
  • Custom fit override: initial training always proceeds, then freezes the inner machine if frozen=true
  • Custom update override: synchronizes freeze!/thaw! on the inner machine before delegating to standard NetworkComposite logic
  • Custom fit! override on Machine{<:SomeFreezable}: synchronizes the outer machine's frozen flag so fit_only! skips when frozen=true and state > 0
  • freeze!/thaw! overloads for SomeFreezable models, consistent with the existing machine-level API
  • Trait delegation for MLJ model traits via metaprogramming

@pebeto pebeto requested a review from ablaom April 16, 2026 20:15
@pebeto pebeto self-assigned this Apr 16, 2026
@pebeto pebeto added the enhancement New feature or request label Apr 16, 2026
@ablaom
Copy link
Copy Markdown
Member

ablaom commented Apr 27, 2026

Started a review and playing around with the example. Very cool.

I think the current behaviour, where warnings are issued when you attempt to train a frozen machine, could get annoying. (For me this never came up before, for the machine-level freeze!/thaw! never seemed that convenient.)

julia> fit!(mach, rows=1:50, verbosity=0)
┌ Warning: machine(FreezableDeterministic(model = DecisionTreeRegressor(max_depth = -1, ), ), ) not trained as it is frozen.
└ @ MLJBase ~/MLJ/MLJBase/src/machines.jl:497
trained Machine; does not cache data
  model: FreezableDeterministic(model = DecisionTreeRegressor(max_depth = -1, ), )
  args: 
    1:  Source @444 ⏎ Table{AbstractVector{Continuous}}
    2:  Source @672 ⏎ AbstractVector{Continuous}

To freeze a machine requires intervention, so I would now regard this as deserving only of @info, like the other training logging. Like those, the logging should disappear when verbosity <= 0. I'd be happy to change this. What do you think?

The relevant code to update is here

edit This might break some tests of logging, but I wouldn't view this a breaking.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented Apr 28, 2026

There seems to be a problem revealed in Example 2, when I increase the verbosity to 2, I see that the scaler appears to be retrained the second time around:

julia> fit!(mach, rows=1:100, verbosity=2)        # both components train
[ Info: Training machine(ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], ), ), ), ).
[ Info: Training machine(:scaler, ).
[ Info: Training machine(:model, ).
[ Info: Features standarized: 
[ Info:   :x1    mu=-2.4868861209579043  sigma=6.1068776340664686
[ Info:   :x2    mu=-2.336433496692005  sigma=6.970932861948223
[ Info: Training machine(:clf, ).
Feature 1 < 0.9348 ?
├─ Feature 2 < -0.7489 ?
    ├─ Feature 1 < -0.717 ?
        ├─ Feature 2 < -0.7979 ?
            ├─ 3 : 2/2
            └─ 1 : 1/1
        └─ 1 : 19/19
    └─ Feature 1 < -0.6649 ?
        ├─ 3 : 31/31
        └─ Feature 2 < -0.6387 ?
            ├─ 1 : 3/3
            └─ Feature 1 < 0.1482 ?
                ├─ 
                └─ 1 : 1/1
└─ 2 : 30/30
trained Machine; does not cache data
  model: ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], ), ), )
  args: 
    1:  Source @688 ⏎ Table{AbstractVector{Continuous}}
    2:  Source @846 ⏎ AbstractVector{Multiclass{3}}


julia> fit!(mach, rows=101:200, verbosity=2)      # only clf retrains; scaler is frozen
[ Info: Training machine(ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], ), ), ), ).
[ Info: Training machine(:scaler, ).
[ Info: Training machine(:model, ).
[ Info: Features standarized: 
[ Info:   :x1    mu=-1.6075949590445386  sigma=6.245230837501581
[ Info:   :x2    mu=-1.766075605094719  sigma=7.420282309033924
[ Info: Training machine(:clf, ).
Feature 2 < 0.03832 ?
├─ Feature 2 < -0.6508 ?
    ├─ Feature 2 < -0.8413 ?
        ├─ 1 : 20/20
        └─ Feature 1 < -0.7213 ?
            ├─ Feature 2 < -0.831 ?
                ├─ 3 : 2/2
                └─ 
            └─ Feature 1 < -0.513 ?
                ├─ 1 : 11/11
                └─ 
    └─ 3 : 17/17
└─ 2 : 36/36
trained Machine; does not cache data
  model: ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], ), ), )
  args: 
    1:  Source @688 ⏎ Table{AbstractVector{Continuous}}
    2:  Source @846 ⏎ AbstractVector{Multiclass{3}}

By the way, there some test tooling that we used to test this kind of logic - the @test_mach_sequence macro, which is in test/test_utilities.jl. Unfortunately, it's undocumented, although there are lots of examples in test/composition/nodes.jl. This might be useful for tests to catch issues like the above.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented Apr 28, 2026

Note to self: I have only scanned the type setup, reviewed the docstring, and tried out the examples there. There will be some issues to explain about wrapping supervised models with transform, etc, and probably some missing traits to forward to the wrapper (feature_importances ??).

Copy link
Copy Markdown
Member

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valuable contribution and will have required considerable effort to familiarise yourself with the composition codebase, the most challenging part of MLJ. This is enormously appreciated.

I'm pausing my review for now. I think there is a bit to sort out with the Example 2 issue.

Comment thread src/composition/models/freezable.jl Outdated
Comment thread src/composition/models/freezable.jl
Comment thread src/composition/models/freezable.jl Outdated
Comment thread src/composition/models/freezable.jl Outdated
Comment thread src/composition/models/freezable.jl
@ablaom
Copy link
Copy Markdown
Member

ablaom commented Apr 28, 2026

The docstring is very clear and concise, by the way.

pebeto and others added 5 commits April 28, 2026 09:51
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented Apr 28, 2026

Some ideas:

  • The frozen-machine warning is a straightforward change to perform, so it will be easy to implement.
  • The idea behind training_losses and feature_importances is to mirror the pattern in transformed_target_model.jl but for the special freezable types

I will take some time to review how to solve the retraining bug.

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 3, 2026

@ablaom, the root cause of the retraining bug was related to when the parent pipeline's outer machine sees new rows, condition_4 in fit_only! triggers a full network rebuild via MMI.fit(::Pipeline), which creates fresh inner machines. The old MMI.fit(::SomeFreezable) always trained, then froze.

A solution was added in the last commit. However, it's adding more information to the model.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 3, 2026

Thanks @pebeto for looking into the fail.

I have to say I'm not too keen on the workaround. We have come to embrace the general principle that an MLJ model (configuration struct) is never mutated. (We have allowed this in the special case of RNGs, but in retrospect I think this was a mistake and in newer models we encourage fit to make copies of RNG's before using them.). I haven't thought through all the consequences but my guess is this design choice will have unintended consequences.

Is there no way you can pass the information you need in the cache or fitresult output of the model's fit and update methods (in addition to that of the atomic model)? This is the normal way to handle model "state", ie., externally.

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 3, 2026

@ablaom I previously tried fitresult/cache but they don't survive the rebuild.

  1. When the outer pipeline machine sees new rows, it rebuilds the entire network from scratch
  2. The old inner :scaler machine holding the wrapper's fitresult and cache is destroyed
  3. A fresh inner machine is set at state == 0, which routes straight to fit (not update), so the wrapper's previous output is gone before any code can read it.

The only information persisted across the rebuild is the model instance, which is why I ended up saving the information there.

Another idea is to modify fit_only! to make composites with frozen children to follow update instead of rebuild (then fitresult/cache can survive). I don't know how critical it could be since fit_only! is used widely.

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 3, 2026

In condition_4, we could add a check to avoid rebuild from scratch when rows change.

(condition_4 passing means full rebuild)

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 3, 2026

I'm not sure I understand. Do you mean if model isa Freezable then condition 4 is modified? Certainly we do want retraining from scratch otherwise for row changes.

Is the problem that you can see no way to use set/unset the frozen flag of the atomic model's machine?

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 3, 2026

okay, i missed your earlier clarification. let me think about this and get back to you.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 4, 2026

Okay, I agree that, not wanting to mutate model structs, we need to break into the fit_only! logic. Thank you for explaining this to me.

I would make two main points:

  1. I think we should now proceed with a slightly more general idea, namely that freezability applies to any model that has frozen as a property (:frozen in propertynames(model)). In that way, we can eventually make the wrapper unnecessary for common use cases, such as TunedModel, by adding frozen as a new property. This is convenient, because we can freeze by just changing an existing hyperparameter, without wrapping, which may require rebuilding an already trained pipeline, etc.

  2. Next, regarding the modified logic, I'm not sure I agree we just modify condition 4. Let me propose a different change, which I find easier to reason about.

Here's the logic we have now:

For the action to be a no-operation, either `mach.frozen == true` or
or none of the following apply:

1. `mach` has never been trained (`mach.state == 0`).

2. `force == true`.

3. The `state` of some other machine on which `mach` depends has
   changed since the last time `mach` was trained (ie, the last time
   `mach.state` was last incremented).

4. The specified `rows` have changed since the last retraining and
   `mach.model` does not have `Static` type.

5. `mach.model` is a `Model` (i.e, not a symbol) and is different from the last model used
   for training (but has the same type).

6. `mach.model` is a `Model` but has a type different from the last model used for
   training.

7. `mach.model` is a symbol and `getproperty(composite, mach.model)` is different from the
   last model used for training (but has the same type).

8. `mach.model` is a symbol and `getproperty(composite, mach.model)` has a different type
   from the last model used for training.

In any of the cases (1) - (4), (6), or (8), `mach` is trained ab initio.
If (5) or (7) is true, then a training update is applied.

Let's say we define a new private method frozen(model) = :frozen in propertynames(model) && getproperty(model, :frozen).

Then we replace the first phrase, "For the action to be a no-operation, either mach.frozen == true or
or none of the following apply:" with:

The action will be a no-operation precisely when one or more of the following apply:

A.  `mach.frozen == true` .

B. `mach.model` is a `Model` (i.e., not a symbol), `frozen(mach.model)` is `true`, and `age(mach) >= 1`.

C. `mach.model` is a symbol, `frozen(getproperty(composite, mach.model))` is `true`, and `age(mach) >= 1`.

D. None of the numbered conditions below apply.

What do you think?

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 5, 2026

@ablaom point 1 seems okay, since generalizing freezability to any model is cleaner and lets future models to opt in by adding the frozen property.

In point 2 we are still facing the issue when the inner :scaler machine is created fresh by prefit during the pipeline rebuild (state == 0 and age(mach) < 1). Conditions B and C won't fire and the fresh inner machine will be trained anyway.

I think we still need to suppress condition 4.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 8, 2026

Okay, I think it's just too long since I was in the weeds with this composite business, and I appreciate your patience. Your suggestions are probably right, but it would be neglectful of me not to make sure I understand this.

Can you help be better understand what you mean by "during the pipeline rebuild". I'm looking at update(::NetworkComposite, ...) (the operative update for a pipeline, no?) and I don't see "rebuilding" happening there. What am I missing? Why shouldn't B and C fire?

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 8, 2026

By rebuild I mean the call to prefit inside MMI.fit(::NetworkComposite). That's where a fresh Signature with fresh inner machines is built.

network_compose.jl:26

fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature

update(::NetworkComposite) isn't reached because when the outer pipeline machine sees new rows, condition 4 in fit_only! fires and routes to fit (not update). MMI.fit(::Pipeline) runs, calls prefit, and builds the network from scratch. The inner :scaler machine that comes out is state == 0 and age == 0, that's why B and C don't fire.

machines.jl:fit_only!

condition_4 = rows_is_new && !(mach.model isa Static)

if mach.state == 0 ||       # condition (1)
    force == true ||        # condition (2)
    upstream_has_changed || # condition (3)
    condition_4 ||          # condition (4)
    modeltype_changed       # conditions (6) or (7)
   # fit path
else if model != mach.old_model # condition (5)
   # update path

The chain for the Example 2 is (in order):

  1. Condition 4 fires
  2. Fit path is picked
  3. MMI.fit(::NetworkComposite) called
  4. Prefit rebuilds
  5. Trained inner machines are gone
  6. Fresh inner :scaler mach has state=0
  7. Its own fit_only! takes the fit path
  8. MMI.fit(::FreezableUnsupervised) retrains the standardizer

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 8, 2026

In this gist, I wrote my perspective of how MLJ training works based on what I observed while doing this feature.

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 12, 2026

Thanks for your detailed update. I think my confusion about "rebuilding" is a red herring. I suspect I have not conveyed my proposal, clearly. What I'm suggesting is that the test for A, B, C, D appears at the very top of fit_only! (at the moment we just check A). If any of these hold, we immediately commit to a no-operation. (No fit or update, as we have presently, when A fires.) In this scenario, we never get to condition (iv) as it has been deemed irrelevant.

@pebeto
Copy link
Copy Markdown
Member Author

pebeto commented May 26, 2026

@ablaom rebuilding is the problem.

  • Outer pipeline mach: A/B/C don't apply (Pipeline has no frozen property). D fails because condition (iv) fires (rows changed). Falls through, MMI.fit(::Pipeline) runs, prefit rebuilds the network with fresh inner machines at state = 0.
  • New inner :scaler: A, B don't apply. C fails because age == 0 (just created). D fails because condition (i) fires. Trains.

The rebuild zeroes the age, meaning that C never catches it. To make C fire, keep the original :scaler (state = 1) in the Signature. That means suppressing condition (iv) on the outer mach so it makes the update path.

To make things clearer, I'm going to implement the changes in the PR.

Comment on lines +68 to +69
Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when
Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside
a `Pipeline`, `Stack`, `TunedModel`, or other `NetworkComposite` model, and the
inner component skips retraining even when


# If the composite has frozen descendants, rebuild a fresh network on `data` and
# transfer trained state from the old inner machines for the frozen children. Fresh
# non-frozen children train from scratch; frozen children short-circuit via the
Copy link
Copy Markdown
Member

@ablaom ablaom May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"train from scratch" is too strong. It may be that some non-frozen components only need an update; see my comment below on the example of a RandomForest warm restart.

new_m.old_model = deepcopy(getproperty(composite, sym))
new_m.old_upstream_state = MLJBase.upstream(new_m)
end

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not sufficient to update only the first machine in the value associated with a given sym. You need to update all of them.

The basic example of multiple machines associated with a single model is a homogeneous ensemble, such as a random forest. Since a single model controls the whole ensemble, freezing the atomic model aught to freeze the ensemble of associated machines. (Actually, we don't implement EnsmbleModel using NetworkComposite, so I can point to a concrete example off the top of my head, sorry.)

@ablaom
Copy link
Copy Markdown
Member

ablaom commented May 28, 2026

Ahhhh. Finally, I understand your main point. Thanks for your patient explanations, and for the explicit implementation of your new idea.

There remains an issue, which is that the presence of a freezable component in a composite model now precludes warm restart (update) of any other component. I've added a comment at the relevant point in the code. Here's a MWE to demonstrate the problem:

using MLJBase, MLJModels, MLJTransforms

forest = (@load RandomForestRegressor pkg=DecisionTree)()

X, y = @load_boston

# demonstration of warm restart:
mach = machine(forest, X, y)
fit!(mach)
# [ Info: Training machine(RandomForestRegressor(max_depth = -1, …), …).
# trained Machine; caches model-specific representations of data
forest.n_trees += 10
fit!(mach)
# [ Info: Updating machine(RandomForestRegressor(max_depth = -1, …), …).
# [ Info: Adding 10 trees to the ensemble.    <----  WARM  RESTART 

# put a forest in a pipeline with a freezable scaler:
pipe = Freezable(Standardizer()) |> forest

# initial train:
mach = machine(pipe, X, y)
fit!(mach, verbosity=2)
# [ Info: Training machine(DeterministicPipeline(freezable_unsupervised = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
# [ Info: Training machine(:freezable_unsupervised, …).
# [ Info: Training machine(:model, …).
# [ Info: Features standarized:
# [ Info:   :Crim    mu=3.6135235573122535  sigma=8.60154510533249
# [ Info:   :Zn    mu=11.363636363636363  sigma=23.322452994515135
# [ Info:   :Indus    mu=11.136778656126486  sigma=6.8603529408975845
# [ Info:   :NOx    mu=0.5546950592885372  sigma=0.11587767566755595
# [ Info:   :Rm    mu=6.284634387351778  sigma=0.7026171434153233
# [ Info:   :Age    mu=68.57490118577076  sigma=28.148861406903617
# [ Info:   :Dis    mu=3.795042687747036  sigma=2.105710126627611
# [ Info:   :Rad    mu=9.549407114624506  sigma=8.707259384239368
# [ Info:   :Tax    mu=408.2371541501976  sigma=168.537116054959
# [ Info:   :PTRatio    mu=18.45553359683795  sigma=2.1649455237144406
# [ Info:   :Black    mu=356.67403162055325  sigma=91.29486438415783
# [ Info:   :LStat    mu=12.653063241106722  sigma=7.141061511348569
# [ Info: Training machine(:random_forest_regressor, …).
# trained Machine; does not cache data

# add trees and train again; expect a warm restart of the forest here, but instead the
# forest is retrained from scratch (while the scaler is correctly ignored):
forest.n_trees += 10
fit!(mach, verbosity=2)
# [ Info: Updating machine(DeterministicPipeline(freezable_unsupervised = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
# [ Info: Not retraining machine(:freezable_unsupervised, …) as it is frozen.
# [ Info: Training machine(:random_forest_regressor, …).  <-- COLD RESTART
# trained Machine; does not cache data

Without the Freezable wrapper here, you get a warm restart of the forest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants