Skip to content

Objectives

Automatic objective functions for XGBoost and LightGBM using JAX autodiff.

How it works

JAXBoost uses JAX automatic differentiation to compute gradients and Hessians from your loss function. You write the loss, JAX computes the derivatives—no manual math required.


Core Classes

These are the building blocks for creating custom objectives.

AutoObjective

The primary class for scalar loss functions (binary classification, regression).

AutoObjective

AutoObjective(loss_fn: LossFunction)

Automatically generate XGBoost/LightGBM objective functions.

Uses JAX automatic differentiation to compute gradients and Hessians, eliminating the need for manual derivation.

Parameters:

Name Type Description Default
loss_fn LossFunction

A loss function that takes (y_pred, y_true, **kwargs) and returns a scalar loss. Should operate on single samples.

required
Example

@AutoObjective ... def my_loss(y_pred, y_true, alpha=0.5): ... return alpha * (y_pred - y_true) ** 2

Use with XGBoost

model = xgb.train(params, dtrain, obj=my_loss.xgb_objective)

Use with custom parameters

model = xgb.train(params, dtrain, obj=my_loss.get_xgb_objective(alpha=0.7))

Source code in src/jaxboost/objective/auto.py
def __init__(self, loss_fn: LossFunction) -> None:
    self.loss_fn = loss_fn
    self._name = getattr(loss_fn, "__name__", "custom_objective")
    self._default_kwargs: dict[str, Any] = {}

    # Pre-compile gradient and Hessian functions
    self._grad_fn = jax.grad(self._loss_wrapper, argnums=0)
    self._hess_fn = jax.grad(lambda *args, **kw: self._grad_fn(*args, **kw), argnums=0)

    # Cache for JIT-compiled vmap functions (keyed by array kwargs pattern)
    self._vmap_cache: dict[frozenset[str], tuple[Any, Any]] = {}

xgb_objective property

xgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost-compatible objective function using default parameters.

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost objective function: (y_pred, dtrain) -> (grad, hess)

Example

model = xgb.train(params, dtrain, obj=my_loss.xgb_objective)

lgb_objective property

lgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

LightGBM-compatible objective function using default parameters.

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

LightGBM objective function: (y_pred, dataset) -> (grad, hess)

Example

model = lgb.train(params, dtrain, fobj=my_loss.lgb_objective)

sklearn_objective property

sklearn_objective: Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

Sklearn-compatible objective for XGBClassifier/XGBRegressor.

Returns:

Type Description
Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

Sklearn objective function: (labels, predt) -> (grad, hess)

Example

from xgboost import XGBClassifier clf = XGBClassifier(objective=focal_loss.sklearn_objective) clf.fit(X_train, y_train)

__call__

__call__(y_pred: Array, y_true: Array, **kwargs: Any) -> jax.Array

Compute the loss value for a batch.

Source code in src/jaxboost/objective/auto.py
def __call__(self, y_pred: jax.Array, y_true: jax.Array, **kwargs: Any) -> jax.Array:
    """Compute the loss value for a batch."""
    n_samples = len(y_pred)
    scalar_kwargs, array_kwargs, array_keys = self._split_kwargs(kwargs, n_samples)
    in_axes = (0, 0, None, dict.fromkeys(array_keys, 0) if array_keys else None)
    return jax.vmap(self._loss_wrapper, in_axes=in_axes)(
        y_pred, y_true, scalar_kwargs, array_kwargs
    )

gradient

gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute gradient of loss w.r.t. y_pred for each sample.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples,)

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,)

required
**kwargs Any

Additional arguments passed to the loss function

{}

Returns:

Type Description
NDArray[floating[Any]]

Gradients, shape (n_samples,)

Source code in src/jaxboost/objective/auto.py
def gradient(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute gradient of loss w.r.t. y_pred for each sample.

    Args:
        y_pred: Predictions, shape (n_samples,)
        y_true: True labels, shape (n_samples,)
        **kwargs: Additional arguments passed to the loss function

    Returns:
        Gradients, shape (n_samples,)
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}
    y_pred_jax = jnp.asarray(y_pred, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.float32)

    n_samples = len(y_pred)
    scalar_kwargs, array_kwargs, array_keys = self._split_kwargs(merged_kwargs, n_samples)
    vmap_grad, _ = self._get_vmap_fns(array_keys)

    grads = vmap_grad(y_pred_jax, y_true_jax, scalar_kwargs, array_kwargs)
    return np.asarray(grads, dtype=np.float64)

hessian

hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute Hessian (second derivative) of loss w.r.t. y_pred for each sample.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples,)

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,)

required
**kwargs Any

Additional arguments passed to the loss function

{}

Returns:

Type Description
NDArray[floating[Any]]

Hessians (diagonal), shape (n_samples,)

Source code in src/jaxboost/objective/auto.py
def hessian(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute Hessian (second derivative) of loss w.r.t. y_pred for each sample.

    Args:
        y_pred: Predictions, shape (n_samples,)
        y_true: True labels, shape (n_samples,)
        **kwargs: Additional arguments passed to the loss function

    Returns:
        Hessians (diagonal), shape (n_samples,)
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}
    y_pred_jax = jnp.asarray(y_pred, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.float32)

    n_samples = len(y_pred)
    scalar_kwargs, array_kwargs, array_keys = self._split_kwargs(merged_kwargs, n_samples)
    _, vmap_hess = self._get_vmap_fns(array_keys)

    hess = vmap_hess(y_pred_jax, y_true_jax, scalar_kwargs, array_kwargs)
    return np.asarray(hess, dtype=np.float64)

grad_hess

grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]

Compute both gradient and Hessian efficiently.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples,)

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,)

required
sample_weight NDArray[floating[Any]] | None

Optional sample weights, shape (n_samples,)

None
**kwargs Any

Additional arguments passed to the loss function

{}

Returns:

Type Description
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]

Tuple of (gradients, hessians), each shape (n_samples,)

Source code in src/jaxboost/objective/auto.py
def grad_hess(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    sample_weight: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
    """
    Compute both gradient and Hessian efficiently.

    Args:
        y_pred: Predictions, shape (n_samples,)
        y_true: True labels, shape (n_samples,)
        sample_weight: Optional sample weights, shape (n_samples,)
        **kwargs: Additional arguments passed to the loss function

    Returns:
        Tuple of (gradients, hessians), each shape (n_samples,)
    """
    grad = self.gradient(y_pred, y_true, **kwargs)
    hess = self.hessian(y_pred, y_true, **kwargs)

    if sample_weight is not None and len(sample_weight) > 0:
        weight = np.asarray(sample_weight, dtype=np.float64)
        grad = grad * weight
        hess = hess * weight

    return grad, hess

get_xgb_objective

get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get an XGBoost-compatible objective function with custom parameters.

Automatically handles: - sample_weight: Sample weights from DMatrix - label_lower_bound: Lower bound labels for interval/survival regression - label_upper_bound: Upper bound labels for interval/survival regression

Parameters:

Name Type Description Default
**kwargs Any

Parameters to pass to the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost objective function: (y_pred, dtrain) -> (grad, hess)

Source code in src/jaxboost/objective/auto.py
def get_xgb_objective(
    self, **kwargs: Any
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get an XGBoost-compatible objective function with custom parameters.

    Automatically handles:
    - sample_weight: Sample weights from DMatrix
    - label_lower_bound: Lower bound labels for interval/survival regression
    - label_upper_bound: Upper bound labels for interval/survival regression

    Args:
        **kwargs: Parameters to pass to the loss function

    Returns:
        XGBoost objective function: (y_pred, dtrain) -> (grad, hess)
    """

    def objective(
        y_pred: NDArray[np.floating[Any]], dtrain: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dtrain.get_label()
        sample_weight = dtrain.get_weight()

        # Get label bounds for interval/survival regression
        extra_kwargs = dict(kwargs)

        if hasattr(dtrain, "get_float_info"):
            try:
                lower_bound = dtrain.get_float_info("label_lower_bound")
                if len(lower_bound) > 0:
                    extra_kwargs["label_lower_bound"] = lower_bound
            except Exception:
                pass
            try:
                upper_bound = dtrain.get_float_info("label_upper_bound")
                if len(upper_bound) > 0:
                    extra_kwargs["label_upper_bound"] = upper_bound
            except Exception:
                pass

        return self.grad_hess(y_pred, y_true, sample_weight=sample_weight, **extra_kwargs)

    objective.__name__ = f"{self._name}_xgb_objective"
    return objective

get_lgb_objective

get_lgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get a LightGBM-compatible objective function with custom parameters.

Automatically handles sample weights if set in the Dataset.

Parameters:

Name Type Description Default
**kwargs Any

Parameters to pass to the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

LightGBM objective function: (y_pred, dataset) -> (grad, hess)

Source code in src/jaxboost/objective/auto.py
def get_lgb_objective(
    self, **kwargs: Any
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get a LightGBM-compatible objective function with custom parameters.

    Automatically handles sample weights if set in the Dataset.

    Args:
        **kwargs: Parameters to pass to the loss function

    Returns:
        LightGBM objective function: (y_pred, dataset) -> (grad, hess)
    """

    def objective(
        y_pred: NDArray[np.floating[Any]], dataset: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dataset.get_label()
        sample_weight = dataset.get_weight()
        return self.grad_hess(y_pred, y_true, sample_weight=sample_weight, **kwargs)

    objective.__name__ = f"{self._name}_lgb_objective"
    return objective

with_params

with_params(**kwargs: Any) -> AutoObjective

Create a new AutoObjective with default parameters set.

Parameters:

Name Type Description Default
**kwargs Any

Default parameters for the loss function

{}

Returns:

Type Description
AutoObjective

New AutoObjective instance with default parameters

Example

focal = focal_loss.with_params(gamma=3.0, alpha=0.75) model = xgb.train(params, dtrain, obj=focal.xgb_objective)

Source code in src/jaxboost/objective/auto.py
def with_params(self, **kwargs: Any) -> AutoObjective:
    """
    Create a new AutoObjective with default parameters set.

    Args:
        **kwargs: Default parameters for the loss function

    Returns:
        New AutoObjective instance with default parameters

    Example:
        >>> focal = focal_loss.with_params(gamma=3.0, alpha=0.75)
        >>> model = xgb.train(params, dtrain, obj=focal.xgb_objective)
    """
    new_instance = AutoObjective(self.loss_fn)
    new_instance._default_kwargs = {**self._default_kwargs, **kwargs}
    return new_instance

get_sklearn_objective

get_sklearn_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], NDArray[np.floating[Any]]], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get a scikit-learn compatible objective function for XGBClassifier/XGBRegressor.

The sklearn interface expects (labels, predt) -> (grad, hess) instead of (predt, dtrain) -> (grad, hess).

Parameters:

Name Type Description Default
**kwargs Any

Parameters to pass to the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], NDArray[floating[Any]]], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

Sklearn-compatible objective: (labels, predt) -> (grad, hess)

Example

from xgboost import XGBClassifier clf = XGBClassifier(objective=focal_loss.sklearn_objective) clf.fit(X_train, y_train)

Source code in src/jaxboost/objective/auto.py
def get_sklearn_objective(
    self, **kwargs: Any
) -> Callable[
    [NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get a scikit-learn compatible objective function for XGBClassifier/XGBRegressor.

    The sklearn interface expects (labels, predt) -> (grad, hess) instead of
    (predt, dtrain) -> (grad, hess).

    Args:
        **kwargs: Parameters to pass to the loss function

    Returns:
        Sklearn-compatible objective: (labels, predt) -> (grad, hess)

    Example:
        >>> from xgboost import XGBClassifier
        >>> clf = XGBClassifier(objective=focal_loss.sklearn_objective)
        >>> clf.fit(X_train, y_train)
    """

    def objective(
        labels: NDArray[np.floating[Any]], predt: NDArray[np.floating[Any]]
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        return self.grad_hess(predt, labels, **kwargs)

    objective.__name__ = f"{self._name}_sklearn_objective"
    return objective

MultiClassObjective

For multi-class classification problems where predictions are logit vectors.

MultiClassObjective

MultiClassObjective(loss_fn: Callable[..., Array] | None = None, n_classes: int = 3)

Objective function wrapper for multi-class classification.

Handles the specific requirements of multi-class classification with XGBoost/LightGBM: - Predictions are logits of shape (n_samples, n_classes) - Labels are integer class indices - Computes gradients w.r.t. each class logit

Parameters:

Name Type Description Default
loss_fn Callable[..., Array] | None

A loss function that takes (logits, label, **kwargs) where: - logits: shape (n_classes,) raw scores for each class - label: integer class index (0 to n_classes-1) Returns a scalar loss.

None
n_classes int

Number of classes

3
Example

@MultiClassObjective(n_classes=3) ... def my_multiclass_loss(logits, label): ... probs = jax.nn.softmax(logits) ... return -jnp.log(probs[label] + 1e-10)

params = {'num_class': 3} model = xgb.train(params, dtrain, obj=my_multiclass_loss.xgb_objective)

Source code in src/jaxboost/objective/multiclass.py
def __init__(
    self,
    loss_fn: Callable[..., jax.Array] | None = None,
    n_classes: int = 3,
) -> None:
    self.n_classes = n_classes
    self._default_kwargs: dict[str, Any] = {}

    if loss_fn is not None:
        self._init_with_fn(loss_fn)
    else:
        self.loss_fn = None
        self._name = "multiclass_objective"

xgb_objective property

xgb_objective: Callable[..., Any]

XGBoost-compatible objective function.

__call__

__call__(loss_fn: Callable[..., Array]) -> MultiClassObjective

Allow use as a decorator with arguments.

Source code in src/jaxboost/objective/multiclass.py
def __call__(self, loss_fn: Callable[..., jax.Array]) -> MultiClassObjective:
    """Allow use as a decorator with arguments."""
    new_instance = MultiClassObjective(loss_fn=loss_fn, n_classes=self.n_classes)
    new_instance._default_kwargs = self._default_kwargs.copy()
    return new_instance

gradient

gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute gradient for multi-class predictions.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples, n_classes) or flattened

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,) integer class indices

required
**kwargs Any

Additional arguments

{}

Returns:

Type Description
NDArray[floating[Any]]

Gradients, shape (n_samples, n_classes)

Source code in src/jaxboost/objective/multiclass.py
def gradient(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute gradient for multi-class predictions.

    Args:
        y_pred: Predictions, shape (n_samples, n_classes) or flattened
        y_true: True labels, shape (n_samples,) integer class indices
        **kwargs: Additional arguments

    Returns:
        Gradients, shape (n_samples, n_classes)
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}

    # Reshape predictions if needed
    y_pred_arr = np.asarray(y_pred)
    y_pred_2d = y_pred_arr.reshape(-1, self.n_classes) if y_pred_arr.ndim == 1 else y_pred_arr

    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.int32)

    # Compute gradients for each sample
    def grad_single(logits: jax.Array, label: jax.Array) -> jax.Array:
        return self._grad_fn(logits, label, merged_kwargs)

    grads = jax.vmap(grad_single)(y_pred_jax, y_true_jax)

    return np.asarray(grads, dtype=np.float64)

hessian

hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute diagonal Hessian for multi-class predictions.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples, n_classes) or flattened

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,) integer class indices

required
**kwargs Any

Additional arguments

{}

Returns:

Type Description
NDArray[floating[Any]]

Diagonal Hessians, shape (n_samples, n_classes)

Source code in src/jaxboost/objective/multiclass.py
def hessian(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute diagonal Hessian for multi-class predictions.

    Args:
        y_pred: Predictions, shape (n_samples, n_classes) or flattened
        y_true: True labels, shape (n_samples,) integer class indices
        **kwargs: Additional arguments

    Returns:
        Diagonal Hessians, shape (n_samples, n_classes)
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}

    y_pred_arr = np.asarray(y_pred)
    y_pred_2d = y_pred_arr.reshape(-1, self.n_classes) if y_pred_arr.ndim == 1 else y_pred_arr
    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.int32)

    def hess_single(logits: jax.Array, label: jax.Array) -> jax.Array:
        return self._hess_fn(logits, label, merged_kwargs)

    hess = jax.vmap(hess_single)(y_pred_jax, y_true_jax)

    return np.asarray(hess, dtype=np.float64)

grad_hess

grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]

Compute both gradient and Hessian.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples, n_classes)

required
y_true NDArray[floating[Any]]

True labels, shape (n_samples,)

required
sample_weight NDArray[floating[Any]] | None

Optional sample weights, shape (n_samples,)

None
**kwargs Any

Additional arguments

{}

Returns:

Type Description
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]

Tuple of (gradients, hessians), each shape (n_samples, n_classes)

Source code in src/jaxboost/objective/multiclass.py
def grad_hess(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    sample_weight: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
    """
    Compute both gradient and Hessian.

    Args:
        y_pred: Predictions, shape (n_samples, n_classes)
        y_true: True labels, shape (n_samples,)
        sample_weight: Optional sample weights, shape (n_samples,)
        **kwargs: Additional arguments

    Returns:
        Tuple of (gradients, hessians), each shape (n_samples, n_classes)
    """
    grad = self.gradient(y_pred, y_true, **kwargs)
    hess = self.hessian(y_pred, y_true, **kwargs)

    if sample_weight is not None and len(sample_weight) > 0:
        weight = np.asarray(sample_weight, dtype=np.float64)
        grad = grad * weight[:, np.newaxis]
        hess = hess * weight[:, np.newaxis]

    return grad, hess

get_xgb_objective

get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get an XGBoost-compatible objective function for multi-class.

Use with XGBoost params: {'num_class': n_classes}

Parameters:

Name Type Description Default
**kwargs Any

Parameters to pass to the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost objective function

Source code in src/jaxboost/objective/multiclass.py
def get_xgb_objective(
    self, **kwargs: Any
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get an XGBoost-compatible objective function for multi-class.

    Use with XGBoost params: {'num_class': n_classes}

    Args:
        **kwargs: Parameters to pass to the loss function

    Returns:
        XGBoost objective function
    """

    def objective(
        y_pred: NDArray[np.floating[Any]], dtrain: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dtrain.get_label()
        sample_weight = dtrain.get_weight()
        return self.grad_hess(y_pred, y_true, sample_weight=sample_weight, **kwargs)

    objective.__name__ = f"{self._name}_xgb_objective"
    return objective

with_params

with_params(**kwargs: Any) -> MultiClassObjective

Create a new instance with default parameters set.

Source code in src/jaxboost/objective/multiclass.py
def with_params(self, **kwargs: Any) -> MultiClassObjective:
    """Create a new instance with default parameters set."""
    new_instance = MultiClassObjective(loss_fn=self.loss_fn, n_classes=self.n_classes)
    new_instance._default_kwargs = {**self._default_kwargs, **kwargs}
    return new_instance

MultiOutputObjective

For multi-output predictions like uncertainty estimation (mean + variance).

MultiOutputObjective

MultiOutputObjective(loss_fn: Callable[..., Array] | None = None, n_outputs: int = 2)

Objective for multi-output/multi-task learning.

Handles cases where each sample has multiple predictions, such as: - Multi-target regression - Parametric models learning multiple parameters - Uncertainty estimation (mean + variance)

Parameters:

Name Type Description Default
loss_fn Callable[..., Array] | None

A loss function that takes (y_pred, y_true, **kwargs) where: - y_pred: shape (n_outputs,) for a single sample - y_true: shape (n_outputs,) or scalar for a single sample Returns a scalar loss.

None
n_outputs int

Number of outputs per sample

2
Example

@MultiOutputObjective(n_outputs=2) ... def parametric_loss(params, y_true, t=None): ... # params = [A, B] for model y = A + B*t ... A, B = params[0], params[1] ... y_pred = A + B * t ... return (y_pred - y_true) ** 2

model = xgb.train(params, dtrain, obj=parametric_loss.xgb_objective)

Source code in src/jaxboost/objective/multi_output.py
def __init__(
    self,
    loss_fn: Callable[..., jax.Array] | None = None,
    n_outputs: int = 2,
) -> None:
    self.n_outputs = n_outputs
    self._default_kwargs: dict[str, Any] = {}

    if loss_fn is not None:
        self._init_with_fn(loss_fn)
    else:
        self.loss_fn = None
        self._name = "multi_output_objective"

xgb_objective property

xgb_objective: Callable[..., Any]

XGBoost-compatible objective function.

__call__

__call__(loss_fn: Callable[..., Array]) -> MultiOutputObjective

Allow use as a decorator with arguments.

Source code in src/jaxboost/objective/multi_output.py
def __call__(self, loss_fn: Callable[..., jax.Array]) -> MultiOutputObjective:
    """Allow use as a decorator with arguments."""
    new_instance = MultiOutputObjective(loss_fn=loss_fn, n_outputs=self.n_outputs)
    new_instance._default_kwargs = self._default_kwargs.copy()
    return new_instance

gradient

gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute gradient for multi-output predictions.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples, n_outputs) or flattened

required
y_true NDArray[floating[Any]]

True labels

required
**kwargs Any

Additional arguments

{}

Returns:

Type Description
NDArray[floating[Any]]

Gradients, shape (n_samples * n_outputs,) for XGBoost compatibility

Source code in src/jaxboost/objective/multi_output.py
def gradient(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute gradient for multi-output predictions.

    Args:
        y_pred: Predictions, shape (n_samples, n_outputs) or flattened
        y_true: True labels
        **kwargs: Additional arguments

    Returns:
        Gradients, shape (n_samples * n_outputs,) for XGBoost compatibility
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}

    # Reshape if flattened
    y_pred_2d = np.asarray(y_pred).reshape(-1, self.n_outputs)
    n_samples = y_pred_2d.shape[0]

    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.float32)

    scalar_kwargs, array_kwargs, _ = self._split_kwargs(merged_kwargs, n_samples)

    # Compute gradients for each sample
    def grad_single(
        y_pred_i: jax.Array, y_true_i: jax.Array, array_kwargs_i: dict[str, Any]
    ) -> jax.Array:
        return self._grad_fn(y_pred_i, y_true_i, scalar_kwargs, array_kwargs_i)

    if array_kwargs:
        grads = jax.vmap(
            lambda yp, yt, *arr_vals: grad_single(
                yp, yt, dict(zip(array_kwargs.keys(), arr_vals, strict=False))
            )
        )(y_pred_jax, y_true_jax, *array_kwargs.values())
    else:
        grads = jax.vmap(lambda yp, yt: grad_single(yp, yt, {}))(y_pred_jax, y_true_jax)

    # Flatten for XGBoost
    return np.asarray(grads, dtype=np.float64).flatten()

hessian

hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], **kwargs: Any) -> NDArray[np.floating[Any]]

Compute diagonal Hessian for multi-output predictions.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples, n_outputs) or flattened

required
y_true NDArray[floating[Any]]

True labels

required
**kwargs Any

Additional arguments

{}

Returns:

Type Description
NDArray[floating[Any]]

Diagonal Hessians, shape (n_samples * n_outputs,) for XGBoost

Source code in src/jaxboost/objective/multi_output.py
def hessian(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute diagonal Hessian for multi-output predictions.

    Args:
        y_pred: Predictions, shape (n_samples, n_outputs) or flattened
        y_true: True labels
        **kwargs: Additional arguments

    Returns:
        Diagonal Hessians, shape (n_samples * n_outputs,) for XGBoost
    """
    merged_kwargs = {**self._default_kwargs, **kwargs}

    y_pred_2d = np.asarray(y_pred).reshape(-1, self.n_outputs)
    n_samples = y_pred_2d.shape[0]

    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true, dtype=jnp.float32)

    scalar_kwargs, array_kwargs, _ = self._split_kwargs(merged_kwargs, n_samples)

    def hess_single(
        y_pred_i: jax.Array, y_true_i: jax.Array, array_kwargs_i: dict[str, Any]
    ) -> jax.Array:
        return self._hess_fn(y_pred_i, y_true_i, scalar_kwargs, array_kwargs_i)

    if array_kwargs:
        hess = jax.vmap(
            lambda yp, yt, *arr_vals: hess_single(
                yp, yt, dict(zip(array_kwargs.keys(), arr_vals, strict=False))
            )
        )(y_pred_jax, y_true_jax, *array_kwargs.values())
    else:
        hess = jax.vmap(lambda yp, yt: hess_single(yp, yt, {}))(y_pred_jax, y_true_jax)

    return np.asarray(hess, dtype=np.float64).flatten()

grad_hess

grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]

Compute both gradient and Hessian.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples * n_outputs,)

required
y_true NDArray[floating[Any]]

True labels

required
sample_weight NDArray[floating[Any]] | None

Optional sample weights, shape (n_samples,)

None
**kwargs Any

Additional arguments

{}

Returns:

Type Description
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]

Tuple of (gradients, hessians), each shape (n_samples * n_outputs,)

Source code in src/jaxboost/objective/multi_output.py
def grad_hess(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    sample_weight: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
    """
    Compute both gradient and Hessian.

    Args:
        y_pred: Predictions, shape (n_samples * n_outputs,)
        y_true: True labels
        sample_weight: Optional sample weights, shape (n_samples,)
        **kwargs: Additional arguments

    Returns:
        Tuple of (gradients, hessians), each shape (n_samples * n_outputs,)
    """
    grad = self.gradient(y_pred, y_true, **kwargs)
    hess = self.hessian(y_pred, y_true, **kwargs)

    if sample_weight is not None and len(sample_weight) > 0:
        weight = np.asarray(sample_weight, dtype=np.float64)
        weight_expanded = np.repeat(weight, self.n_outputs)
        grad = grad * weight_expanded
        hess = hess * weight_expanded

    return grad, hess

get_xgb_objective

get_xgb_objective(**kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get an XGBoost-compatible objective function for multi-output.

Note: Requires XGBoost with multi-output support. Set multi_strategy='multi_output_tree' and num_target=n_outputs in params.

Parameters:

Name Type Description Default
**kwargs Any

Parameters to pass to the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost objective function

Source code in src/jaxboost/objective/multi_output.py
def get_xgb_objective(
    self, **kwargs: Any
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get an XGBoost-compatible objective function for multi-output.

    Note: Requires XGBoost with multi-output support.
    Set `multi_strategy='multi_output_tree'` and `num_target=n_outputs` in params.

    Args:
        **kwargs: Parameters to pass to the loss function

    Returns:
        XGBoost objective function
    """

    def objective(
        y_pred: NDArray[np.floating[Any]], dtrain: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dtrain.get_label()
        sample_weight = dtrain.get_weight()
        return self.grad_hess(y_pred, y_true, sample_weight=sample_weight, **kwargs)

    objective.__name__ = f"{self._name}_xgb_objective"
    return objective

with_params

with_params(**kwargs: Any) -> MultiOutputObjective

Create a new instance with default parameters set.

Source code in src/jaxboost/objective/multi_output.py
def with_params(self, **kwargs: Any) -> MultiOutputObjective:
    """Create a new instance with default parameters set."""
    new_instance = MultiOutputObjective(loss_fn=self.loss_fn, n_outputs=self.n_outputs)
    new_instance._default_kwargs = {**self._default_kwargs, **kwargs}
    return new_instance

MaskedMultiTaskObjective

For multi-task learning with potentially missing labels.

MaskedMultiTaskObjective

MaskedMultiTaskObjective(task_loss_fn: Callable[[Array, Array], Array] | None = None, n_tasks: int = 2, task_weights: list[float] | NDArray[floating[Any]] | None = None)

Multi-task objective with support for missing labels.

Key features: - Handles arbitrary label missingness patterns - Gradients are 0 for missing labels (no update) - Supports per-task loss functions - Automatic gradient/Hessian computation via JAX

Parameters:

Name Type Description Default
n_tasks int

Number of tasks (outputs per sample)

2
task_loss_fn Callable[[Array, Array], Array] | None

Loss function for each task. Signature: (y_pred, y_true) -> scalar loss Default is squared error.

None
task_weights list[float] | NDArray[floating[Any]] | None

Optional weights for each task, shape (n_tasks,)

None
Example

Basic usage

obj = MaskedMultiTaskObjective(n_tasks=3)

Custom per-task loss

@MaskedMultiTaskObjective(n_tasks=3) ... def my_mtl_loss(y_pred, y_true): ... return (y_pred - y_true) ** 2

Different weights per task

obj = MaskedMultiTaskObjective(n_tasks=3, task_weights=[1.0, 2.0, 0.5])

Source code in src/jaxboost/objective/multi_task.py
def __init__(
    self,
    task_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
    n_tasks: int = 2,
    task_weights: list[float] | NDArray[np.floating[Any]] | None = None,
) -> None:
    self.n_tasks = n_tasks
    self.task_weights = (
        jnp.asarray(task_weights, dtype=jnp.float32)
        if task_weights is not None
        else jnp.ones(n_tasks, dtype=jnp.float32)
    )
    self._default_kwargs: dict[str, Any] = {}

    if task_loss_fn is not None:
        self._init_with_fn(task_loss_fn)
    else:
        # Default: squared error
        self._init_with_fn(lambda y_pred, y_true: (y_pred - y_true) ** 2)

xgb_objective property

xgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost-compatible objective function (no mask).

For missing label support, use get_xgb_objective(mask=...) instead.

lgb_objective property

lgb_objective: Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

LightGBM-compatible objective function (no mask).

__call__

__call__(task_loss_fn: Callable[[Array, Array], Array]) -> MaskedMultiTaskObjective

Allow use as decorator.

Source code in src/jaxboost/objective/multi_task.py
def __call__(
    self, task_loss_fn: Callable[[jax.Array, jax.Array], jax.Array]
) -> MaskedMultiTaskObjective:
    """Allow use as decorator."""
    new_instance = MaskedMultiTaskObjective(
        task_loss_fn=task_loss_fn,
        n_tasks=self.n_tasks,
        task_weights=self.task_weights,
    )
    new_instance._default_kwargs = self._default_kwargs.copy()
    return new_instance

gradient

gradient(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, **kwargs: Any) -> NDArray[np.floating[Any]]

Compute gradients with missing label support.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples * n_tasks,) or (n_samples, n_tasks)

required
y_true NDArray[floating[Any]]

Labels, shape (n_samples * n_tasks,) or (n_samples, n_tasks)

required
mask NDArray[floating[Any]] | None

Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks) 1 = valid label, 0 = missing label. Default: all valid.

None
**kwargs Any

Additional arguments (unused)

{}

Returns:

Type Description
NDArray[floating[Any]]

Gradients, shape (n_samples * n_tasks,)

Source code in src/jaxboost/objective/multi_task.py
def gradient(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    mask: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute gradients with missing label support.

    Args:
        y_pred: Predictions, shape (n_samples * n_tasks,) or (n_samples, n_tasks)
        y_true: Labels, shape (n_samples * n_tasks,) or (n_samples, n_tasks)
        mask: Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks)
              1 = valid label, 0 = missing label. Default: all valid.
        **kwargs: Additional arguments (unused)

    Returns:
        Gradients, shape (n_samples * n_tasks,)
    """
    # Reshape to 2D
    y_pred_2d = np.asarray(y_pred, dtype=np.float64).reshape(-1, self.n_tasks)
    y_true_2d = np.asarray(y_true, dtype=np.float64).reshape(-1, self.n_tasks)
    n_samples = y_pred_2d.shape[0]

    if mask is None:
        mask_2d = np.ones((n_samples, self.n_tasks), dtype=np.float32)
    else:
        mask_2d = np.asarray(mask, dtype=np.float32).reshape(-1, self.n_tasks)

    # Convert to JAX arrays
    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true_2d, dtype=jnp.float32)
    mask_jax = jnp.asarray(mask_2d, dtype=jnp.float32)

    # Compute gradients for all samples (vectorized)
    grads, _ = jax.vmap(self._compute_grad_hess_single)(y_pred_jax, y_true_jax, mask_jax)

    return np.asarray(grads, dtype=np.float64).flatten()

hessian

hessian(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, **kwargs: Any) -> NDArray[np.floating[Any]]

Compute Hessians with missing label support.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples * n_tasks,)

required
y_true NDArray[floating[Any]]

Labels, shape (n_samples * n_tasks,)

required
mask NDArray[floating[Any]] | None

Label mask, shape (n_samples * n_tasks,)

None
**kwargs Any

Additional arguments

{}

Returns:

Type Description
NDArray[floating[Any]]

Diagonal Hessians, shape (n_samples * n_tasks,)

Source code in src/jaxboost/objective/multi_task.py
def hessian(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    mask: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> NDArray[np.floating[Any]]:
    """
    Compute Hessians with missing label support.

    Args:
        y_pred: Predictions, shape (n_samples * n_tasks,)
        y_true: Labels, shape (n_samples * n_tasks,)
        mask: Label mask, shape (n_samples * n_tasks,)
        **kwargs: Additional arguments

    Returns:
        Diagonal Hessians, shape (n_samples * n_tasks,)
    """
    y_pred_2d = np.asarray(y_pred, dtype=np.float64).reshape(-1, self.n_tasks)
    y_true_2d = np.asarray(y_true, dtype=np.float64).reshape(-1, self.n_tasks)
    n_samples = y_pred_2d.shape[0]

    if mask is None:
        mask_2d = np.ones((n_samples, self.n_tasks), dtype=np.float32)
    else:
        mask_2d = np.asarray(mask, dtype=np.float32).reshape(-1, self.n_tasks)

    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true_2d, dtype=jnp.float32)
    mask_jax = jnp.asarray(mask_2d, dtype=jnp.float32)

    _, hess = jax.vmap(self._compute_grad_hess_single)(y_pred_jax, y_true_jax, mask_jax)

    return np.asarray(hess, dtype=np.float64).flatten()

grad_hess

grad_hess(y_pred: NDArray[floating[Any]], y_true: NDArray[floating[Any]], mask: NDArray[floating[Any]] | None = None, sample_weight: NDArray[floating[Any]] | None = None, **kwargs: Any) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]

Compute both gradient and Hessian efficiently.

Parameters:

Name Type Description Default
y_pred NDArray[floating[Any]]

Predictions, shape (n_samples * n_tasks,)

required
y_true NDArray[floating[Any]]

Labels, shape (n_samples * n_tasks,)

required
mask NDArray[floating[Any]] | None

Label mask, shape (n_samples * n_tasks,)

None
sample_weight NDArray[floating[Any]] | None

Optional sample weights, shape (n_samples,)

None
**kwargs Any

Additional arguments

{}

Returns:

Type Description
tuple[NDArray[floating[Any]], NDArray[floating[Any]]]

Tuple of (gradients, hessians)

Source code in src/jaxboost/objective/multi_task.py
def grad_hess(
    self,
    y_pred: NDArray[np.floating[Any]],
    y_true: NDArray[np.floating[Any]],
    mask: NDArray[np.floating[Any]] | None = None,
    sample_weight: NDArray[np.floating[Any]] | None = None,
    **kwargs: Any,
) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
    """
    Compute both gradient and Hessian efficiently.

    Args:
        y_pred: Predictions, shape (n_samples * n_tasks,)
        y_true: Labels, shape (n_samples * n_tasks,)
        mask: Label mask, shape (n_samples * n_tasks,)
        sample_weight: Optional sample weights, shape (n_samples,)
        **kwargs: Additional arguments

    Returns:
        Tuple of (gradients, hessians)
    """
    y_pred_2d = np.asarray(y_pred, dtype=np.float64).reshape(-1, self.n_tasks)
    y_true_2d = np.asarray(y_true, dtype=np.float64).reshape(-1, self.n_tasks)
    n_samples = y_pred_2d.shape[0]

    if mask is None:
        mask_2d = np.ones((n_samples, self.n_tasks), dtype=np.float32)
    else:
        mask_2d = np.asarray(mask, dtype=np.float32).reshape(-1, self.n_tasks)

    y_pred_jax = jnp.asarray(y_pred_2d, dtype=jnp.float32)
    y_true_jax = jnp.asarray(y_true_2d, dtype=jnp.float32)
    mask_jax = jnp.asarray(mask_2d, dtype=jnp.float32)

    # JIT-compiled batch computation
    @jax.jit
    def batch_grad_hess(y_pred, y_true, mask):
        return jax.vmap(self._compute_grad_hess_single)(y_pred, y_true, mask)

    grads, hess = batch_grad_hess(y_pred_jax, y_true_jax, mask_jax)

    grads = np.asarray(grads, dtype=np.float64).flatten()
    hess = np.asarray(hess, dtype=np.float64).flatten()

    # Apply sample weights if provided
    if sample_weight is not None and len(sample_weight) > 0:
        weight = np.asarray(sample_weight, dtype=np.float64)
        weight_expanded = np.repeat(weight, self.n_tasks)
        grads = grads * weight_expanded
        hess = hess * weight_expanded

    return grads, hess

get_xgb_objective

get_xgb_objective(mask: NDArray[floating[Any]] | None = None, mask_key: str | None = None, **kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get an XGBoost-compatible objective function.

Parameters:

Name Type Description Default
mask NDArray[floating[Any]] | None

Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks). 1 = valid label, 0 = missing label. If None, all labels valid.

None
mask_key str | None

Key to retrieve mask from DMatrix via get_float_info(). Use this for distributed training (e.g., Ray XGBoost) where data is partitioned across workers. The mask will be read per-worker, ensuring correct alignment with local data.

None
**kwargs Any

Additional parameters for the loss function

{}

Returns:

Type Description
Callable[[NDArray[floating[Any]], Any], tuple[NDArray[floating[Any]], NDArray[floating[Any]]]]

XGBoost objective function: (y_pred, dtrain) -> (grad, hess)

Example

Single-node: pass mask directly

y_true = np.array([[1.0, np.nan, 0.5], ... [np.nan, 2.0, 1.0]]) mask = ~np.isnan(y_true) y_true_filled = np.nan_to_num(y_true, nan=0.0) dtrain = xgb.DMatrix(X, label=y_true_filled.flatten()) obj = MaskedMultiTaskObjective(n_tasks=3) model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask=mask))

Distributed (Ray XGBoost): store mask in DMatrix

dtrain.set_float_info("label_mask", mask.astype(np.float32).flatten()) model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask_key="label_mask"))

Source code in src/jaxboost/objective/multi_task.py
def get_xgb_objective(
    self,
    mask: NDArray[np.floating[Any]] | None = None,
    mask_key: str | None = None,
    **kwargs: Any,
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get an XGBoost-compatible objective function.

    Args:
        mask: Label mask, shape (n_samples * n_tasks,) or (n_samples, n_tasks).
              1 = valid label, 0 = missing label. If None, all labels valid.
        mask_key: Key to retrieve mask from DMatrix via get_float_info().
                  Use this for distributed training (e.g., Ray XGBoost) where
                  data is partitioned across workers. The mask will be read
                  per-worker, ensuring correct alignment with local data.
        **kwargs: Additional parameters for the loss function

    Returns:
        XGBoost objective function: (y_pred, dtrain) -> (grad, hess)

    Example:
        >>> # Single-node: pass mask directly
        >>> y_true = np.array([[1.0, np.nan, 0.5],
        ...                    [np.nan, 2.0, 1.0]])
        >>> mask = ~np.isnan(y_true)
        >>> y_true_filled = np.nan_to_num(y_true, nan=0.0)
        >>> dtrain = xgb.DMatrix(X, label=y_true_filled.flatten())
        >>> obj = MaskedMultiTaskObjective(n_tasks=3)
        >>> model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask=mask))

        >>> # Distributed (Ray XGBoost): store mask in DMatrix
        >>> dtrain.set_float_info("label_mask", mask.astype(np.float32).flatten())
        >>> model = xgb.train(params, dtrain, obj=obj.get_xgb_objective(mask_key="label_mask"))
    """
    captured_mask = mask
    n_tasks = self.n_tasks

    def objective(
        y_pred: NDArray[np.floating[Any]], dtrain: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dtrain.get_label()
        sample_weight = dtrain.get_weight()

        if mask_key is not None:
            current_mask = dtrain.get_float_info(mask_key).reshape(-1, n_tasks)
        else:
            current_mask = captured_mask

        return self.grad_hess(
            y_pred, y_true, mask=current_mask, sample_weight=sample_weight, **kwargs
        )

    objective.__name__ = f"{self._name}_xgb_objective"
    return objective

get_lgb_objective

get_lgb_objective(mask: NDArray[floating[Any]] | None = None, mask_key: str | None = None, **kwargs: Any) -> Callable[[NDArray[np.floating[Any]], Any], tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]]

Get a LightGBM-compatible objective function.

Parameters:

Name Type Description Default
mask NDArray[floating[Any]] | None

Label mask for missing labels.

None
mask_key str | None

Key to retrieve mask from Dataset. Use for distributed training.

None
**kwargs Any

Additional parameters.

{}

Note: LightGBM multi-output support is more limited than XGBoost.

Source code in src/jaxboost/objective/multi_task.py
def get_lgb_objective(
    self,
    mask: NDArray[np.floating[Any]] | None = None,
    mask_key: str | None = None,
    **kwargs: Any,
) -> Callable[
    [NDArray[np.floating[Any]], Any],
    tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]],
]:
    """
    Get a LightGBM-compatible objective function.

    Args:
        mask: Label mask for missing labels.
        mask_key: Key to retrieve mask from Dataset. Use for distributed training.
        **kwargs: Additional parameters.

    Note: LightGBM multi-output support is more limited than XGBoost.
    """
    captured_mask = mask
    n_tasks = self.n_tasks

    def objective(
        y_pred: NDArray[np.floating[Any]], dataset: Any
    ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
        y_true = dataset.get_label()
        sample_weight = dataset.get_weight()

        if mask_key is not None and hasattr(dataset, "get_data"):
            current_mask = dataset.get_data(mask_key).reshape(-1, n_tasks)
        else:
            current_mask = captured_mask

        return self.grad_hess(
            y_pred, y_true, mask=current_mask, sample_weight=sample_weight, **kwargs
        )

    objective.__name__ = f"{self._name}_lgb_objective"
    return objective

with_params

with_params(**kwargs: Any) -> MaskedMultiTaskObjective

Create a new instance with default parameters set.

Source code in src/jaxboost/objective/multi_task.py
def with_params(self, **kwargs: Any) -> MaskedMultiTaskObjective:
    """Create a new instance with default parameters set."""
    new_instance = MaskedMultiTaskObjective(
        task_loss_fn=self.task_loss_fn,
        n_tasks=self.n_tasks,
        task_weights=self.task_weights,
    )
    new_instance._default_kwargs = {**self._default_kwargs, **kwargs}
    return new_instance

Binary Classification

Objectives for binary classification tasks (labels in {0, 1}).

focal_loss

focal_loss(y_pred: Array, y_true: Array, gamma: float = 2.0, alpha: float = 0.25) -> jax.Array

Focal Loss for imbalanced binary classification.

Down-weights well-classified examples and focuses on hard examples. Particularly useful for highly imbalanced datasets.

Parameters:

Name Type Description Default
y_pred Array

Raw prediction (logit), will be passed through sigmoid

required
y_true Array

Binary label (0 or 1)

required
gamma float

Focusing parameter. Higher = more focus on hard examples. Default: 2.0

2.0
alpha float

Class balance weight for positive class. Default: 0.25

0.25
Reference

Lin et al. "Focal Loss for Dense Object Detection" (2017) https://arxiv.org/abs/1708.02002

Example

model = xgb.train(params, dtrain, obj=focal_loss.xgb_objective)

With custom parameters:

model = xgb.train(params, dtrain, obj=focal_loss.get_xgb_objective(gamma=3.0))

Source code in src/jaxboost/objective/binary.py
@AutoObjective
def focal_loss(
    y_pred: jax.Array,
    y_true: jax.Array,
    gamma: float = 2.0,
    alpha: float = 0.25,
) -> jax.Array:
    """
    Focal Loss for imbalanced binary classification.

    Down-weights well-classified examples and focuses on hard examples.
    Particularly useful for highly imbalanced datasets.

    Args:
        y_pred: Raw prediction (logit), will be passed through sigmoid
        y_true: Binary label (0 or 1)
        gamma: Focusing parameter. Higher = more focus on hard examples. Default: 2.0
        alpha: Class balance weight for positive class. Default: 0.25

    Reference:
        Lin et al. "Focal Loss for Dense Object Detection" (2017)
        https://arxiv.org/abs/1708.02002

    Example:
        >>> model = xgb.train(params, dtrain, obj=focal_loss.xgb_objective)
        >>> # With custom parameters:
        >>> model = xgb.train(params, dtrain, obj=focal_loss.get_xgb_objective(gamma=3.0))
    """
    p = jax.nn.sigmoid(y_pred)
    # Clip for numerical stability
    p = jnp.clip(p, 1e-7, 1 - 1e-7)

    # Cross-entropy loss
    ce_loss = -y_true * jnp.log(p) - (1 - y_true) * jnp.log(1 - p)

    # Focal weight: (1 - p_t)^gamma where p_t is prob of true class
    p_t = y_true * p + (1 - y_true) * (1 - p)
    alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
    focal_weight = alpha_t * (1 - p_t) ** gamma

    return focal_weight * ce_loss

binary_crossentropy

binary_crossentropy(y_pred: Array, y_true: Array) -> jax.Array

Binary Cross-Entropy Loss.

Standard binary classification loss. Included for testing and as a baseline.

Parameters:

Name Type Description Default
y_pred Array

Raw prediction (logit), will be passed through sigmoid

required
y_true Array

Binary label (0 or 1)

required
Example

model = xgb.train(params, dtrain, obj=binary_crossentropy.xgb_objective)

Source code in src/jaxboost/objective/binary.py
@AutoObjective
def binary_crossentropy(
    y_pred: jax.Array,
    y_true: jax.Array,
) -> jax.Array:
    """
    Binary Cross-Entropy Loss.

    Standard binary classification loss. Included for testing and as a baseline.

    Args:
        y_pred: Raw prediction (logit), will be passed through sigmoid
        y_true: Binary label (0 or 1)

    Example:
        >>> model = xgb.train(params, dtrain, obj=binary_crossentropy.xgb_objective)
    """
    p = jax.nn.sigmoid(y_pred)
    p = jnp.clip(p, 1e-7, 1 - 1e-7)
    return -y_true * jnp.log(p) - (1 - y_true) * jnp.log(1 - p)

weighted_binary_crossentropy

weighted_binary_crossentropy(y_pred: Array, y_true: Array, pos_weight: float = 1.0) -> jax.Array

Weighted Binary Cross-Entropy Loss.

Applies a weight to the positive class to handle class imbalance.

Parameters:

Name Type Description Default
y_pred Array

Raw prediction (logit), will be passed through sigmoid

required
y_true Array

Binary label (0 or 1)

required
pos_weight float

Weight for positive class. Default: 1.0

1.0
Example

10x weight for positive class

obj = weighted_binary_crossentropy.with_params(pos_weight=10.0) model = xgb.train(params, dtrain, obj=obj.xgb_objective)

Source code in src/jaxboost/objective/binary.py
@AutoObjective
def weighted_binary_crossentropy(
    y_pred: jax.Array,
    y_true: jax.Array,
    pos_weight: float = 1.0,
) -> jax.Array:
    """
    Weighted Binary Cross-Entropy Loss.

    Applies a weight to the positive class to handle class imbalance.

    Args:
        y_pred: Raw prediction (logit), will be passed through sigmoid
        y_true: Binary label (0 or 1)
        pos_weight: Weight for positive class. Default: 1.0

    Example:
        >>> # 10x weight for positive class
        >>> obj = weighted_binary_crossentropy.with_params(pos_weight=10.0)
        >>> model = xgb.train(params, dtrain, obj=obj.xgb_objective)
    """
    p = jax.nn.sigmoid(y_pred)
    p = jnp.clip(p, 1e-7, 1 - 1e-7)
    return -pos_weight * y_true * jnp.log(p) - (1 - y_true) * jnp.log(1 - p)

hinge_loss

hinge_loss(y_pred: Array, y_true: Array, margin: float = 1.0) -> jax.Array

Smooth Hinge Loss for binary classification.

SVM-style loss with smooth approximation for non-zero Hessians.

Parameters:

Name Type Description Default
y_pred Array

Raw prediction score

required
y_true Array

Binary label (0 or 1), will be converted to {-1, +1}

required
margin float

Margin parameter. Default: 1.0

1.0
Example

model = xgb.train(params, dtrain, obj=hinge_loss.xgb_objective)

Source code in src/jaxboost/objective/binary.py
@AutoObjective
def hinge_loss(
    y_pred: jax.Array,
    y_true: jax.Array,
    margin: float = 1.0,
) -> jax.Array:
    """
    Smooth Hinge Loss for binary classification.

    SVM-style loss with smooth approximation for non-zero Hessians.

    Args:
        y_pred: Raw prediction score
        y_true: Binary label (0 or 1), will be converted to {-1, +1}
        margin: Margin parameter. Default: 1.0

    Example:
        >>> model = xgb.train(params, dtrain, obj=hinge_loss.xgb_objective)
    """
    # Convert {0, 1} to {-1, +1}
    y_signed = 2 * y_true - 1

    # Smooth hinge: use softplus for differentiability
    # hinge(z) = max(0, 1 - z) ≈ softplus(1 - z)
    z = y_signed * y_pred
    return jax.nn.softplus(margin - z)

Regression

Objectives for continuous target prediction.

Standard

mse

mse(y_pred: Array, y_true: Array) -> jax.Array

Mean Squared Error Loss.

Standard squared error loss. Included for testing and as a baseline.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
Example

model = xgb.train(params, dtrain, obj=mse.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def mse(
    y_pred: jax.Array,
    y_true: jax.Array,
) -> jax.Array:
    """
    Mean Squared Error Loss.

    Standard squared error loss. Included for testing and as a baseline.

    Args:
        y_pred: Predicted value
        y_true: True value

    Example:
        >>> model = xgb.train(params, dtrain, obj=mse.xgb_objective)
    """
    return (y_pred - y_true) ** 2

huber

huber(y_pred: Array, y_true: Array, delta: float = 1.0) -> jax.Array

Huber Loss for robust regression.

Combines MSE for small errors and MAE for large errors, making it robust to outliers while maintaining smoothness near zero.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
delta float

Threshold where loss transitions from quadratic to linear. Default: 1.0

1.0
Example

model = xgb.train(params, dtrain, obj=huber.xgb_objective)

With custom delta:

model = xgb.train(params, dtrain, obj=huber.get_xgb_objective(delta=0.5))

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def huber(
    y_pred: jax.Array,
    y_true: jax.Array,
    delta: float = 1.0,
) -> jax.Array:
    """
    Huber Loss for robust regression.

    Combines MSE for small errors and MAE for large errors, making it
    robust to outliers while maintaining smoothness near zero.

    Args:
        y_pred: Predicted value
        y_true: True value
        delta: Threshold where loss transitions from quadratic to linear. Default: 1.0

    Example:
        >>> model = xgb.train(params, dtrain, obj=huber.xgb_objective)
        >>> # With custom delta:
        >>> model = xgb.train(params, dtrain, obj=huber.get_xgb_objective(delta=0.5))
    """
    error = y_pred - y_true
    abs_error = jnp.abs(error)
    return jnp.where(
        abs_error <= delta,
        0.5 * error**2,
        delta * (abs_error - 0.5 * delta),
    )

pseudo_huber

pseudo_huber(y_pred: Array, y_true: Array, delta: float = 1.0) -> jax.Array

Pseudo-Huber Loss.

A smooth approximation to the Huber loss that is differentiable everywhere.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
delta float

Scale parameter controlling the transition. Default: 1.0

1.0
Example

model = xgb.train(params, dtrain, obj=pseudo_huber.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def pseudo_huber(
    y_pred: jax.Array,
    y_true: jax.Array,
    delta: float = 1.0,
) -> jax.Array:
    """
    Pseudo-Huber Loss.

    A smooth approximation to the Huber loss that is differentiable everywhere.

    Args:
        y_pred: Predicted value
        y_true: True value
        delta: Scale parameter controlling the transition. Default: 1.0

    Example:
        >>> model = xgb.train(params, dtrain, obj=pseudo_huber.xgb_objective)
    """
    error = y_pred - y_true
    return delta**2 * (jnp.sqrt(1 + (error / delta) ** 2) - 1)

log_cosh

log_cosh(y_pred: Array, y_true: Array) -> jax.Array

Log-Cosh Loss for smooth robust regression.

Similar to Huber loss but smoother everywhere. Twice differentiable, which can lead to better optimization behavior.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
Example

model = xgb.train(params, dtrain, obj=log_cosh.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def log_cosh(
    y_pred: jax.Array,
    y_true: jax.Array,
) -> jax.Array:
    """
    Log-Cosh Loss for smooth robust regression.

    Similar to Huber loss but smoother everywhere. Twice differentiable,
    which can lead to better optimization behavior.

    Args:
        y_pred: Predicted value
        y_true: True value

    Example:
        >>> model = xgb.train(params, dtrain, obj=log_cosh.xgb_objective)
    """
    error = y_pred - y_true
    return jnp.log(jnp.cosh(error))

mae_smooth

mae_smooth(y_pred: Array, y_true: Array, beta: float = 0.1) -> jax.Array

Smooth Mean Absolute Error Loss.

A smooth approximation to MAE that has non-zero Hessian. Uses sqrt(error^2 + beta^2) - beta as approximation.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
beta float

Smoothing parameter. Smaller = closer to true MAE. Default: 0.1

0.1
Example

model = xgb.train(params, dtrain, obj=mae_smooth.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def mae_smooth(
    y_pred: jax.Array,
    y_true: jax.Array,
    beta: float = 0.1,
) -> jax.Array:
    """
    Smooth Mean Absolute Error Loss.

    A smooth approximation to MAE that has non-zero Hessian.
    Uses sqrt(error^2 + beta^2) - beta as approximation.

    Args:
        y_pred: Predicted value
        y_true: True value
        beta: Smoothing parameter. Smaller = closer to true MAE. Default: 0.1

    Example:
        >>> model = xgb.train(params, dtrain, obj=mae_smooth.xgb_objective)
    """
    error = y_pred - y_true
    return jnp.sqrt(error**2 + beta**2) - beta

Quantile & Asymmetric

quantile

quantile(y_pred: Array, y_true: Array, q: float = 0.5, alpha: float = 0.01) -> jax.Array

Smooth Quantile Loss (Pinball Loss) for quantile regression.

Asymmetric loss that penalizes under-prediction and over-prediction differently, allowing prediction of specific quantiles.

Uses a smooth approximation to ensure non-zero Hessians.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
q float

Target quantile in (0, 1). Default: 0.5 (median). q=0.1 for 10th percentile (conservative), q=0.5 for median, q=0.9 for 90th percentile (aggressive).

0.5
alpha float

Smoothing parameter for regularization. Default: 0.01

0.01
Example

Predict the 90th percentile

q90 = quantile.with_params(q=0.9) model = xgb.train(params, dtrain, obj=q90.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def quantile(
    y_pred: jax.Array,
    y_true: jax.Array,
    q: float = 0.5,
    alpha: float = 0.01,
) -> jax.Array:
    """
    Smooth Quantile Loss (Pinball Loss) for quantile regression.

    Asymmetric loss that penalizes under-prediction and over-prediction
    differently, allowing prediction of specific quantiles.

    Uses a smooth approximation to ensure non-zero Hessians.

    Args:
        y_pred: Predicted value
        y_true: True value
        q: Target quantile in (0, 1). Default: 0.5 (median).
            q=0.1 for 10th percentile (conservative),
            q=0.5 for median,
            q=0.9 for 90th percentile (aggressive).
        alpha: Smoothing parameter for regularization. Default: 0.01

    Example:
        >>> # Predict the 90th percentile
        >>> q90 = quantile.with_params(q=0.9)
        >>> model = xgb.train(params, dtrain, obj=q90.xgb_objective)
    """
    error = y_true - y_pred
    abs_error = jnp.abs(error)

    # Standard quantile loss + small quadratic regularization
    loss = jnp.where(
        error >= 0,
        q * abs_error,
        (1 - q) * abs_error,
    )

    # Add small quadratic term for Hessian
    loss = loss + alpha * error**2

    return loss

asymmetric

asymmetric(y_pred: Array, y_true: Array, alpha: float = 0.7) -> jax.Array

Asymmetric Loss for different penalties on under/over-prediction.

Useful when the cost of under-prediction differs from over-prediction, e.g., inventory management, demand forecasting.

Parameters:

Name Type Description Default
y_pred Array

Predicted value

required
y_true Array

True value

required
alpha float

Asymmetry parameter in (0, 1). Default: 0.7 - alpha > 0.5: Penalize under-prediction more - alpha < 0.5: Penalize over-prediction more

0.7
Example

Penalize under-prediction heavily (for safety stock)

obj = asymmetric.with_params(alpha=0.9) model = xgb.train(params, dtrain, obj=obj.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def asymmetric(
    y_pred: jax.Array,
    y_true: jax.Array,
    alpha: float = 0.7,
) -> jax.Array:
    """
    Asymmetric Loss for different penalties on under/over-prediction.

    Useful when the cost of under-prediction differs from over-prediction,
    e.g., inventory management, demand forecasting.

    Args:
        y_pred: Predicted value
        y_true: True value
        alpha: Asymmetry parameter in (0, 1). Default: 0.7
               - alpha > 0.5: Penalize under-prediction more
               - alpha < 0.5: Penalize over-prediction more

    Example:
        >>> # Penalize under-prediction heavily (for safety stock)
        >>> obj = asymmetric.with_params(alpha=0.9)
        >>> model = xgb.train(params, dtrain, obj=obj.xgb_objective)
    """
    error = y_true - y_pred
    return jnp.where(
        error >= 0,
        alpha * error**2,
        (1 - alpha) * error**2,
    )

Distribution-Based

tweedie

tweedie(y_pred: Array, y_true: Array, p: float = 1.5) -> jax.Array

Tweedie Loss for zero-inflated and positive continuous data.

Common in insurance claims, rainfall prediction, and other scenarios with many zeros and positive continuous values.

Valid for 1 < p < 2. For p=1 (Poisson), use poisson objective. For p=2 (Gamma), use gamma objective.

Parameters:

Name Type Description Default
y_pred Array

Raw prediction (will be exponentiated to ensure positivity)

required
y_true Array

True value (must be non-negative)

required
p float

Tweedie power parameter. Default: 1.5. For 1<p<2: Compound Poisson-Gamma (most common for insurance).

1.5
Example

model = xgb.train(params, dtrain, obj=tweedie.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def tweedie(
    y_pred: jax.Array,
    y_true: jax.Array,
    p: float = 1.5,
) -> jax.Array:
    """
    Tweedie Loss for zero-inflated and positive continuous data.

    Common in insurance claims, rainfall prediction, and other scenarios
    with many zeros and positive continuous values.

    Valid for 1 < p < 2.
    For p=1 (Poisson), use `poisson` objective.
    For p=2 (Gamma), use `gamma` objective.

    Args:
        y_pred: Raw prediction (will be exponentiated to ensure positivity)
        y_true: True value (must be non-negative)
        p: Tweedie power parameter. Default: 1.5.
            For 1<p<2: Compound Poisson-Gamma (most common for insurance).

    Example:
        >>> model = xgb.train(params, dtrain, obj=tweedie.xgb_objective)
    """
    # Ensure positive predictions via exp
    mu = jnp.exp(y_pred)
    mu = jnp.clip(mu, 1e-10, 1e10)

    # Tweedie deviance
    return -y_true * jnp.power(mu, 1 - p) / (1 - p) + jnp.power(mu, 2 - p) / (2 - p)

poisson

poisson(y_pred: Array, y_true: Array) -> jax.Array

Poisson Negative Log-Likelihood.

For count data. Assumes log-link function (y_pred is log(lambda)). Loss = exp(y_pred) - y_true * y_pred

Parameters:

Name Type Description Default
y_pred Array

Log of the expected count (log(lambda))

required
y_true Array

True count (must be non-negative)

required
Example

model = xgb.train(params, dtrain, obj=poisson.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def poisson(
    y_pred: jax.Array,
    y_true: jax.Array,
) -> jax.Array:
    """
    Poisson Negative Log-Likelihood.

    For count data. Assumes log-link function (y_pred is log(lambda)).
    Loss = exp(y_pred) - y_true * y_pred

    Args:
        y_pred: Log of the expected count (log(lambda))
        y_true: True count (must be non-negative)

    Example:
        >>> model = xgb.train(params, dtrain, obj=poisson.xgb_objective)
    """
    # Standard Poisson NLL ignoring constant log(y!) term
    return jnp.exp(y_pred) - y_true * y_pred

gamma

gamma(y_pred: Array, y_true: Array) -> jax.Array

Gamma Negative Log-Likelihood.

For positive continuous data (e.g. insurance claims, wait times). Assumes log-link function (y_pred is log(mean)). Loss = y_pred + y_true / exp(y_pred)

Parameters:

Name Type Description Default
y_pred Array

Log of the expected value (log(mean))

required
y_true Array

True value (must be positive)

required
Example

model = xgb.train(params, dtrain, obj=gamma.xgb_objective)

Source code in src/jaxboost/objective/regression.py
@AutoObjective
def gamma(
    y_pred: jax.Array,
    y_true: jax.Array,
) -> jax.Array:
    """
    Gamma Negative Log-Likelihood.

    For positive continuous data (e.g. insurance claims, wait times).
    Assumes log-link function (y_pred is log(mean)).
    Loss = y_pred + y_true / exp(y_pred)

    Args:
        y_pred: Log of the expected value (log(mean))
        y_true: True value (must be positive)

    Example:
        >>> model = xgb.train(params, dtrain, obj=gamma.xgb_objective)
    """
    # Gamma deviance-like loss: -log(y/mu) + (y-mu)/mu
    # With log-link mu = exp(pred):
    # Loss ~ log(mu) + y/mu
    #      = y_pred + y_true * exp(-y_pred)
    return y_pred + y_true * jnp.exp(-y_pred)

Multi-class Classification

Objectives for classification with more than two classes.

softmax_cross_entropy

softmax_cross_entropy(n_classes: int = 3) -> MultiClassObjective

Softmax Cross-Entropy Loss for multi-class classification.

Standard cross-entropy loss with softmax activation.

Parameters:

Name Type Description Default
n_classes int

Number of classes

3

Returns:

Type Description
MultiClassObjective

MultiClassObjective instance

Example

softmax_loss = softmax_cross_entropy(n_classes=5) params = {'num_class': 5} model = xgb.train(params, dtrain, obj=softmax_loss.xgb_objective)

Source code in src/jaxboost/objective/multiclass.py
def softmax_cross_entropy(n_classes: int = 3) -> MultiClassObjective:
    """
    Softmax Cross-Entropy Loss for multi-class classification.

    Standard cross-entropy loss with softmax activation.

    Args:
        n_classes: Number of classes

    Returns:
        MultiClassObjective instance

    Example:
        >>> softmax_loss = softmax_cross_entropy(n_classes=5)
        >>> params = {'num_class': 5}
        >>> model = xgb.train(params, dtrain, obj=softmax_loss.xgb_objective)
    """

    @multiclass_objective(n_classes=n_classes)
    def softmax_ce(logits: jax.Array, label: jax.Array) -> jax.Array:
        """Softmax cross-entropy for a single sample."""
        log_probs = jax.nn.log_softmax(logits)
        return -log_probs[label]

    return softmax_ce

focal_multiclass

focal_multiclass(n_classes: int = 3, gamma: float = 2.0, alpha: float | None = None) -> MultiClassObjective

Focal Loss for multi-class classification.

Extends focal loss to multi-class setting. Down-weights well-classified examples to focus on hard examples.

Parameters:

Name Type Description Default
n_classes int

Number of classes

3
gamma float

Focusing parameter. Higher = more focus on hard examples.

2.0
alpha float | None

Optional class weight. If None, all classes weighted equally.

None

Returns:

Type Description
MultiClassObjective

MultiClassObjective instance

Example

focal_mc = focal_multiclass(n_classes=10, gamma=2.0) model = xgb.train(params, dtrain, obj=focal_mc.xgb_objective)

Reference

Lin et al. "Focal Loss for Dense Object Detection" (2017)

Source code in src/jaxboost/objective/multiclass.py
def focal_multiclass(
    n_classes: int = 3,
    gamma: float = 2.0,
    alpha: float | None = None,
) -> MultiClassObjective:
    """
    Focal Loss for multi-class classification.

    Extends focal loss to multi-class setting. Down-weights well-classified
    examples to focus on hard examples.

    Args:
        n_classes: Number of classes
        gamma: Focusing parameter. Higher = more focus on hard examples.
        alpha: Optional class weight. If None, all classes weighted equally.

    Returns:
        MultiClassObjective instance

    Example:
        >>> focal_mc = focal_multiclass(n_classes=10, gamma=2.0)
        >>> model = xgb.train(params, dtrain, obj=focal_mc.xgb_objective)

    Reference:
        Lin et al. "Focal Loss for Dense Object Detection" (2017)
    """

    @multiclass_objective(n_classes=n_classes)
    def focal_mc(logits: jax.Array, label: jax.Array) -> jax.Array:
        """Focal loss for a single sample."""
        probs = jax.nn.softmax(logits)
        probs = jnp.clip(probs, 1e-10, 1.0 - 1e-10)

        p_t = probs[label]
        focal_weight = (1 - p_t) ** gamma
        ce = -jnp.log(p_t)

        if alpha is not None:
            return alpha * focal_weight * ce
        return focal_weight * ce

    return focal_mc

label_smoothing

label_smoothing(n_classes: int = 3, smoothing: float = 0.1) -> MultiClassObjective

Label Smoothing Cross-Entropy Loss.

Softmax cross-entropy with label smoothing for regularization.

Parameters:

Name Type Description Default
n_classes int

Number of classes

3
smoothing float

Smoothing factor in [0, 1]. 0 = no smoothing.

0.1

Returns:

Type Description
MultiClassObjective

MultiClassObjective instance

Example

smooth_loss = label_smoothing(n_classes=10, smoothing=0.1) model = xgb.train(params, dtrain, obj=smooth_loss.xgb_objective)

Source code in src/jaxboost/objective/multiclass.py
def label_smoothing(
    n_classes: int = 3,
    smoothing: float = 0.1,
) -> MultiClassObjective:
    """
    Label Smoothing Cross-Entropy Loss.

    Softmax cross-entropy with label smoothing for regularization.

    Args:
        n_classes: Number of classes
        smoothing: Smoothing factor in [0, 1]. 0 = no smoothing.

    Returns:
        MultiClassObjective instance

    Example:
        >>> smooth_loss = label_smoothing(n_classes=10, smoothing=0.1)
        >>> model = xgb.train(params, dtrain, obj=smooth_loss.xgb_objective)
    """

    @multiclass_objective(n_classes=n_classes)
    def label_smooth(logits: jax.Array, label: jax.Array) -> jax.Array:
        """Label smoothing cross-entropy for a single sample."""
        log_probs = jax.nn.log_softmax(logits)

        smooth_weight = smoothing / (n_classes - 1)
        true_weight = 1.0 - smoothing

        loss = -true_weight * log_probs[label] - smooth_weight * (
            jnp.sum(log_probs) - log_probs[label]
        )

        return loss

    return label_smooth

class_balanced

class_balanced(n_classes: int = 3, samples_per_class: NDArray[floating[Any]] | None = None, beta: float = 0.999) -> MultiClassObjective

Class-Balanced Loss for long-tailed distributions.

Re-weights classes based on effective number of samples.

Parameters:

Name Type Description Default
n_classes int

Number of classes

3
samples_per_class NDArray[floating[Any]] | None

Array of sample counts per class. If None, uniform weights.

None
beta float

Hyperparameter for effective number. Higher = more aggressive reweighting.

0.999

Returns:

Type Description
MultiClassObjective

MultiClassObjective instance

Example

cb_loss = class_balanced( ... n_classes=5, ... samples_per_class=np.array([1000, 500, 100, 50, 10]), ... beta=0.999 ... ) model = xgb.train(params, dtrain, obj=cb_loss.xgb_objective)

Reference

Cui et al. "Class-Balanced Loss Based on Effective Number of Samples" (2019)

Source code in src/jaxboost/objective/multiclass.py
def class_balanced(
    n_classes: int = 3,
    samples_per_class: NDArray[np.floating[Any]] | None = None,
    beta: float = 0.999,
) -> MultiClassObjective:
    """
    Class-Balanced Loss for long-tailed distributions.

    Re-weights classes based on effective number of samples.

    Args:
        n_classes: Number of classes
        samples_per_class: Array of sample counts per class. If None, uniform weights.
        beta: Hyperparameter for effective number. Higher = more aggressive reweighting.

    Returns:
        MultiClassObjective instance

    Example:
        >>> cb_loss = class_balanced(
        ...     n_classes=5,
        ...     samples_per_class=np.array([1000, 500, 100, 50, 10]),
        ...     beta=0.999
        ... )
        >>> model = xgb.train(params, dtrain, obj=cb_loss.xgb_objective)

    Reference:
        Cui et al. "Class-Balanced Loss Based on Effective Number of Samples" (2019)
    """
    if samples_per_class is not None:
        effective_num = 1.0 - np.power(beta, samples_per_class)
        weights = (1.0 - beta) / effective_num
        weights = weights / np.sum(weights) * n_classes
        weights_jax = jnp.asarray(weights, dtype=jnp.float32)
    else:
        weights_jax = jnp.ones(n_classes, dtype=jnp.float32)

    @multiclass_objective(n_classes=n_classes)
    def cb_ce(logits: jax.Array, label: jax.Array) -> jax.Array:
        """Class-balanced cross-entropy for a single sample."""
        log_probs = jax.nn.log_softmax(logits)
        weight = weights_jax[label]
        return -weight * log_probs[label]

    return cb_ce

Ordinal Regression

Objectives for ordered categorical outcomes (ratings, grades, severity levels).

ordinal_logit

ordinal_logit(n_classes: int) -> OrdinalObjective

Create an ordinal regression objective with logit link.

Source code in src/jaxboost/objective/ordinal.py
def ordinal_logit(n_classes: int) -> OrdinalObjective:
    """Create an ordinal regression objective with logit link."""
    return OrdinalObjective(n_classes=n_classes, link="logit")

ordinal_probit

ordinal_probit(n_classes: int) -> OrdinalObjective

Create an ordinal regression objective with probit link.

Source code in src/jaxboost/objective/ordinal.py
def ordinal_probit(n_classes: int) -> OrdinalObjective:
    """Create an ordinal regression objective with probit link."""
    return OrdinalObjective(n_classes=n_classes, link="probit")

QWK-Aligned

qwk_ordinal

qwk_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit', alpha: float = 0.0, beta: float = 1.0) -> QWKOrdinalObjective

Create a QWK-aligned ordinal objective.

Uses Expected Quadratic Error (EQE) as a differentiable surrogate for QWK.

Parameters:

Name Type Description Default
n_classes int

Number of ordinal classes

required
link Literal['probit', 'logit']

Link function - 'probit' or 'logit'

'logit'
alpha float

Weight for NLL loss (0 = no NLL)

0.0
beta float

Weight for EQE loss (1 = full EQE)

1.0

Returns:

Type Description
QWKOrdinalObjective

QWKOrdinalObjective instance

Example

Pure EQE (best QWK alignment)

obj = qwk_ordinal(n_classes=7)

Hybrid for stability

obj = qwk_ordinal(n_classes=7, alpha=0.5, beta=0.5)

Source code in src/jaxboost/objective/ordinal.py
def qwk_ordinal(
    n_classes: int,
    link: Literal["probit", "logit"] = "logit",
    alpha: float = 0.0,
    beta: float = 1.0,
) -> QWKOrdinalObjective:
    """
    Create a QWK-aligned ordinal objective.

    Uses Expected Quadratic Error (EQE) as a differentiable surrogate for QWK.

    Args:
        n_classes: Number of ordinal classes
        link: Link function - 'probit' or 'logit'
        alpha: Weight for NLL loss (0 = no NLL)
        beta: Weight for EQE loss (1 = full EQE)

    Returns:
        QWKOrdinalObjective instance

    Example:
        >>> # Pure EQE (best QWK alignment)
        >>> obj = qwk_ordinal(n_classes=7)
        >>>
        >>> # Hybrid for stability
        >>> obj = qwk_ordinal(n_classes=7, alpha=0.5, beta=0.5)
    """
    return QWKOrdinalObjective(n_classes=n_classes, link=link, alpha=alpha, beta=beta)

squared_cdf_ordinal

squared_cdf_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit') -> SquaredCDFObjective

Create a Squared CDF (CRPS) ordinal objective.

This minimizes the squared Earth Mover's Distance between distributions, often outperforming EQE/NLL for QWK optimization.

Source code in src/jaxboost/objective/ordinal.py
def squared_cdf_ordinal(
    n_classes: int,
    link: Literal["probit", "logit"] = "logit",
) -> SquaredCDFObjective:
    """
    Create a Squared CDF (CRPS) ordinal objective.

    This minimizes the squared Earth Mover's Distance between distributions,
    often outperforming EQE/NLL for QWK optimization.
    """
    return SquaredCDFObjective(n_classes=n_classes, link=link)

hybrid_ordinal

hybrid_ordinal(n_classes: int, link: Literal['probit', 'logit'] = 'logit', nll_weight: float = 0.7, eqe_weight: float = 0.3) -> QWKOrdinalObjective

Create a hybrid ordinal objective (NLL + EQE).

Combines: - NLL: Proper probabilistic loss (stable gradients early in training) - EQE: QWK-aligned loss (better metric alignment)

Default weights (0.7/0.3) work well in practice.

Parameters:

Name Type Description Default
n_classes int

Number of ordinal classes

required
link Literal['probit', 'logit']

Link function

'logit'
nll_weight float

Weight for NLL loss

0.7
eqe_weight float

Weight for EQE loss

0.3

Returns:

Type Description
QWKOrdinalObjective

QWKOrdinalObjective instance

Source code in src/jaxboost/objective/ordinal.py
def hybrid_ordinal(
    n_classes: int,
    link: Literal["probit", "logit"] = "logit",
    nll_weight: float = 0.7,
    eqe_weight: float = 0.3,
) -> QWKOrdinalObjective:
    """
    Create a hybrid ordinal objective (NLL + EQE).

    Combines:
    - NLL: Proper probabilistic loss (stable gradients early in training)
    - EQE: QWK-aligned loss (better metric alignment)

    Default weights (0.7/0.3) work well in practice.

    Args:
        n_classes: Number of ordinal classes
        link: Link function
        nll_weight: Weight for NLL loss
        eqe_weight: Weight for EQE loss

    Returns:
        QWKOrdinalObjective instance
    """
    return QWKOrdinalObjective(n_classes=n_classes, link=link, alpha=nll_weight, beta=eqe_weight)

SLACE Paper (AAAI 2025)

slace_objective

slace_objective(n_classes: int, alpha: float = 1.0) -> SLACEObjective

Create SLACE (Soft Labels Accumulating Cross Entropy) objective.

Source code in src/jaxboost/objective/ordinal.py
def slace_objective(n_classes: int, alpha: float = 1.0) -> SLACEObjective:
    """Create SLACE (Soft Labels Accumulating Cross Entropy) objective."""
    return SLACEObjective(n_classes=n_classes, alpha=alpha)

sord_objective

sord_objective(n_classes: int, alpha: float = 1.0) -> SORDObjective

Create SORD (Soft Ordinal) objective.

Source code in src/jaxboost/objective/ordinal.py
def sord_objective(n_classes: int, alpha: float = 1.0) -> SORDObjective:
    """Create SORD (Soft Ordinal) objective."""
    return SORDObjective(n_classes=n_classes, alpha=alpha)

oll_objective

oll_objective(n_classes: int, alpha: float = 1.0) -> OLLObjective

Create OLL (Ordinal Log Loss) objective.

Source code in src/jaxboost/objective/ordinal.py
def oll_objective(n_classes: int, alpha: float = 1.0) -> OLLObjective:
    """Create OLL (Ordinal Log Loss) objective."""
    return OLLObjective(n_classes=n_classes, alpha=alpha)

Survival Analysis

Objectives for time-to-event modeling.

aft

aft(y_pred: Array, y_true: Array, label_lower_bound: Array | None = None, label_upper_bound: Array | None = None, sigma: float = 1.0) -> jax.Array

Accelerated Failure Time (AFT) Loss for survival analysis.

Models log(T) = y_pred + sigma * epsilon, where epsilon follows a normal distribution.

Handles censored data: - Uncensored: lower == upper (exact event time) - Right-censored: upper == inf (event hasn't occurred yet) - Interval-censored: lower < upper (event in time range)

Parameters:

Name Type Description Default
y_pred Array

Predicted log survival time

required
y_true Array

Label (used as lower bound if bounds not provided)

required
label_lower_bound Array | None

Lower bound of survival time

None
label_upper_bound Array | None

Upper bound of survival time

None
sigma float

Scale parameter (default 1.0)

1.0
Example

Survival data: some patients censored (still alive)

lower_bounds = event_times upper_bounds = np.where(is_censored, np.inf, event_times) dtrain.set_float_info('label_lower_bound', lower_bounds) dtrain.set_float_info('label_upper_bound', upper_bounds) model = xgb.train(params, dtrain, obj=aft.xgb_objective)

Source code in src/jaxboost/objective/survival.py
@AutoObjective
def aft(
    y_pred: jax.Array,
    y_true: jax.Array,
    label_lower_bound: jax.Array | None = None,
    label_upper_bound: jax.Array | None = None,
    sigma: float = 1.0,
) -> jax.Array:
    """
    Accelerated Failure Time (AFT) Loss for survival analysis.

    Models log(T) = y_pred + sigma * epsilon, where epsilon follows
    a normal distribution.

    Handles censored data:
    - Uncensored: lower == upper (exact event time)
    - Right-censored: upper == inf (event hasn't occurred yet)
    - Interval-censored: lower < upper (event in time range)

    Args:
        y_pred: Predicted log survival time
        y_true: Label (used as lower bound if bounds not provided)
        label_lower_bound: Lower bound of survival time
        label_upper_bound: Upper bound of survival time
        sigma: Scale parameter (default 1.0)

    Example:
        >>> # Survival data: some patients censored (still alive)
        >>> lower_bounds = event_times
        >>> upper_bounds = np.where(is_censored, np.inf, event_times)
        >>> dtrain.set_float_info('label_lower_bound', lower_bounds)
        >>> dtrain.set_float_info('label_upper_bound', upper_bounds)
        >>> model = xgb.train(params, dtrain, obj=aft.xgb_objective)
    """
    # Use y_true as default bounds if not provided
    lower = y_true if label_lower_bound is None else label_lower_bound
    upper = y_true if label_upper_bound is None else label_upper_bound

    # Log transform (AFT works in log-time space)
    log_lower = jnp.log(jnp.maximum(lower, 1e-10))

    # Standardized residual for lower bound
    z_lower = (log_lower - y_pred) / sigma

    # Check censoring type
    is_uncensored = jnp.abs(upper - lower) < 1e-7
    is_right_censored = upper > 1e10  # inf check

    # Uncensored: -log(pdf) = 0.5*z^2 + log(sigma) + const
    uncensored_loss = 0.5 * z_lower**2 + jnp.log(sigma)

    # Right-censored: -log(survival) = -log(1 - CDF(z))
    cdf_lower = jax.scipy.stats.norm.cdf(z_lower)
    cdf_lower_clipped = jnp.clip(cdf_lower, 0.0, 1.0 - 1e-7)
    right_censored_loss = -jnp.log1p(-cdf_lower_clipped)

    # Interval-censored: -log(CDF(upper) - CDF(lower))
    log_upper = jnp.log(jnp.maximum(upper, 1e-10))
    log_upper_clipped = jnp.clip(log_upper, -100.0, 100.0)
    z_upper = (log_upper_clipped - y_pred) / sigma
    cdf_upper = jax.scipy.stats.norm.cdf(z_upper)
    interval_prob = jnp.maximum(cdf_upper - cdf_lower, 1e-10)
    interval_loss = -jnp.log(interval_prob)

    # Select appropriate loss
    loss = jnp.where(
        is_uncensored,
        uncensored_loss,
        jnp.where(is_right_censored, right_censored_loss, interval_loss),
    )

    return loss

weibull_aft

weibull_aft(y_pred: Array, y_true: Array, label_lower_bound: Array | None = None, label_upper_bound: Array | None = None, k: float = 1.0) -> jax.Array

Weibull AFT (Accelerated Failure Time) Loss.

Uses Weibull distribution for survival times instead of log-normal.

Parameters:

Name Type Description Default
y_pred Array

Predicted log scale parameter (lambda)

required
y_true Array

Label (used as lower bound if bounds not provided)

required
label_lower_bound Array | None

Lower bound of survival time

None
label_upper_bound Array | None

Upper bound of survival time

None
k float

Shape parameter of Weibull distribution (default 1.0 = exponential)

1.0
Example

model = xgb.train(params, dtrain, obj=weibull_aft.xgb_objective)

Source code in src/jaxboost/objective/survival.py
@AutoObjective
def weibull_aft(
    y_pred: jax.Array,
    y_true: jax.Array,
    label_lower_bound: jax.Array | None = None,
    label_upper_bound: jax.Array | None = None,
    k: float = 1.0,
) -> jax.Array:
    """
    Weibull AFT (Accelerated Failure Time) Loss.

    Uses Weibull distribution for survival times instead of log-normal.

    Args:
        y_pred: Predicted log scale parameter (lambda)
        y_true: Label (used as lower bound if bounds not provided)
        label_lower_bound: Lower bound of survival time
        label_upper_bound: Upper bound of survival time
        k: Shape parameter of Weibull distribution (default 1.0 = exponential)

    Example:
        >>> model = xgb.train(params, dtrain, obj=weibull_aft.xgb_objective)
    """
    # Use y_true as default bounds if not provided
    lower = y_true if label_lower_bound is None else label_lower_bound
    upper = y_true if label_upper_bound is None else label_upper_bound

    # Scale parameter from prediction
    lambda_ = jnp.exp(y_pred)
    lambda_ = jnp.clip(lambda_, 1e-10, 1e10)

    # Check censoring type
    is_uncensored = jnp.abs(upper - lower) < 1e-7
    is_right_censored = upper > 1e10

    # Weibull survival function: S(t) = exp(-(t/lambda)^k)
    # Weibull PDF: f(t) = (k/lambda) * (t/lambda)^(k-1) * exp(-(t/lambda)^k)

    t_lower = jnp.maximum(lower, 1e-10)

    # Uncensored: -log(pdf)
    z_lower = (t_lower / lambda_) ** k
    log_pdf = (
        jnp.log(k) - jnp.log(lambda_) + (k - 1) * (jnp.log(t_lower) - jnp.log(lambda_)) - z_lower
    )
    uncensored_loss = -log_pdf

    # Right-censored: -log(survival) = (t/lambda)^k
    right_censored_loss = z_lower

    # Interval-censored: -log(S(lower) - S(upper))
    t_upper = jnp.maximum(upper, 1e-10)
    t_upper = jnp.clip(t_upper, 0, 1e10)
    z_upper = (t_upper / lambda_) ** k
    survival_lower = jnp.exp(-z_lower)
    survival_upper = jnp.exp(-z_upper)
    interval_prob = jnp.maximum(survival_lower - survival_upper, 1e-10)
    interval_loss = -jnp.log(interval_prob)

    loss = jnp.where(
        is_uncensored,
        uncensored_loss,
        jnp.where(is_right_censored, right_censored_loss, interval_loss),
    )

    return loss

Multi-task Learning

Objectives for predicting multiple targets simultaneously.

multi_task_regression

multi_task_regression(n_tasks: int) -> MaskedMultiTaskObjective

Standard multi-task regression with MSE loss.

Supports missing labels via masking.

Parameters:

Name Type Description Default
n_tasks int

Number of regression tasks

required
Example

obj = multi_task_regression(n_tasks=5)

Labels shape: (n_samples, 5), some can be NaN

mask = ~np.isnan(y_true) dtrain.set_float_info('label_mask', mask.flatten())

Source code in src/jaxboost/objective/multi_task.py
def multi_task_regression(n_tasks: int) -> MaskedMultiTaskObjective:
    """
    Standard multi-task regression with MSE loss.

    Supports missing labels via masking.

    Args:
        n_tasks: Number of regression tasks

    Example:
        >>> obj = multi_task_regression(n_tasks=5)
        >>> # Labels shape: (n_samples, 5), some can be NaN
        >>> mask = ~np.isnan(y_true)
        >>> dtrain.set_float_info('label_mask', mask.flatten())
    """
    return MaskedMultiTaskObjective(
        task_loss_fn=lambda y_pred, y_true: (y_pred - y_true) ** 2,
        n_tasks=n_tasks,
    )

multi_task_classification

multi_task_classification(n_tasks: int) -> MaskedMultiTaskObjective

Multi-task binary classification with log loss.

Each task is an independent binary classification. Supports missing labels via masking.

Parameters:

Name Type Description Default
n_tasks int

Number of binary classification tasks

required
Example

obj = multi_task_classification(n_tasks=3)

Each task: predict 0 or 1

Source code in src/jaxboost/objective/multi_task.py
def multi_task_classification(n_tasks: int) -> MaskedMultiTaskObjective:
    """
    Multi-task binary classification with log loss.

    Each task is an independent binary classification.
    Supports missing labels via masking.

    Args:
        n_tasks: Number of binary classification tasks

    Example:
        >>> obj = multi_task_classification(n_tasks=3)
        >>> # Each task: predict 0 or 1
    """

    def binary_logloss(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
        """Binary cross-entropy loss."""
        # y_pred is raw score (logit), y_true is 0 or 1
        # BCE = -y*log(sigmoid(s)) - (1-y)*log(1-sigmoid(s))
        #     = max(s, 0) - s*y + log(1 + exp(-|s|))
        return jnp.maximum(y_pred, 0) - y_pred * y_true + jnp.log1p(jnp.exp(-jnp.abs(y_pred)))

    return MaskedMultiTaskObjective(
        task_loss_fn=binary_logloss,
        n_tasks=n_tasks,
    )

multi_task_huber

multi_task_huber(n_tasks: int, delta: float = 1.0) -> MaskedMultiTaskObjective

Multi-task regression with Huber loss (robust to outliers).

Parameters:

Name Type Description Default
n_tasks int

Number of tasks

required
delta float

Threshold for switching between L1 and L2

1.0
Example

obj = multi_task_huber(n_tasks=3, delta=1.5)

Source code in src/jaxboost/objective/multi_task.py
def multi_task_huber(n_tasks: int, delta: float = 1.0) -> MaskedMultiTaskObjective:
    """
    Multi-task regression with Huber loss (robust to outliers).

    Args:
        n_tasks: Number of tasks
        delta: Threshold for switching between L1 and L2

    Example:
        >>> obj = multi_task_huber(n_tasks=3, delta=1.5)
    """

    def huber_loss(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
        error = y_pred - y_true
        abs_error = jnp.abs(error)
        quadratic = 0.5 * error**2
        linear = delta * abs_error - 0.5 * delta**2
        return jnp.where(abs_error <= delta, quadratic, linear)

    return MaskedMultiTaskObjective(
        task_loss_fn=huber_loss,
        n_tasks=n_tasks,
    )

multi_task_quantile

multi_task_quantile(n_tasks: int, quantiles: list[float] | None = None) -> MaskedMultiTaskObjective

Multi-task quantile regression.

Each task predicts a different quantile. Useful for prediction intervals.

Parameters:

Name Type Description Default
n_tasks int

Number of quantiles to predict

required
quantiles list[float] | None

List of quantile values (default: evenly spaced)

None
Example

Predict 10th, 50th, 90th percentiles

obj = multi_task_quantile(n_tasks=3, quantiles=[0.1, 0.5, 0.9])

Source code in src/jaxboost/objective/multi_task.py
def multi_task_quantile(
    n_tasks: int, quantiles: list[float] | None = None
) -> MaskedMultiTaskObjective:
    """
    Multi-task quantile regression.

    Each task predicts a different quantile. Useful for prediction intervals.

    Args:
        n_tasks: Number of quantiles to predict
        quantiles: List of quantile values (default: evenly spaced)

    Example:
        >>> # Predict 10th, 50th, 90th percentiles
        >>> obj = multi_task_quantile(n_tasks=3, quantiles=[0.1, 0.5, 0.9])
    """
    if quantiles is None:
        quantiles = list(np.linspace(0.1, 0.9, n_tasks))
    elif len(quantiles) != n_tasks:
        raise ValueError(f"quantiles length ({len(quantiles)}) must match n_tasks ({n_tasks})")

    quantiles_arr = jnp.asarray(quantiles, dtype=jnp.float32)

    # We need a way to pass the quantile to each task
    # Since task_loss_fn operates per-task, we'll use a trick:
    # Store quantiles and use task index

    class QuantileMTL(MaskedMultiTaskObjective):
        """Quantile MTL with per-task quantiles."""

        def __init__(self) -> None:
            super().__init__(n_tasks=n_tasks)
            self.quantiles = quantiles_arr
            self._name = "multi_task_quantile"

        def _compute_grad_hess_single(
            self,
            y_pred: jax.Array,
            y_true: jax.Array,
            mask: jax.Array,
        ) -> tuple[jax.Array, jax.Array]:
            """Quantile loss with per-task quantiles."""
            error = y_true - y_pred

            # Gradient of quantile loss
            # Loss = q * max(error, 0) + (1-q) * max(-error, 0)
            # Grad w.r.t. y_pred:
            #   if error > 0: -q
            #   if error < 0: 1-q
            grads = jnp.where(error > 0, -self.quantiles, 1 - self.quantiles)

            # Hessian is 0 for quantile loss (piecewise linear)
            # Use small constant for XGBoost stability
            hess = jnp.ones_like(grads) * 1.0

            # Apply mask
            grads = grads * mask
            hess = hess * mask

            return grads, hess

    return QuantileMTL()

Uncertainty Estimation

Multi-output objectives that predict both value and uncertainty.

gaussian_nll

gaussian_nll(n_outputs: int = 2) -> MultiOutputObjective

Gaussian Negative Log-Likelihood for uncertainty estimation.

Predicts both mean and log-variance, enabling uncertainty quantification.

Parameters:

Name Type Description Default
n_outputs int

Should be 2 (mean, log_variance)

2

Returns:

Type Description
MultiOutputObjective

MultiOutputObjective instance

Example

nll = gaussian_nll() params = {'multi_strategy': 'multi_output_tree', 'num_target': 2} model = xgb.train(params, dtrain, obj=nll.xgb_objective)

Source code in src/jaxboost/objective/multi_output.py
def gaussian_nll(n_outputs: int = 2) -> MultiOutputObjective:
    """
    Gaussian Negative Log-Likelihood for uncertainty estimation.

    Predicts both mean and log-variance, enabling uncertainty quantification.

    Args:
        n_outputs: Should be 2 (mean, log_variance)

    Returns:
        MultiOutputObjective instance

    Example:
        >>> nll = gaussian_nll()
        >>> params = {'multi_strategy': 'multi_output_tree', 'num_target': 2}
        >>> model = xgb.train(params, dtrain, obj=nll.xgb_objective)
    """
    if n_outputs != 2:
        raise ValueError("gaussian_nll requires n_outputs=2 (mean, log_variance)")

    @multi_output_objective(n_outputs=2)
    def gnll(params: jax.Array, y_true: jax.Array) -> jax.Array:
        """Gaussian NLL for a single sample."""
        mean = params[0]
        log_var = params[1]

        # Clip log_var for numerical stability
        log_var = jnp.clip(log_var, -10.0, 10.0)
        var = jnp.exp(log_var)

        # NLL = 0.5 * (log(var) + (y - mean)^2 / var)
        nll = 0.5 * (log_var + (y_true - mean) ** 2 / var)
        return nll

    return gnll

laplace_nll

laplace_nll(n_outputs: int = 2) -> MultiOutputObjective

Laplace Negative Log-Likelihood for robust uncertainty estimation.

Similar to Gaussian NLL but uses Laplace distribution, which is more robust to outliers.

Parameters:

Name Type Description Default
n_outputs int

Should be 2 (location, log_scale)

2

Returns:

Type Description
MultiOutputObjective

MultiOutputObjective instance

Example

nll = laplace_nll() model = xgb.train(params, dtrain, obj=nll.xgb_objective)

Source code in src/jaxboost/objective/multi_output.py
def laplace_nll(n_outputs: int = 2) -> MultiOutputObjective:
    """
    Laplace Negative Log-Likelihood for robust uncertainty estimation.

    Similar to Gaussian NLL but uses Laplace distribution, which is more
    robust to outliers.

    Args:
        n_outputs: Should be 2 (location, log_scale)

    Returns:
        MultiOutputObjective instance

    Example:
        >>> nll = laplace_nll()
        >>> model = xgb.train(params, dtrain, obj=nll.xgb_objective)
    """
    if n_outputs != 2:
        raise ValueError("laplace_nll requires n_outputs=2 (location, log_scale)")

    @multi_output_objective(n_outputs=2)
    def lnll(params: jax.Array, y_true: jax.Array) -> jax.Array:
        """Laplace NLL for a single sample."""
        loc = params[0]
        log_scale = params[1]

        log_scale = jnp.clip(log_scale, -10.0, 10.0)
        scale = jnp.exp(log_scale)

        # NLL = log(scale) + |y - loc| / scale
        nll = log_scale + jnp.abs(y_true - loc) / scale
        return nll

    return lnll