-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
boosters that should be training identically are not #6711
Comments
The random number engine is not reset during each run. |
If you remove the booster_0 = train({...}, dtrain, num_boost_round=1)
booster_1 = train({...}, dtrain, num_boost_round=1, xgb_model=booster_0) should equal to: booster = train({...}, dtrain, num_boost_round=2) If the random number engine were reset during each run, the above wouldn't hold. |
Yes, resetting the random number engine would be a bad idea. Because yes, you want to be able to interrupt training and reach the same results. In other words, we are in agreement that:
should always come up with the same results as
However, the current code does NOT guarantee this as my post demonstrates! Because IF the user makes any calls to In other words, this simple change will keep
It 100% should not be the case that interrupting training on one booster by training another should change the outcome on the first booster. And yet, it does. Because that call to train If you cannot train in bursts (let's say two bursts of 10 rounds each) and achieve the same results as if you trained in a single burst of 20, then that seems like a bug that should be fixed. Otherwise, the only way to guarantee the same model outcome in xgboost is to always start training from the very start. Which is untenable for boosters that require hundreds of rounds of training. And yet, in the current code, that is exactly what users must do if they want repeatable results (assuming they use any My suggestion is not to reset the random number engine with each call to I am not familiar with the internals of the code so I am not clear on whether my fix is easy or even possible. But right now anyone who trains boosters in bursts using xgboost, and who work with more than one booster at a time, is not guaranteed to get the same results between their runs unless they avoid using the parameters that adjust the RNG state! |
That's true. Making the state local to the booster is possible with a few changes. But saving it might be more involved. Let me figure something out. |
Thank you! |
Still an outstanding issue as far as I know.
…On Thu, May 23, 2024 at 8:27 PM andrew-esteban-imc ***@***.***> wrote:
Hi there.
Just wondering if there has been any progress on this issue? We have found
that resuming hist training from a checkpoint when any of subsample,
colsample_bytree, colsample_bylevel and colsample_bynode are set below 1
results in a non-deterministic result. With our current setup, it's fairly
common for training jobs to be interrupted (they are running in a contested
cluster without guaranteed resources), so we rely on checkpointing to
reduce time.
—
Reply to this email directly, view it on GitHub
<#6711 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACA5ETDE3NURQNRQHGN5MITZD2XQ3AVCNFSM4XXYZOTKU5DIOJSWCZC7NNSXTN2JONZXKZKDN5WW2ZLOOQ5TEMJSHA2DIOJUGY4A>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Hello -
We've found an issue where models that should end up identical after training (assuming a deterministic random number generator) do not under all conditions. See the following code in which we train 4 boosters with identical parameter sets on the same data for 3 rounds each.
Two are trained for 3 rounds each in series, aka, we train booster 1 for three rounds before moving on to booster 2. Their predictions are identical and always appear to come out that way (as expected because they start with the same seed even though there is random sampling going on with the
colsample_*
parameters).However, two are trained in an "interleaved" fashion, aka, we train one booster for one round, then train the second booster for one round, then repeat until both have been trained for 3 rounds. Those boosters do not have identical predictions to either of the first two or each other. Even though they too were trained with the same seed, params, and training data.
It appears that any one of the
colsample_*
parameters being < 1.0 will trigger this effect. We suspect that a random seed in the module instead of the booster is preserving its state between calls toxgb.train()
which puts the training out of sync when you train on different boosters.If you uncomment the two lines in our code featuring the variable
clear_the_pipes
and re-run, you will see that callingxgb.train
withxgb_model=None
does seem to reset things a bit for the interleaved boosters. Indeed, this is the only explanation for why the first two models train identically. Because if the random state were not reset after the first 3 calls toxgb.train()
, the second model would come out differently.We feel that even with
colsample_*
params impacting the training outcomes, the results should be repeatable and all four of our boosters should end up the same. Perhaps a random seed should be kept on a per-booster basis so that every call toxgb.train()
will set the seed accordingly.Thanks in advance for your help with this one!
This is all in python 3.79, xgboost 1.3.3.
The text was updated successfully, but these errors were encountered: