Skip to content

Callbacks

Training callbacks for control and monitoring.

Available Callbacks

EarlyStopping

EarlyStopping

EarlyStopping(
    patience=50,
    min_delta=0.0,
    restore_best=True,
    verbose=False,
)

Bases: Callback

Stop training when validation metric stops improving.

Works with ANY model that provides val_loss in TrainingState. Requires eval_set to be passed to fit().

PARAMETER DESCRIPTION
patience

Number of rounds without improvement before stopping.

TYPE: int DEFAULT: 50

min_delta

Minimum change to qualify as an improvement.

TYPE: float DEFAULT: 0.0

restore_best

If True, restore model to best iteration after stopping.

TYPE: bool DEFAULT: True

verbose

If True, print message when stopping.

TYPE: bool DEFAULT: False

Attributes (after training): best_score: Best validation score achieved. best_round: Round at which best score was achieved. stopped_round: Round at which training was stopped (or None).

Example

callback = EarlyStopping(patience=50, min_delta=1e-4) model.fit(X, y, callbacks=[callback], eval_set=[(X_val, y_val)]) print(f"Best round: {model.best_iteration_}")

on_train_begin

on_train_begin(state)

Reset state at start of training.

on_round_end

on_round_end(state)

Check if we should stop training.

on_train_end

on_train_end(state)

Restore best model if requested.

Logger

Logger

Logger(period=1, show_train=True, show_val=True)

Bases: Callback

Log training progress to stdout.

PARAMETER DESCRIPTION
period

Print every N rounds (default: 1).

TYPE: int DEFAULT: 1

show_train

Include training loss in output.

TYPE: bool DEFAULT: True

show_val

Include validation loss in output.

TYPE: bool DEFAULT: True

Example

callback = Logger(period=10) # Log every 10 rounds model.fit(X, y, callbacks=[callback], eval_set=[(X_val, y_val)]) [0] train: 0.5234 valid: 0.5456 [10] train: 0.2134 valid: 0.2345 [20] train: 0.1234 valid: 0.1456 ...

on_round_end

on_round_end(state)

Print progress if at logging period.

ModelCheckpoint

ModelCheckpoint

ModelCheckpoint(
    filepath, save_best_only=True, verbose=False
)

Bases: Callback

Save model periodically or when validation score improves.

PARAMETER DESCRIPTION
filepath

Path to save model (use .pkl extension).

TYPE: str

save_best_only

If True, only save when validation improves.

TYPE: bool DEFAULT: True

verbose

If True, print message when saving.

TYPE: bool DEFAULT: False

Example

callback = ModelCheckpoint('best_model.pkl', save_best_only=True) model.fit(X, y, callbacks=[callback], eval_set=[(X_val, y_val)])

on_round_end

on_round_end(state)

Save model if conditions are met.

LearningRateScheduler

LearningRateScheduler

LearningRateScheduler(schedule)

Bases: Callback

Adjust learning rate during training.

PARAMETER DESCRIPTION
schedule

Function (round_idx) -> learning_rate_multiplier

Example

Decay learning rate by 0.95 each round

scheduler = LearningRateScheduler(lambda r: 0.95 ** r) model.fit(X, y, callbacks=[scheduler])

Step decay: halve LR at round 50 and 100

def step_decay(r): ... if r < 50: return 1.0 ... elif r < 100: return 0.5 ... else: return 0.25 scheduler = LearningRateScheduler(step_decay)

on_train_begin

on_train_begin(state)

Store initial learning rate.

on_round_begin

on_round_begin(state)

Update learning rate for this round.

HistoryCallback

HistoryCallback

HistoryCallback()

Bases: Callback

Record training history (losses per round).

Attributes (after training): history: Dict with 'train_loss' and 'val_loss' lists.

Example

history = HistoryCallback() model.fit(X, y, callbacks=[history], eval_set=[(X_val, y_val)]) plt.plot(history.history['train_loss'], label='train') plt.plot(history.history['val_loss'], label='valid') plt.legend()

on_train_begin

on_train_begin(state)

Reset history.

on_round_end

on_round_end(state)

Record losses.

Base Classes

Callback

Callback

Bases: ABC

Base class for training callbacks.

Subclass this to create custom callbacks for training hooks. All methods are optional - override only what you need.

Example (custom callback): >>> class GradientTracker(Callback): ... def init(self): ... self.grad_norms = [] ...
... def on_round_end(self, state): ... if 'grad_norm' in state.extra: ... self.grad_norms.append(state.extra['grad_norm']) ... return True >>> >>> tracker = GradientTracker() >>> model.fit(X, y, callbacks=[tracker]) >>> plt.plot(tracker.grad_norms)

on_train_begin

on_train_begin(state)

Called at the start of training.

PARAMETER DESCRIPTION
state

Current training state.

TYPE: TrainingState

on_train_end

on_train_end(state)

Called at the end of training.

PARAMETER DESCRIPTION
state

Current training state.

TYPE: TrainingState

on_round_begin

on_round_begin(state)

Called at the start of each boosting round.

PARAMETER DESCRIPTION
state

Current training state.

TYPE: TrainingState

on_round_end

on_round_end(state)

Called at the end of each boosting round.

PARAMETER DESCRIPTION
state

Current training state.

TYPE: TrainingState

RETURNS DESCRIPTION
bool

True to continue training, False to stop early.

TrainingState

TrainingState dataclass

TrainingState(
    model,
    round_idx=0,
    n_rounds=0,
    train_loss=None,
    val_loss=None,
    extra=dict(),
)

Shared state passed to callbacks during training.

This object is passed to all callbacks at each training event, allowing them to inspect and modify training behavior.

ATTRIBUTE DESCRIPTION
model

The model being trained (modified in place).

TYPE: Any

round_idx

Current boosting round (0-indexed).

TYPE: int

n_rounds

Total number of rounds requested.

TYPE: int

train_loss

Training loss for current round (if computed).

TYPE: float | None

val_loss

Validation loss for current round (if eval_set provided).

TYPE: float | None

extra

Dict for custom data (research callbacks can use this).

TYPE: dict

CallbackManager

CallbackManager

CallbackManager(callbacks=None)

Orchestrates multiple callbacks.

Used internally by training loops to manage callback execution.

PARAMETER DESCRIPTION
callbacks

List of Callback instances.

TYPE: list[Callback] | None DEFAULT: None

on_train_begin

on_train_begin(state)

Call on_train_begin for all callbacks.

on_round_begin

on_round_begin(state)

Call on_round_begin for all callbacks.

on_round_end

on_round_end(state)

Call on_round_end for all callbacks.

RETURNS DESCRIPTION
bool

True if training should continue, False if any callback wants to stop.

on_train_end

on_train_end(state)

Call on_train_end for all callbacks.