API Reference¶
Welcome to the JAXBoost API documentation.
What is JAXBoost?
JAXBoost provides automatic objective functions for XGBoost and LightGBM using JAX automatic differentiation. Write a loss function, get gradients and Hessians automatically.
Quick Example¶
import xgboost as xgb
from jaxboost import auto_objective, focal_loss
# Load your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
# Use built-in objective
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
# Create custom objective
@auto_objective
def my_loss(y_pred, y_true):
return (y_pred - y_true) ** 2
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_loss.xgb_objective)
Core API¶
| Class/Decorator | Description |
|---|---|
@auto_objective |
Decorator for scalar loss functions |
AutoObjective |
Base class for custom objectives |
MultiClassObjective |
Multi-class classification objectives |
MultiOutputObjective |
Multi-output objectives (uncertainty) |
MaskedMultiTaskObjective |
Multi-task with missing labels |
Built-in Objectives¶
Binary Classification¶
| Objective | Description |
|---|---|
focal_loss |
Focal loss for imbalanced data |
binary_crossentropy |
Standard binary cross-entropy |
weighted_binary_crossentropy |
Weighted binary cross-entropy |
hinge_loss |
SVM-style hinge loss |
Regression¶
| Objective | Description |
|---|---|
mse |
Mean squared error |
huber |
Huber loss (robust to outliers) |
pseudo_huber |
Smooth approximation of Huber |
log_cosh |
Log-cosh loss |
mae_smooth |
Smooth approximation of MAE |
quantile |
Quantile regression |
asymmetric |
Asymmetric loss |
tweedie |
Tweedie deviance |
poisson |
Poisson deviance |
gamma |
Gamma deviance |
Multi-class Classification¶
| Objective | Description |
|---|---|
softmax_cross_entropy |
Standard softmax cross-entropy |
focal_multiclass |
Focal loss for multi-class |
label_smoothing |
Label smoothing regularization |
class_balanced |
Class-balanced loss |
Ordinal Regression¶
| Objective | Description |
|---|---|
ordinal_logit |
Cumulative Link Model (logit link) |
ordinal_probit |
Cumulative Link Model (probit link) |
qwk_ordinal |
QWK-aligned Expected Quadratic Error |
squared_cdf_ordinal |
CRPS / Ranked Probability Score |
hybrid_ordinal |
NLL + EQE hybrid |
slace_objective |
SLACE (AAAI 2025) |
sord_objective |
SORD - Soft Ordinal |
oll_objective |
OLL - Ordinal Log-Loss |
Survival Analysis¶
| Objective | Description |
|---|---|
aft |
Accelerated failure time |
weibull_aft |
Weibull AFT model |
Multi-task Learning¶
| Objective | Description |
|---|---|
multi_task_regression |
Multi-task MSE |
multi_task_classification |
Multi-task classification |
multi_task_huber |
Multi-task Huber loss |
multi_task_quantile |
Multi-task quantile loss |
Uncertainty Estimation¶
| Objective | Description |
|---|---|
gaussian_nll |
Gaussian negative log-likelihood |
laplace_nll |
Laplace negative log-likelihood |
Module Structure¶
jaxboost/
└── objective/ # Automatic objective functions
├── auto.py # @auto_objective decorator
├── binary.py # Binary classification
├── regression.py # Regression objectives
├── multiclass.py # Multi-class classification
├── ordinal.py # Ordinal regression (CLM)
├── multi_output.py # Multi-output (uncertainty)
├── multi_task.py # Multi-task learning
└── survival.py # Survival analysis