Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Guideline] Training API Enhancement and Refactor: Use Callbacks #892

Closed
3 tasks
tqchen opened this issue Feb 28, 2016 · 15 comments
Closed
3 tasks

[Guideline] Training API Enhancement and Refactor: Use Callbacks #892

tqchen opened this issue Feb 28, 2016 · 15 comments

Comments

@tqchen
Copy link
Member

tqchen commented Feb 28, 2016

There has been series of changes to enhance the training and cross validation API in python/R, example of these changes include:

  • Early stopping based on the statistics.
  • Whether save the results from evaluation in cross validation or training.
  • Whether print the results from evaluation.
  • Whether save and return the best model from cv or training.
  • Adapt learning rate during training.

Currently, each of these proposals involves an API change on core training API. One argument need to be added to each of these requirements. We need to use a better way to handle these issues, otherwise the training API will become extremely hard to maintain.

Use Callbacks to Handle These Cases

def early_stop_maximize(round, metric, verbose=True):
    """Example, Early stopping to maximize metric,
    """
    Info = namedtuple(["best_score", "best_score_i"]) 
    info = Info(best_score=float(-inf), best_score_i=0);
    def callback(iteration, booster, evaluation_results):
         """ Callback function to do  early stop.
         iteration: int
              Current iteration number, equal to total number of trees so far.
              If continue from existing model, 
         booster: Booster
              Current booster model
         evaluation_results: list of (str, float), evaluation results from watchlist.
         """ 
         score =  dict(evaluation_results)[metric]
         if score > info.best_score:
             info.best_score = score
             info.best_score_i = iteration
         if iteration - info.best_score_i > round:
             booster.best_iteration = info.best_score_i
             if verbose:
                  sys.stderr.write("Stopping at round %d" % iteration)
             raise StopTraining()
     return callback

def train(param, num_boost_round, callbacks):
     ....
     for i in range(num_boost_round):
          bst.update()
          try:
              for callback  in callbacks:
                  callback(i, bst, evaluation_results)
          except StopTraining:
              break

# call training
bst = train(param, num_boost_round,  callbacks=early_stop_maximize(3, 'test-auc'));

TODO List

  • Add the callback API to training and cv API.
    • This include python/R/Julia/JVM
  • Add a callback function module to xgboost
    • We will only accept improvements to callbacks in the future, and being more careful about training API change.
    • Add callbacks to support early_stop, logging, best_model_save
  • Use the callbacks to keep backward-compatibility
    • For example, when early_stop_rounds is detected, add early_stop_maximize to the callback list in the beginning of function
    • Mark the newly added arguments as deprecated, and give a deprecation warning to ask user to use callback API
    • We will consider remove some of the not so import arguments after two major release.
@tqchen
Copy link
Member Author

tqchen commented Feb 28, 2016

@terrytangyuan @hetong007 @CodingCat please reply to discuss the API you like and see if you would like to add the callback and do the refactoring for part of language bindings(python/R/jvm)

@hetong007
Copy link
Member

Is "Adapt learning rate during training" independent from this change? I think it is more related to bst.update().

@CodingCat
Copy link
Member

is there any restriction on the signature of the callbacks? or shall we distinguish between the pre-iteration and post-iteration callback, i.e. where to put the following lines

for callback  in callbacks:
       callback(i, bst, evaluation_results)

@tqchen
Copy link
Member Author

tqchen commented Feb 28, 2016

@hetong007 learning rate change can be done by setting parameters in post iteration, i.e. call bst.set_param in callback

@tqchen
Copy link
Member Author

tqchen commented Feb 28, 2016

@CodingCat This is an proposal. Most of the application so far seems to be post iteration callback. So we can use post iteration for now. But it might be interesting explicit

@terrytangyuan
Copy link
Member

terrytangyuan commented Feb 29, 2016

@tqchen Yeah the API evolved a lot and became complicated. I'll look into this for Python package when I get a chance. I got quite busy recently so if anyone wants to do it, you are very welcomed to do it!

@tqchen
Copy link
Member Author

tqchen commented May 20, 2016

Add callback API to python in #1211

@tqchen
Copy link
Member Author

tqchen commented May 20, 2016

The callback API for python is checkin in #1211 I am looking for volunteers to contribute the R counterpart. Please reply if you want to do this @hetong007 @khotilov

@hetong007
Copy link
Member

I will look into it. @khotilov You are also welcome to do that.

@khotilov
Copy link
Member

I'll give it some time this weekend.

@terrytangyuan
Copy link
Member

@tqchen Awesome. Looks great!

@khotilov
Copy link
Member

I've coded all the R callbacks and xgb.train but didn't finish debugging yet.

@tqchen
Copy link
Member Author

tqchen commented May 23, 2016

@khotilov souds good. let us know when it is ready. Note that R can directly pass Environments, so as long as the variable naming is proper, it is even more powerful than python

@khotilov
Copy link
Member

@tqchen: yes, that was the mechanism I was exploiting.

BTW, even though it's possible in R to throw and catch a specific exception
("signal a condition" in R-speak), I incline towards simply setting a
stop-flag variable for early stopping, so that it doesn't have to be the
last callback. I think that might make sense to do in Python as well.

Also I've added a finalizer type of callbacks in addition to pre- and
post-iteration ones. E.g., record_evaluation does very simple evaluation
history collection at post-iteration, and then its finalizer runs at the
end to do some fast in-bulk transformation of the collected data. I assume
that a finalizer would always be coupled with either a pre or post
callback, so it doesn't need to be a separate callback function, and it
could be invoked by, e.g., passing a finalize=TRUE option. It seems like
it's usually easy to detect the condition for calling init, but for
finalizing it would be more reliable and clean to call it explicitly. Does
that makes sense?

On Mon, May 23, 2016 at 11:23 AM, Tianqi Chen [email protected]
wrote:

@khotilov https://github.com/khotilov souds good. let us know when it
is ready. Note that R can directly pass Environments, so as long as the
variable naming is proper, it is even more powerful than python


You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub
#892 (comment)

@tqchen
Copy link
Member Author

tqchen commented May 23, 2016

  • I think the same logic could apply for an exception. By finish calling the rest of callbacks. The reason why exception is raised as opposed to return condition was to make the code more explicit about what is the condition, and possibly being compatible with future conditions.
  • I think an optional finalizer can be supported. We can supported if by an optional attribute in the callback. i.e. same as the flag of pre-iteration. So if the field exists, the finalizer will be invoked.

hetong007 added a commit that referenced this issue Jul 3, 2016
@tqchen tqchen closed this as completed Jul 29, 2016
@lock lock bot locked as resolved and limited conversation to collaborators Oct 26, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants