Adding Freezable wrapper for models#1048
Conversation
|
Started a review and playing around with the example. Very cool. I think the current behaviour, where warnings are issued when you attempt to train a frozen machine, could get annoying. (For me this never came up before, for the machine-level julia> fit!(mach, rows=1:50, verbosity=0)
┌ Warning: machine(FreezableDeterministic(model = DecisionTreeRegressor(max_depth = -1, …), …), …) not trained as it is frozen.
└ @ MLJBase ~/MLJ/MLJBase/src/machines.jl:497
trained Machine; does not cache data
model: FreezableDeterministic(model = DecisionTreeRegressor(max_depth = -1, …), …)
args:
1: Source @444 ⏎ Table{AbstractVector{Continuous}}
2: Source @672 ⏎ AbstractVector{Continuous}To freeze a machine requires intervention, so I would now regard this as deserving only of The relevant code to update is here edit This might break some tests of logging, but I wouldn't view this a breaking. |
|
There seems to be a problem revealed in Example 2, when I increase the verbosity to 2, I see that the scaler appears to be retrained the second time around: julia> fit!(mach, rows=1:100, verbosity=2) # both components train
[ Info: Training machine(ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
[ Info: Training machine(:scaler, …).
[ Info: Training machine(:model, …).
[ Info: Features standarized:
[ Info: :x1 mu=-2.4868861209579043 sigma=6.1068776340664686
[ Info: :x2 mu=-2.336433496692005 sigma=6.970932861948223
[ Info: Training machine(:clf, …).
Feature 1 < 0.9348 ?
├─ Feature 2 < -0.7489 ?
├─ Feature 1 < -0.717 ?
├─ Feature 2 < -0.7979 ?
├─ 3 : 2/2
└─ 1 : 1/1
└─ 1 : 19/19
└─ Feature 1 < -0.6649 ?
├─ 3 : 31/31
└─ Feature 2 < -0.6387 ?
├─ 1 : 3/3
└─ Feature 1 < 0.1482 ?
├─
└─ 1 : 1/1
└─ 2 : 30/30
trained Machine; does not cache data
model: ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …)
args:
1: Source @688 ⏎ Table{AbstractVector{Continuous}}
2: Source @846 ⏎ AbstractVector{Multiclass{3}}
julia> fit!(mach, rows=101:200, verbosity=2) # only clf retrains; scaler is frozen
[ Info: Training machine(ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
[ Info: Training machine(:scaler, …).
[ Info: Training machine(:model, …).
[ Info: Features standarized:
[ Info: :x1 mu=-1.6075949590445386 sigma=6.245230837501581
[ Info: :x2 mu=-1.766075605094719 sigma=7.420282309033924
[ Info: Training machine(:clf, …).
Feature 2 < 0.03832 ?
├─ Feature 2 < -0.6508 ?
├─ Feature 2 < -0.8413 ?
├─ 1 : 20/20
└─ Feature 1 < -0.7213 ?
├─ Feature 2 < -0.831 ?
├─ 3 : 2/2
└─
└─ Feature 1 < -0.513 ?
├─ 1 : 11/11
└─
└─ 3 : 17/17
└─ 2 : 36/36
trained Machine; does not cache data
model: ProbabilisticPipeline(scaler = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …)
args:
1: Source @688 ⏎ Table{AbstractVector{Continuous}}
2: Source @846 ⏎ AbstractVector{Multiclass{3}}By the way, there some test tooling that we used to test this kind of logic - the |
|
Note to self: I have only scanned the type setup, reviewed the docstring, and tried out the examples there. There will be some issues to explain about wrapping supervised models with |
ablaom
left a comment
There was a problem hiding this comment.
This is a valuable contribution and will have required considerable effort to familiarise yourself with the composition codebase, the most challenging part of MLJ. This is enormously appreciated.
I'm pausing my review for now. I think there is a bit to sort out with the Example 2 issue.
|
The docstring is very clear and concise, by the way. |
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
|
Some ideas:
I will take some time to review how to solve the retraining bug. |
|
@ablaom, the root cause of the retraining bug was related to when the parent pipeline's outer machine sees new rows, A solution was added in the last commit. However, it's adding more information to the model. |
|
Thanks @pebeto for looking into the fail. I have to say I'm not too keen on the workaround. We have come to embrace the general principle that an MLJ model (configuration struct) is never mutated. (We have allowed this in the special case of RNGs, but in retrospect I think this was a mistake and in newer models we encourage Is there no way you can pass the information you need in the |
|
@ablaom I previously tried
The only information persisted across the rebuild is the model instance, which is why I ended up saving the information there. Another idea is to modify |
|
In ( |
|
I'm not sure I understand. Do you mean if Is the problem that you can see no way to use set/unset the |
|
okay, i missed your earlier clarification. let me think about this and get back to you. |
|
Okay, I agree that, not wanting to mutate model structs, we need to break into the I would make two main points:
Here's the logic we have now: Let's say we define a new private method Then we replace the first phrase, "For the action to be a no-operation, either What do you think? |
|
@ablaom point 1 seems okay, since generalizing freezability to any model is cleaner and lets future models to opt in by adding the In point 2 we are still facing the issue when the inner I think we still need to suppress condition 4. |
|
Okay, I think it's just too long since I was in the weeds with this composite business, and I appreciate your patience. Your suggestions are probably right, but it would be neglectful of me not to make sure I understand this. Can you help be better understand what you mean by "during the pipeline rebuild". I'm looking at |
|
By rebuild I mean the call to fitresult = prefit(composite, verbosity, data...) |> MLJBase.Signature
condition_4 = rows_is_new && !(mach.model isa Static)
if mach.state == 0 || # condition (1)
force == true || # condition (2)
upstream_has_changed || # condition (3)
condition_4 || # condition (4)
modeltype_changed # conditions (6) or (7)
# fit path
else if model != mach.old_model # condition (5)
# update pathThe chain for the Example 2 is (in order):
|
|
In this gist, I wrote my perspective of how MLJ training works based on what I observed while doing this feature. |
|
Thanks for your detailed update. I think my confusion about "rebuilding" is a red herring. I suspect I have not conveyed my proposal, clearly. What I'm suggesting is that the test for A, B, C, D appears at the very top of |
|
@ablaom rebuilding is the problem.
The rebuild zeroes the age, meaning that C never catches it. To make C fire, keep the original To make things clearer, I'm going to implement the changes in the PR. |
… in `NetworkComposite`
| Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside | ||
| a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when |
There was a problem hiding this comment.
| Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside | |
| a `Pipeline`, `Stack`, or `TunedModel` and the inner component skips retraining even when | |
| Wrap `model` so `fit!` is a no-op after the first training pass. Place the wrapper inside | |
| a `Pipeline`, `Stack`, `TunedModel`, or other `NetworkComposite` model, and the | |
| inner component skips retraining even when |
|
|
||
| # If the composite has frozen descendants, rebuild a fresh network on `data` and | ||
| # transfer trained state from the old inner machines for the frozen children. Fresh | ||
| # non-frozen children train from scratch; frozen children short-circuit via the |
There was a problem hiding this comment.
"train from scratch" is too strong. It may be that some non-frozen components only need an update; see my comment below on the example of a RandomForest warm restart.
| new_m.old_model = deepcopy(getproperty(composite, sym)) | ||
| new_m.old_upstream_state = MLJBase.upstream(new_m) | ||
| end | ||
|
|
There was a problem hiding this comment.
It's not sufficient to update only the first machine in the value associated with a given sym. You need to update all of them.
The basic example of multiple machines associated with a single model is a homogeneous ensemble, such as a random forest. Since a single model controls the whole ensemble, freezing the atomic model aught to freeze the ensemble of associated machines. (Actually, we don't implement EnsmbleModel using NetworkComposite, so I can point to a concrete example off the top of my head, sorry.)
|
Ahhhh. Finally, I understand your main point. Thanks for your patient explanations, and for the explicit implementation of your new idea. There remains an issue, which is that the presence of a freezable component in a composite model now precludes warm restart ( using MLJBase, MLJModels, MLJTransforms
forest = (@load RandomForestRegressor pkg=DecisionTree)()
X, y = @load_boston
# demonstration of warm restart:
mach = machine(forest, X, y)
fit!(mach)
# [ Info: Training machine(RandomForestRegressor(max_depth = -1, …), …).
# trained Machine; caches model-specific representations of data
forest.n_trees += 10
fit!(mach)
# [ Info: Updating machine(RandomForestRegressor(max_depth = -1, …), …).
# [ Info: Adding 10 trees to the ensemble. <---- WARM RESTART
# put a forest in a pipeline with a freezable scaler:
pipe = Freezable(Standardizer()) |> forest
# initial train:
mach = machine(pipe, X, y)
fit!(mach, verbosity=2)
# [ Info: Training machine(DeterministicPipeline(freezable_unsupervised = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
# [ Info: Training machine(:freezable_unsupervised, …).
# [ Info: Training machine(:model, …).
# [ Info: Features standarized:
# [ Info: :Crim mu=3.6135235573122535 sigma=8.60154510533249
# [ Info: :Zn mu=11.363636363636363 sigma=23.322452994515135
# [ Info: :Indus mu=11.136778656126486 sigma=6.8603529408975845
# [ Info: :NOx mu=0.5546950592885372 sigma=0.11587767566755595
# [ Info: :Rm mu=6.284634387351778 sigma=0.7026171434153233
# [ Info: :Age mu=68.57490118577076 sigma=28.148861406903617
# [ Info: :Dis mu=3.795042687747036 sigma=2.105710126627611
# [ Info: :Rad mu=9.549407114624506 sigma=8.707259384239368
# [ Info: :Tax mu=408.2371541501976 sigma=168.537116054959
# [ Info: :PTRatio mu=18.45553359683795 sigma=2.1649455237144406
# [ Info: :Black mu=356.67403162055325 sigma=91.29486438415783
# [ Info: :LStat mu=12.653063241106722 sigma=7.141061511348569
# [ Info: Training machine(:random_forest_regressor, …).
# trained Machine; does not cache data
# add trees and train again; expect a warm restart of the forest here, but instead the
# forest is retrained from scratch (while the scaler is correctly ignored):
forest.n_trees += 10
fit!(mach, verbosity=2)
# [ Info: Updating machine(DeterministicPipeline(freezable_unsupervised = FreezableUnsupervised(model = Standardizer(features = Symbol[], …), …), …), …).
# [ Info: Not retraining machine(:freezable_unsupervised, …) as it is frozen.
# [ Info: Training machine(:random_forest_regressor, …). <-- COLD RESTART
# trained Machine; does not cache dataWithout the |
This PR implements the
Freezablemodel wrapper (#1016) following theTransformedTargetModelpattern.FreezableDeterministic,FreezableProbabilistic,FreezableInterval,FreezableUnsupervised), each subtyping the corresponding NetworkCompositeFreezable(model; frozen=true, cache=true)constructor that dispatches on the model's abstract supertypeprefitmethods for supervised (returnspredict/transformnodes) and unsupervised (returnstransform/inverse_transformnodes) modelsfitoverride: initial training always proceeds, then freezes the inner machine iffrozen=trueupdateoverride: synchronizesfreeze!/thaw!on the inner machine before delegating to standard NetworkComposite logicfit!override onMachine{<:SomeFreezable}: synchronizes the outer machine's frozen flag so fit_only! skips when frozen=true and state > 0freeze!/thaw!overloads forSomeFreezablemodels, consistent with the existing machine-level API