Skip to content

API Reference

Welcome to the JAXBoost API documentation.

What is JAXBoost?

JAXBoost provides automatic objective functions for XGBoost and LightGBM using JAX automatic differentiation. Write a loss function, get gradients and Hessians automatically.

Quick Example

import xgboost as xgb
from jaxboost import auto_objective, focal_loss

# Load your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}

# Use built-in objective
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)

# Create custom objective
@auto_objective
def my_loss(y_pred, y_true):
    return (y_pred - y_true) ** 2

model = xgb.train(params, dtrain, num_boost_round=100, obj=my_loss.xgb_objective)
import lightgbm as lgb
from jaxboost import huber

train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)

Core API

Class/Decorator Description
@auto_objective Decorator for scalar loss functions
AutoObjective Base class for custom objectives
MultiClassObjective Multi-class classification objectives
MultiOutputObjective Multi-output objectives (uncertainty)
MaskedMultiTaskObjective Multi-task with missing labels

Built-in Objectives

Binary Classification

Objective Description
focal_loss Focal loss for imbalanced data
binary_crossentropy Standard binary cross-entropy
weighted_binary_crossentropy Weighted binary cross-entropy
hinge_loss SVM-style hinge loss

Regression

Objective Description
mse Mean squared error
huber Huber loss (robust to outliers)
pseudo_huber Smooth approximation of Huber
log_cosh Log-cosh loss
mae_smooth Smooth approximation of MAE
quantile Quantile regression
asymmetric Asymmetric loss
tweedie Tweedie deviance
poisson Poisson deviance
gamma Gamma deviance

Multi-class Classification

Objective Description
softmax_cross_entropy Standard softmax cross-entropy
focal_multiclass Focal loss for multi-class
label_smoothing Label smoothing regularization
class_balanced Class-balanced loss

Ordinal Regression

Objective Description
ordinal_logit Cumulative Link Model (logit link)
ordinal_probit Cumulative Link Model (probit link)
qwk_ordinal QWK-aligned Expected Quadratic Error
squared_cdf_ordinal CRPS / Ranked Probability Score
hybrid_ordinal NLL + EQE hybrid
slace_objective SLACE (AAAI 2025)
sord_objective SORD - Soft Ordinal
oll_objective OLL - Ordinal Log-Loss

Survival Analysis

Objective Description
aft Accelerated failure time
weibull_aft Weibull AFT model

Multi-task Learning

Objective Description
multi_task_regression Multi-task MSE
multi_task_classification Multi-task classification
multi_task_huber Multi-task Huber loss
multi_task_quantile Multi-task quantile loss

Uncertainty Estimation

Objective Description
gaussian_nll Gaussian negative log-likelihood
laplace_nll Laplace negative log-likelihood

Module Structure

jaxboost/
└── objective/           # Automatic objective functions
    ├── auto.py          # @auto_objective decorator
    ├── binary.py        # Binary classification
    ├── regression.py    # Regression objectives
    ├── multiclass.py    # Multi-class classification
    ├── ordinal.py       # Ordinal regression (CLM)
    ├── multi_output.py  # Multi-output (uncertainty)
    ├── multi_task.py    # Multi-task learning
    └── survival.py      # Survival analysis