Objectives¶
Automatic objective functions for XGBoost and LightGBM using JAX autodiff.
How it works
JAXBoost uses JAX automatic differentiation to compute gradients and Hessians from your loss function. You write the loss, JAX computes the derivatives—no manual math required.
Core Classes¶
These are the building blocks for creating custom objectives.
AutoObjective¶
The primary class for scalar loss functions (binary classification, regression).
AutoObjective ¶
Automatically generate XGBoost/LightGBM objective functions.
Uses JAX automatic differentiation to compute gradients and Hessians, eliminating the need for manual derivation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
LossFunction
|
A loss function that takes (y_pred, y_true, **kwargs) and returns a scalar loss. Should operate on single samples. |
required |
Example
@AutoObjective ... def my_loss(y_pred, y_true, alpha=0.5): ... return alpha * (y_pred - y_true) ** 2
Use with XGBoost¶
model = xgb.train(params, dtrain, obj=my_loss.xgb_objective)
Use with custom parameters¶
model = xgb.train(params, dtrain, obj=my_loss.get_xgb_objective(alpha=0.7))
Source code in src/jaxboost/objective/auto.py
xgb_objective
property
¶
xgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
XGBoost-compatible objective function using default parameters.
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
XGBoost objective function: (y_pred, dtrain) -> (grad, hess) |
Example
model = xgb.train(params, dtrain, obj=my_loss.xgb_objective)
lgb_objective
property
¶
lgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
LightGBM-compatible objective function using default parameters.
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
LightGBM objective function: (y_pred, dataset) -> (grad, hess) |
Example
model = lgb.train(params, dtrain, fobj=my_loss.lgb_objective)
sklearn_objective
property
¶
sklearn_objective: Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
Sklearn-compatible objective for XGBClassifier/XGBRegressor.
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
Sklearn objective function: (labels, predt) -> (grad, hess) |
Example
from xgboost import XGBClassifier clf = XGBClassifier(objective=focal_loss.sklearn_objective) clf.fit(X_train, y_train)
__call__ ¶
Compute the loss value for a batch.
Source code in src/jaxboost/objective/auto.py
gradient ¶
gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute gradient of loss w.r.t. y_pred for each sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples,) |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) |
required |
**kwargs
|
Any
|
Additional arguments passed to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Gradients, shape (n_samples,) |
Source code in src/jaxboost/objective/auto.py
hessian ¶
hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute Hessian (second derivative) of loss w.r.t. y_pred for each sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples,) |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) |
required |
**kwargs
|
Any
|
Additional arguments passed to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Hessians (diagonal), shape (n_samples,) |
Source code in src/jaxboost/objective/auto.py
grad_hess ¶
grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]
Compute both gradient and Hessian efficiently.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples,) |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) |
required |
sample_weight
|
NDArray[floating[Any]] | None
|
Optional sample weights, shape (n_samples,) |
None
|
**kwargs
|
Any
|
Additional arguments passed to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]
|
Tuple of (gradients, hessians), each shape (n_samples,) |
Source code in src/jaxboost/objective/auto.py
get_xgb_objective ¶
get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get an XGBoost-compatible objective function with custom parameters.
Automatically handles: - sample_weight: Sample weights from DMatrix - label_lower_bound: Lower bound labels for interval/survival regression - label_upper_bound: Upper bound labels for interval/survival regression
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Parameters to pass to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
XGBoost objective function: (y_pred, dtrain) -> (grad, hess) |
Source code in src/jaxboost/objective/auto.py
get_lgb_objective ¶
get_lgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get a LightGBM-compatible objective function with custom parameters.
Automatically handles sample weights if set in the Dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Parameters to pass to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
LightGBM objective function: (y_pred, dataset) -> (grad, hess) |
Source code in src/jaxboost/objective/auto.py
with_params ¶
Create a new AutoObjective with default parameters set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Default parameters for the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
AutoObjective
|
New AutoObjective instance with default parameters |
Example
focal = focal_loss.with_params(gamma=3.0, alpha=0.75) model = xgb.train(params, dtrain, obj=focal.xgb_objective)
Source code in src/jaxboost/objective/auto.py
get_sklearn_objective ¶
get_sklearn_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], NDArray[np.floating[Any]]], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get a scikit-learn compatible objective function for XGBClassifier/XGBRegressor.
The sklearn interface expects (labels, predt) -> (grad, hess) instead of (predt, dtrain) -> (grad, hess).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Parameters to pass to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
Sklearn-compatible objective: (labels, predt) -> (grad, hess) |
Example
from xgboost import XGBClassifier clf = XGBClassifier(objective=focal_loss.sklearn_objective) clf.fit(X_train, y_train)
Source code in src/jaxboost/objective/auto.py
MultiClassObjective¶
For multi-class classification problems where predictions are logit vectors.
MultiClassObjective ¶
Objective function wrapper for multi-class classification.
Handles the specific requirements of multi-class classification with XGBoost/LightGBM: - Predictions are logits of shape (n_samples, n_classes) - Labels are integer class indices - Computes gradients w.r.t. each class logit
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
Callable[..., Array] | None
|
A loss function that takes (logits, label, **kwargs) where: - logits: shape (n_classes,) raw scores for each class - label: integer class index (0 to n_classes-1) Returns a scalar loss. |
None
|
n_classes
|
int
|
Number of classes |
3
|
Example
@MultiClassObjective(n_classes=3) ... def my_multiclass_loss(logits, label): ... probs = jax.nn.softmax(logits) ... return -jnp.log(probs[label] + 1e-10)
params = {'num_class': 3} model = xgb.train(params, dtrain, obj=my_multiclass_loss.xgb_objective)
Source code in src/jaxboost/objective/multiclass.py
__call__ ¶
Allow use as a decorator with arguments.
Source code in src/jaxboost/objective/multiclass.py
gradient ¶
gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute gradient for multi-class predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples, n_classes) or flattened |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) integer class indices |
required |
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Gradients, shape (n_samples, n_classes) |
Source code in src/jaxboost/objective/multiclass.py
hessian ¶
hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute diagonal Hessian for multi-class predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples, n_classes) or flattened |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) integer class indices |
required |
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Diagonal Hessians, shape (n_samples, n_classes) |
Source code in src/jaxboost/objective/multiclass.py
grad_hess ¶
grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]
Compute both gradient and Hessian.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples, n_classes) |
required |
y_true
|
NDArray[floating[Any]]
|
True labels, shape (n_samples,) |
required |
sample_weight
|
NDArray[floating[Any]] | None
|
Optional sample weights, shape (n_samples,) |
None
|
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]
|
Tuple of (gradients, hessians), each shape (n_samples, n_classes) |
Source code in src/jaxboost/objective/multiclass.py
get_xgb_objective ¶
get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get an XGBoost-compatible objective function for multi-class.
Use with XGBoost params: {'num_class': n_classes}
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Parameters to pass to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
XGBoost objective function |
Source code in src/jaxboost/objective/multiclass.py
with_params ¶
Create a new instance with default parameters set.
Source code in src/jaxboost/objective/multiclass.py
MultiOutputObjective¶
For multi-output predictions like uncertainty estimation (mean + variance).
MultiOutputObjective ¶
Objective for multi-output/multi-task learning.
Handles cases where each sample has multiple predictions, such as: - Multi-target regression - Parametric models learning multiple parameters - Uncertainty estimation (mean + variance)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
Callable[..., Array] | None
|
A loss function that takes (y_pred, y_true, **kwargs) where: - y_pred: shape (n_outputs,) for a single sample - y_true: shape (n_outputs,) or scalar for a single sample Returns a scalar loss. |
None
|
n_outputs
|
int
|
Number of outputs per sample |
2
|
Example
@MultiOutputObjective(n_outputs=2) ... def parametric_loss(params, y_true, t=None): ... # params = [A, B] for model y = A + B*t ... A, B = params[0], params[1] ... y_pred = A + B * t ... return (y_pred - y_true) ** 2
model = xgb.train(params, dtrain, obj=parametric_loss.xgb_objective)
Source code in src/jaxboost/objective/multi_output.py
__call__ ¶
Allow use as a decorator with arguments.
Source code in src/jaxboost/objective/multi_output.py
gradient ¶
gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute gradient for multi-output predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples, n_outputs) or flattened |
required |
y_true
|
NDArray[floating[Any]]
|
True labels |
required |
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Gradients, shape (n_samples * n_outputs,) for XGBoost compatibility |
Source code in src/jaxboost/objective/multi_output.py
hessian ¶
hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]
Compute diagonal Hessian for multi-output predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples, n_outputs) or flattened |
required |
y_true
|
NDArray[floating[Any]]
|
True labels |
required |
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Diagonal Hessians, shape (n_samples * n_outputs,) for XGBoost |
Source code in src/jaxboost/objective/multi_output.py
grad_hess ¶
grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]
Compute both gradient and Hessian.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples * n_outputs,) |
required |
y_true
|
NDArray[floating[Any]]
|
True labels |
required |
sample_weight
|
NDArray[floating[Any]] | None
|
Optional sample weights, shape (n_samples,) |
None
|
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]
|
Tuple of (gradients, hessians), each shape (n_samples * n_outputs,) |
Source code in src/jaxboost/objective/multi_output.py
get_xgb_objective ¶
get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get an XGBoost-compatible objective function for multi-output.
Note: Requires XGBoost with multi-output support.
Set multi_strategy='multi_output_tree' and num_target=n_outputs in params.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Parameters to pass to the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
XGBoost objective function |
Source code in src/jaxboost/objective/multi_output.py
with_params ¶
Create a new instance with default parameters set.
Source code in src/jaxboost/objective/multi_output.py
MaskedMultiTaskObjective¶
For multi-task learning with potentially missing labels.
MaskedMultiTaskObjective ¶
MaskedMultiTaskObjective(task_loss_fn: Callable[[Array, Array], Array] | None = None, n_tasks: int = 2, task_weights: list[float] | NDArray[floating[Any]] | None = None)
Multi-task objective with support for missing labels.
Key features: - Handles arbitrary label missingness patterns - Gradients are 0 for missing labels (no update) - Supports per-task loss functions - Automatic gradient/Hessian computation via JAX
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks
|
int
|
Number of tasks (outputs per sample) |
2
|
task_loss_fn
|
Callable[[Array, Array], Array] | None
|
Loss function for each task. Signature: (y_pred, y_true) -> scalar loss Default is squared error. |
None
|
task_weights
|
list[float] | NDArray[floating[Any]] | None
|
Optional weights for each task, shape (n_tasks,) |
None
|
Example
Basic usage¶
obj = MaskedMultiTaskObjective(n_tasks=3)
Custom per-task loss¶
@MaskedMultiTaskObjective(n_tasks=3) ... def my_mtl_loss(y_pred, y_true): ... return (y_pred - y_true) ** 2
Different weights per task¶
obj = MaskedMultiTaskObjective(n_tasks=3, task_weights=[1.0, 2.0, 0.5])
Source code in src/jaxboost/objective/multi_task.py
xgb_objective
property
¶
xgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
XGBoost-compatible objective function (no mask).
For missing label support, use get_xgb_objective(mask=...) instead.
lgb_objective
property
¶
lgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
LightGBM-compatible objective function (no mask).
__call__ ¶
Allow use as decorator.
Source code in src/jaxboost/objective/multi_task.py
gradient ¶
gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, **kwargs: Any) -> NDArray[np.floating[Any]]
Compute gradients with missing label support.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples * n_tasks,) or (n_samples, n_tasks) |
required |
y_true
|
NDArray[floating[Any]]
|
Labels, shape (n_samples * n_tasks,) or (n_samples, n_tasks) |
required |
mask
|
NDArray[floating[Any]] | None
|
Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks) 1 = valid label, 0 = missing label. Default: all valid. |
None
|
**kwargs
|
Any
|
Additional arguments (unused) |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Gradients, shape (n_samples * n_tasks,) |
Source code in src/jaxboost/objective/multi_task.py
hessian ¶
hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, **kwargs: Any) -> NDArray[np.floating[Any]]
Compute Hessians with missing label support.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples * n_tasks,) |
required |
y_true
|
NDArray[floating[Any]]
|
Labels, shape (n_samples * n_tasks,) |
required |
mask
|
NDArray[floating[Any]] | None
|
Label mask, shape (n_samples * n_tasks,) |
None
|
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
NDArray[floating[Any]]
|
Diagonal Hessians, shape (n_samples * n_tasks,) |
Source code in src/jaxboost/objective/multi_task.py
grad_hess ¶
grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]
Compute both gradient and Hessian efficiently.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
NDArray[floating[Any]]
|
Predictions, shape (n_samples * n_tasks,) |
required |
y_true
|
NDArray[floating[Any]]
|
Labels, shape (n_samples * n_tasks,) |
required |
mask
|
NDArray[floating[Any]] | None
|
Label mask, shape (n_samples * n_tasks,) |
None
|
sample_weight
|
NDArray[floating[Any]] | None
|
Optional sample weights, shape (n_samples,) |
None
|
**kwargs
|
Any
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]
|
Tuple of (gradients, hessians) |
Source code in src/jaxboost/objective/multi_task.py
get_xgb_objective ¶
get_xgb_objective(mask: NDArray[floating[Any]] | None = None, mask_key: str | None = None, **kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get an XGBoost-compatible objective function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mask
|
NDArray[floating[Any]] | None
|
Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks). 1 = valid label, 0 = missing label. If None, all labels valid. |
None
|
mask_key
|
str | None
|
Key to retrieve mask from DMatrix via get_float_info(). Use this for distributed training (e.g., Ray XGBoost) where data is partitioned across workers. The mask will be read per-worker, ensuring correct alignment with local data. |
None
|
**kwargs
|
Any
|
Additional parameters for the loss function |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]
|
XGBoost objective function: (y_pred, dtrain) -> (grad, hess) |
Example
Single-node: pass mask directly¶
y_true = np.array([[1.0, np.nan, 0.5], ... [np.nan, 2.0, 1.0]]) mask = ~np.isnan(y_true) y_true_filled = np.nan_to_num(y_true, nan=0.0) dtrain = xgb.DMatrix(X, label=y_true_filled.flatten()) obj = MaskedMultiTaskObjective(n_tasks=3) model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask=mask))
Distributed (Ray XGBoost): store mask in DMatrix¶
dtrain.set_float_info("label_mask", mask.astype(np.float32).flatten()) model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask_key="label_mask"))
Source code in src/jaxboost/objective/multi_task.py
get_lgb_objective ¶
get_lgb_objective(mask: NDArray[floating[Any]] | None = None, mask_key: str | None = None, **kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]
Get a LightGBM-compatible objective function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mask
|
NDArray[floating[Any]] | None
|
Label mask for missing labels. |
None
|
mask_key
|
str | None
|
Key to retrieve mask from Dataset. Use for distributed training. |
None
|
**kwargs
|
Any
|
Additional parameters. |
{}
|
Note: LightGBM multi-output support is more limited than XGBoost.
Source code in src/jaxboost/objective/multi_task.py
with_params ¶
Create a new instance with default parameters set.
Source code in src/jaxboost/objective/multi_task.py
Binary Classification¶
Objectives for binary classification tasks (labels in {0, 1}).
focal_loss ¶
Focal Loss for imbalanced binary classification.
Down-weights well-classified examples and focuses on hard examples. Particularly useful for highly imbalanced datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Raw prediction (logit), will be passed through sigmoid |
required |
y_true
|
Array
|
Binary label (0 or 1) |
required |
gamma
|
float
|
Focusing parameter. Higher = more focus on hard examples. Default: 2.0 |
2.0
|
alpha
|
float
|
Class balance weight for positive class. Default: 0.25 |
0.25
|
Reference
Lin et al. "Focal Loss for Dense Object Detection" (2017) https://arxiv.org/abs/1708.02002
Example
model = xgb.train(params, dtrain, obj=focal_loss.xgb_objective)
With custom parameters:¶
model = xgb.train(params, dtrain, obj=focal_loss.get_xgb_objective(gamma=3.0))
Source code in src/jaxboost/objective/binary.py
binary_crossentropy ¶
Binary Cross-Entropy Loss.
Standard binary classification loss. Included for testing and as a baseline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Raw prediction (logit), will be passed through sigmoid |
required |
y_true
|
Array
|
Binary label (0 or 1) |
required |
Example
model = xgb.train(params, dtrain, obj=binary_crossentropy.xgb_objective)
Source code in src/jaxboost/objective/binary.py
weighted_binary_crossentropy ¶
Weighted Binary Cross-Entropy Loss.
Applies a weight to the positive class to handle class imbalance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Raw prediction (logit), will be passed through sigmoid |
required |
y_true
|
Array
|
Binary label (0 or 1) |
required |
pos_weight
|
float
|
Weight for positive class. Default: 1.0 |
1.0
|
Example
10x weight for positive class¶
obj = weighted_binary_crossentropy.with_params(pos_weight=10.0) model = xgb.train(params, dtrain, obj=obj.xgb_objective)
Source code in src/jaxboost/objective/binary.py
hinge_loss ¶
Smooth Hinge Loss for binary classification.
SVM-style loss with smooth approximation for non-zero Hessians.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Raw prediction score |
required |
y_true
|
Array
|
Binary label (0 or 1), will be converted to {-1, +1} |
required |
margin
|
float
|
Margin parameter. Default: 1.0 |
1.0
|
Example
model = xgb.train(params, dtrain, obj=hinge_loss.xgb_objective)
Source code in src/jaxboost/objective/binary.py
Regression¶
Objectives for continuous target prediction.
Standard¶
mse ¶
Mean Squared Error Loss.
Standard squared error loss. Included for testing and as a baseline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
Example
model = xgb.train(params, dtrain, obj=mse.xgb_objective)
Source code in src/jaxboost/objective/regression.py
huber ¶
Huber Loss for robust regression.
Combines MSE for small errors and MAE for large errors, making it robust to outliers while maintaining smoothness near zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
delta
|
float
|
Threshold where loss transitions from quadratic to linear. Default: 1.0 |
1.0
|
Example
model = xgb.train(params, dtrain, obj=huber.xgb_objective)
With custom delta:¶
model = xgb.train(params, dtrain, obj=huber.get_xgb_objective(delta=0.5))
Source code in src/jaxboost/objective/regression.py
pseudo_huber ¶
Pseudo-Huber Loss.
A smooth approximation to the Huber loss that is differentiable everywhere.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
delta
|
float
|
Scale parameter controlling the transition. Default: 1.0 |
1.0
|
Example
model = xgb.train(params, dtrain, obj=pseudo_huber.xgb_objective)
Source code in src/jaxboost/objective/regression.py
log_cosh ¶
Log-Cosh Loss for smooth robust regression.
Similar to Huber loss but smoother everywhere. Twice differentiable, which can lead to better optimization behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
Example
model = xgb.train(params, dtrain, obj=log_cosh.xgb_objective)
Source code in src/jaxboost/objective/regression.py
mae_smooth ¶
Smooth Mean Absolute Error Loss.
A smooth approximation to MAE that has non-zero Hessian. Uses sqrt(error^2 + beta^2) - beta as approximation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
beta
|
float
|
Smoothing parameter. Smaller = closer to true MAE. Default: 0.1 |
0.1
|
Example
model = xgb.train(params, dtrain, obj=mae_smooth.xgb_objective)
Source code in src/jaxboost/objective/regression.py
Quantile & Asymmetric¶
quantile ¶
Smooth Quantile Loss (Pinball Loss) for quantile regression.
Asymmetric loss that penalizes under-prediction and over-prediction differently, allowing prediction of specific quantiles.
Uses a smooth approximation to ensure non-zero Hessians.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
q
|
float
|
Target quantile in (0, 1). Default: 0.5 (median). q=0.1 for 10th percentile (conservative), q=0.5 for median, q=0.9 for 90th percentile (aggressive). |
0.5
|
alpha
|
float
|
Smoothing parameter for regularization. Default: 0.01 |
0.01
|
Example
Predict the 90th percentile¶
q90 = quantile.with_params(q=0.9) model = xgb.train(params, dtrain, obj=q90.xgb_objective)
Source code in src/jaxboost/objective/regression.py
asymmetric ¶
Asymmetric Loss for different penalties on under/over-prediction.
Useful when the cost of under-prediction differs from over-prediction, e.g., inventory management, demand forecasting.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted value |
required |
y_true
|
Array
|
True value |
required |
alpha
|
float
|
Asymmetry parameter in (0, 1). Default: 0.7 - alpha > 0.5: Penalize under-prediction more - alpha < 0.5: Penalize over-prediction more |
0.7
|
Example
Penalize under-prediction heavily (for safety stock)¶
obj = asymmetric.with_params(alpha=0.9) model = xgb.train(params, dtrain, obj=obj.xgb_objective)
Source code in src/jaxboost/objective/regression.py
Distribution-Based¶
tweedie ¶
Tweedie Loss for zero-inflated and positive continuous data.
Common in insurance claims, rainfall prediction, and other scenarios with many zeros and positive continuous values.
Valid for 1 < p < 2.
For p=1 (Poisson), use poisson objective.
For p=2 (Gamma), use gamma objective.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Raw prediction (will be exponentiated to ensure positivity) |
required |
y_true
|
Array
|
True value (must be non-negative) |
required |
p
|
float
|
Tweedie power parameter. Default: 1.5. For 1<p<2: Compound Poisson-Gamma (most common for insurance). |
1.5
|
Example
model = xgb.train(params, dtrain, obj=tweedie.xgb_objective)
Source code in src/jaxboost/objective/regression.py
poisson ¶
Poisson Negative Log-Likelihood.
For count data. Assumes log-link function (y_pred is log(lambda)). Loss = exp(y_pred) - y_true * y_pred
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Log of the expected count (log(lambda)) |
required |
y_true
|
Array
|
True count (must be non-negative) |
required |
Example
model = xgb.train(params, dtrain, obj=poisson.xgb_objective)
Source code in src/jaxboost/objective/regression.py
gamma ¶
Gamma Negative Log-Likelihood.
For positive continuous data (e.g. insurance claims, wait times). Assumes log-link function (y_pred is log(mean)). Loss = y_pred + y_true / exp(y_pred)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Log of the expected value (log(mean)) |
required |
y_true
|
Array
|
True value (must be positive) |
required |
Example
model = xgb.train(params, dtrain, obj=gamma.xgb_objective)
Source code in src/jaxboost/objective/regression.py
Multi-class Classification¶
Objectives for classification with more than two classes.
softmax_cross_entropy ¶
Softmax Cross-Entropy Loss for multi-class classification.
Standard cross-entropy loss with softmax activation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of classes |
3
|
Returns:
| Type | Description |
|---|---|
MultiClassObjective
|
MultiClassObjective instance |
Example
softmax_loss = softmax_cross_entropy(n_classes=5) params = {'num_class': 5} model = xgb.train(params, dtrain, obj=softmax_loss.xgb_objective)
Source code in src/jaxboost/objective/multiclass.py
focal_multiclass ¶
focal_multiclass(n_classes: int = 3, gamma: float = 2.0, alpha: float | None = None) -> MultiClassObjective
Focal Loss for multi-class classification.
Extends focal loss to multi-class setting. Down-weights well-classified examples to focus on hard examples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of classes |
3
|
gamma
|
float
|
Focusing parameter. Higher = more focus on hard examples. |
2.0
|
alpha
|
float | None
|
Optional class weight. If None, all classes weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
MultiClassObjective
|
MultiClassObjective instance |
Example
focal_mc = focal_multiclass(n_classes=10, gamma=2.0) model = xgb.train(params, dtrain, obj=focal_mc.xgb_objective)
Reference
Lin et al. "Focal Loss for Dense Object Detection" (2017)
Source code in src/jaxboost/objective/multiclass.py
label_smoothing ¶
Label Smoothing Cross-Entropy Loss.
Softmax cross-entropy with label smoothing for regularization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of classes |
3
|
smoothing
|
float
|
Smoothing factor in [0, 1]. 0 = no smoothing. |
0.1
|
Returns:
| Type | Description |
|---|---|
MultiClassObjective
|
MultiClassObjective instance |
Example
smooth_loss = label_smoothing(n_classes=10, smoothing=0.1) model = xgb.train(params, dtrain, obj=smooth_loss.xgb_objective)
Source code in src/jaxboost/objective/multiclass.py
class_balanced ¶
class_balanced(n_classes: int = 3, samples_per_class: NDArray[floating[Any]] | None = None, beta: float = 0.999) -> MultiClassObjective
Class-Balanced Loss for long-tailed distributions.
Re-weights classes based on effective number of samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of classes |
3
|
samples_per_class
|
NDArray[floating[Any]] | None
|
Array of sample counts per class. If None, uniform weights. |
None
|
beta
|
float
|
Hyperparameter for effective number. Higher = more aggressive reweighting. |
0.999
|
Returns:
| Type | Description |
|---|---|
MultiClassObjective
|
MultiClassObjective instance |
Example
cb_loss = class_balanced( ... n_classes=5, ... samples_per_class=np.array([1000, 500, 100, 50, 10]), ... beta=0.999 ... ) model = xgb.train(params, dtrain, obj=cb_loss.xgb_objective)
Reference
Cui et al. "Class-Balanced Loss Based on Effective Number of Samples" (2019)
Source code in src/jaxboost/objective/multiclass.py
Ordinal Regression¶
Objectives for ordered categorical outcomes (ratings, grades, severity levels).
Cumulative Link Models¶
ordinal_logit ¶
ordinal_probit ¶
QWK-Aligned¶
qwk_ordinal ¶
qwk_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit', alpha: float = 0.0, beta: float = 1.0) -> QWKOrdinalObjective
Create a QWK-aligned ordinal objective.
Uses Expected Quadratic Error (EQE) as a differentiable surrogate for QWK.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of ordinal classes |
required |
link
|
Literal['probit', 'logit']
|
Link function - 'probit' or 'logit' |
'logit'
|
alpha
|
float
|
Weight for NLL loss (0 = no NLL) |
0.0
|
beta
|
float
|
Weight for EQE loss (1 = full EQE) |
1.0
|
Returns:
| Type | Description |
|---|---|
QWKOrdinalObjective
|
QWKOrdinalObjective instance |
Example
Pure EQE (best QWK alignment)¶
obj = qwk_ordinal(n_classes=7)
Hybrid for stability¶
obj = qwk_ordinal(n_classes=7, alpha=0.5, beta=0.5)
Source code in src/jaxboost/objective/ordinal.py
squared_cdf_ordinal ¶
squared_cdf_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit') -> SquaredCDFObjective
Create a Squared CDF (CRPS) ordinal objective.
This minimizes the squared Earth Mover's Distance between distributions, often outperforming EQE/NLL for QWK optimization.
Source code in src/jaxboost/objective/ordinal.py
hybrid_ordinal ¶
hybrid_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit', nll_weight: float = 0.7, eqe_weight: float = 0.3) -> QWKOrdinalObjective
Create a hybrid ordinal objective (NLL + EQE).
Combines: - NLL: Proper probabilistic loss (stable gradients early in training) - EQE: QWK-aligned loss (better metric alignment)
Default weights (0.7/0.3) work well in practice.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of ordinal classes |
required |
link
|
Literal['probit', 'logit']
|
Link function |
'logit'
|
nll_weight
|
float
|
Weight for NLL loss |
0.7
|
eqe_weight
|
float
|
Weight for EQE loss |
0.3
|
Returns:
| Type | Description |
|---|---|
QWKOrdinalObjective
|
QWKOrdinalObjective instance |
Source code in src/jaxboost/objective/ordinal.py
SLACE Paper (AAAI 2025)¶
slace_objective ¶
Create SLACE (Soft Labels Accumulating Cross Entropy) objective.
sord_objective ¶
oll_objective ¶
Survival Analysis¶
Objectives for time-to-event modeling.
aft ¶
aft(y_pred: Array, y_true: Array, label_lower_bound: Array | None = None, label_upper_bound: Array | None = None, sigma: float = 1.0) -> jax.Array
Accelerated Failure Time (AFT) Loss for survival analysis.
Models log(T) = y_pred + sigma * epsilon, where epsilon follows a normal distribution.
Handles censored data: - Uncensored: lower == upper (exact event time) - Right-censored: upper == inf (event hasn't occurred yet) - Interval-censored: lower < upper (event in time range)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted log survival time |
required |
y_true
|
Array
|
Label (used as lower bound if bounds not provided) |
required |
label_lower_bound
|
Array | None
|
Lower bound of survival time |
None
|
label_upper_bound
|
Array | None
|
Upper bound of survival time |
None
|
sigma
|
float
|
Scale parameter (default 1.0) |
1.0
|
Example
Survival data: some patients censored (still alive)¶
lower_bounds = event_times upper_bounds = np.where(is_censored, np.inf, event_times) dtrain.set_float_info('label_lower_bound', lower_bounds) dtrain.set_float_info('label_upper_bound', upper_bounds) model = xgb.train(params, dtrain, obj=aft.xgb_objective)
Source code in src/jaxboost/objective/survival.py
weibull_aft ¶
weibull_aft(y_pred: Array, y_true: Array, label_lower_bound: Array | None = None, label_upper_bound: Array | None = None, k: float = 1.0) -> jax.Array
Weibull AFT (Accelerated Failure Time) Loss.
Uses Weibull distribution for survival times instead of log-normal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred
|
Array
|
Predicted log scale parameter (lambda) |
required |
y_true
|
Array
|
Label (used as lower bound if bounds not provided) |
required |
label_lower_bound
|
Array | None
|
Lower bound of survival time |
None
|
label_upper_bound
|
Array | None
|
Upper bound of survival time |
None
|
k
|
float
|
Shape parameter of Weibull distribution (default 1.0 = exponential) |
1.0
|
Example
model = xgb.train(params, dtrain, obj=weibull_aft.xgb_objective)
Source code in src/jaxboost/objective/survival.py
Multi-task Learning¶
Objectives for predicting multiple targets simultaneously.
multi_task_regression ¶
Standard multi-task regression with MSE loss.
Supports missing labels via masking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks
|
int
|
Number of regression tasks |
required |
Example
obj = multi_task_regression(n_tasks=5)
Labels shape: (n_samples, 5), some can be NaN¶
mask = ~np.isnan(y_true) dtrain.set_float_info('label_mask', mask.flatten())
Source code in src/jaxboost/objective/multi_task.py
multi_task_classification ¶
Multi-task binary classification with log loss.
Each task is an independent binary classification. Supports missing labels via masking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks
|
int
|
Number of binary classification tasks |
required |
Example
obj = multi_task_classification(n_tasks=3)
Each task: predict 0 or 1¶
Source code in src/jaxboost/objective/multi_task.py
multi_task_huber ¶
Multi-task regression with Huber loss (robust to outliers).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks
|
int
|
Number of tasks |
required |
delta
|
float
|
Threshold for switching between L1 and L2 |
1.0
|
Example
obj = multi_task_huber(n_tasks=3, delta=1.5)
Source code in src/jaxboost/objective/multi_task.py
multi_task_quantile ¶
Multi-task quantile regression.
Each task predicts a different quantile. Useful for prediction intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks
|
int
|
Number of quantiles to predict |
required |
quantiles
|
list[float] | None
|
List of quantile values (default: evenly spaced) |
None
|
Example
Predict 10th, 50th, 90th percentiles¶
obj = multi_task_quantile(n_tasks=3, quantiles=[0.1, 0.5, 0.9])
Source code in src/jaxboost/objective/multi_task.py
Uncertainty Estimation¶
Multi-output objectives that predict both value and uncertainty.
gaussian_nll ¶
Gaussian Negative Log-Likelihood for uncertainty estimation.
Predicts both mean and log-variance, enabling uncertainty quantification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_outputs
|
int
|
Should be 2 (mean, log_variance) |
2
|
Returns:
| Type | Description |
|---|---|
MultiOutputObjective
|
MultiOutputObjective instance |
Example
nll = gaussian_nll() params = {'multi_strategy': 'multi_output_tree', 'num_target': 2} model = xgb.train(params, dtrain, obj=nll.xgb_objective)
Source code in src/jaxboost/objective/multi_output.py
laplace_nll ¶
Laplace Negative Log-Likelihood for robust uncertainty estimation.
Similar to Gaussian NLL but uses Laplace distribution, which is more robust to outliers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_outputs
|
int
|
Should be 2 (location, log_scale) |
2
|
Returns:
| Type | Description |
|---|---|
MultiOutputObjective
|
MultiOutputObjective instance |
Example
nll = laplace_nll() model = xgb.train(params, dtrain, obj=nll.xgb_objective)