Skip to content

Quick Start

This guide will help you get started with JAXBoost in minutes.


Installation

pip install jaxboost

Using uv?

uv add jaxboost

Basic Usage

JAXBoost lets you write custom loss functions and automatically generates the gradients and Hessians needed by XGBoost and LightGBM.

Using Built-in Objectives

import xgboost as xgb
from jaxboost import focal_loss, huber, quantile

# Load your data
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
params = {"max_depth": 4, "eta": 0.1}

# Focal loss for imbalanced classification
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)

# Huber loss for robust regression
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)

# Quantile loss for median regression
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.5).xgb_objective)
import lightgbm as lgb
from jaxboost import huber

train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}

model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)

Custom Objectives

Create your own objective function with the @auto_objective decorator:

import jax.numpy as jnp
from jaxboost import auto_objective

@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
    """Penalize under-predictions more than over-predictions."""
    error = y_true - y_pred
    return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)

# Use with XGBoost
import xgboost as xgb
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)

# Use with LightGBM
import lightgbm as lgb
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=asymmetric_mse.lgb_objective)

How it works

The @auto_objective decorator wraps your loss function and uses JAX's grad and hessian to automatically compute derivatives. Your function should:

  1. Take y_pred and y_true as the first two arguments
  2. Return a scalar loss value per sample
  3. Use JAX-compatible operations (e.g., jax.numpy instead of numpy)

Passing Custom Parameters

# Default parameters
model = xgb.train(params, dtrain, obj=asymmetric_mse.xgb_objective)

# Custom parameters
model = xgb.train(
    params, dtrain, num_boost_round=100,
    obj=asymmetric_mse.get_xgb_objective(alpha=0.9)
)

Multi-class Classification

For multi-class problems, use @multiclass_objective:

import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective

@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
    """Custom multi-class loss."""
    probs = jax.nn.softmax(logits)
    return -jnp.log(probs[label] + 1e-7)

params = {"num_class": 3, "max_depth": 4}
model = xgb.train(params, dtrain, num_boost_round=100, obj=custom_multiclass.xgb_objective)

XGBoost Configuration

For multi-class objectives, make sure to set num_class in params to match your decorator.


Multi-task Learning

Handle multiple targets with optional missing labels:

import numpy as np
from jaxboost import MaskedMultiTaskObjective

# 3 regression tasks
objective = MaskedMultiTaskObjective(n_tasks=3)

# Create mask for missing labels (1 = valid, 0 = missing)
mask = np.ones_like(y_train)
mask[some_indices] = 0  # Mark missing labels

model = xgb.train(
    params, dtrain, num_boost_round=100,
    obj=objective.get_xgb_objective(mask=mask)
)

Ordinal Regression

For ordered categorical outcomes (ratings, grades, severity levels):

from jaxboost import ordinal_logit, qwk_ordinal

# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)

# Train with XGBoost
model = xgb.train(params, dtrain, num_boost_round=100, obj=ordinal.xgb_objective)

# Or LightGBM
model = lgb.train(params, train_data, num_boost_round=100, fobj=ordinal.lgb_objective)

# Get predictions
latent = model.predict(dtest)
probs = ordinal.predict_proba(latent)  # Class probabilities
classes = ordinal.predict(latent)       # Predicted classes

Why Ordinal Regression?

Standard approaches lose information:

  • Regression assumes equal intervals (3→4 same as 7→8)
  • Multi-class ignores ordering entirely

Ordinal regression learns proper class thresholds.


Survival Analysis

Built-in objectives for time-to-event modeling:

from jaxboost import aft, weibull_aft

# Accelerated failure time model
model = xgb.train(params, dtrain, obj=aft.xgb_objective)

# Weibull AFT model
model = xgb.train(params, dtrain, obj=weibull_aft.xgb_objective)

Data Format

For survival analysis, ensure your labels encode both the event time and censoring indicator according to the expected format.


Uncertainty Estimation

Predict both values and uncertainty using multi-output objectives:

from jaxboost import gaussian_nll

# Predicts mean and log-variance
model = xgb.train(params, dtrain, obj=gaussian_nll.xgb_objective)

# Get predictions (shape: [n_samples, 2])
preds = model.predict(dtest)
mean = preds[:, 0]
variance = np.exp(preds[:, 1])  # log-variance → variance

Available Objectives

Category Objectives
Regression mse, huber, quantile, tweedie, asymmetric, log_cosh, pseudo_huber, mae_smooth, poisson, gamma
Binary Classification focal_loss, binary_crossentropy, weighted_binary_crossentropy, hinge_loss
Multi-class softmax_cross_entropy, focal_multiclass, label_smoothing, class_balanced
Ordinal Regression ordinal_logit, ordinal_probit, qwk_ordinal, squared_cdf_ordinal, hybrid_ordinal, slace_objective
Survival aft, weibull_aft
Multi-task multi_task_regression, multi_task_classification, multi_task_huber, multi_task_quantile, MaskedMultiTaskObjective
Uncertainty gaussian_nll, laplace_nll

See the API Reference for detailed documentation of each objective.


Next Steps