Skip to content

Commit

Permalink
Added max_to_keep parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
glr72 committed Jan 12, 2016
1 parent d1439a1 commit 1924411
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions R-package/R/xgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#' \code{maximize=TRUE} means the larger the evaluation score the better.
#' @param save_period save the model to the disk in every \code{save_period} rounds, 0 means no such action.
#' @param save_name the name or path for periodically saved model file.
#' @param max_to_keep number of distinct models to keep in the save_name path.
#' @param ... other parameters to pass to \code{params}.
#'
#' @details
Expand Down Expand Up @@ -123,7 +124,7 @@
xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
obj = NULL, feval = NULL, verbose = 1, print.every.n=1L,
early.stop.round = NULL, maximize = NULL,
save_period = 0, save_name = "xgboost.model", ...) {
save_period = 0, save_name = "xgboost.model", max_to_keep = 0,...) {
dtrain <- data
if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list")
Expand All @@ -144,6 +145,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
dot.params <- list(...)
nms.params <- names(params)
nms.dot.params <- names(dot.params)
models_kept<-0
if (length(intersect(nms.params,nms.dot.params)) > 0)
stop("Duplicated term in parameters. Please check your list of params.")
params <- append(params, dot.params)
Expand Down Expand Up @@ -220,7 +222,10 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
}
if (save_period > 0) {
if (i %% save_period == 0) {
xgb.save(bst, save_name)
if (models_kept < max_to_keep) {
xgb.save(bst, paste(save_name,i,sep=""))
models_kept <- models_kept + 1
}
}
}
}
Expand Down

0 comments on commit 1924411

Please sign in to comment.