Source code for pyttb.gcp.fg_setup

"""Prepare Function and Gradient Handles for GCP OPT."""

# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from __future__ import annotations

from functools import partial
from typing import Callable, Optional, Tuple, Union

import numpy as np

import pyttb as ttb
from pyttb.gcp import handles
from pyttb.gcp.handles import Objectives

function_type = Callable[[np.ndarray, np.ndarray], np.ndarray]
fg_return = Tuple[function_type, function_type, float]


[docs] def setup( # noqa: PLR0912,PLR0915 objective: Objectives, data: Optional[Union[ttb.tensor, ttb.sptensor]] = None, additional_parameter: Optional[float] = None, ) -> fg_return: """Collect the function and gradient handles for GCP. Parameters ---------- objective: Objective function to gather handles for. data: Tensor to check for consistency with desired objective function. additional_parameter: Additional constant argument provided to objective function if necessary. Returns ------- Function handle, gradient handle, and lower bound. """ if objective == Objectives.GAUSSIAN: function_handle = handles.gaussian gradient_handle = handles.gaussian_grad lower_bound = -np.inf elif objective == Objectives.BERNOULLI_ODDS: if data is not None and not valid_binary(data): raise ValueError(f"{objective.name} requires a binary tensor") function_handle = handles.bernoulli_odds gradient_handle = handles.bernoulli_odds_grad lower_bound = 0.0 elif objective == Objectives.BERNOULLI_LOGIT: if data is not None and not valid_binary(data): raise ValueError(f"{objective.name} requires a binary tensor") function_handle = handles.bernoulli_logit gradient_handle = handles.bernoulli_logit_grad lower_bound = -np.inf elif objective == Objectives.POISSON: if data is not None and not valid_natural(data): raise ValueError(f"{objective.name} requires a count tensor") function_handle = handles.poisson gradient_handle = handles.poisson_grad lower_bound = 0.0 elif objective == Objectives.POISSON_LOG: if data is not None and not valid_natural(data): raise ValueError(f"{objective.name} requires a count tensor") function_handle = handles.poisson_log gradient_handle = handles.poisson_log_grad lower_bound = -np.inf elif objective == Objectives.RAYLEIGH: if data is not None and not valid_nonneg(data): raise ValueError(f"{objective.name} requires a non-negative tensor") function_handle = handles.rayleigh gradient_handle = handles.rayleigh_grad lower_bound = 0.0 elif objective == Objectives.GAMMA: if data is not None and not valid_nonneg(data): raise ValueError(f"{objective.name} requires a non-negative tensor") function_handle = handles.gamma gradient_handle = handles.gamma_grad lower_bound = 0.0 elif objective == Objectives.HUBER: if additional_parameter is None: raise ValueError( f"{objective.name} requires additional parameter for `threshold`" ) function_handle = partial(handles.huber, threshold=additional_parameter) gradient_handle = partial(handles.huber_grad, threshold=additional_parameter) lower_bound = -np.inf elif objective == Objectives.NEGATIVE_BINOMIAL: if data is not None and not valid_nonneg(data): raise ValueError(f"{objective.name} requires a non-negative tensor") if additional_parameter is None: raise ValueError( f"{objective.name} requires additional parameter for `num_trials`" ) function_handle = partial( handles.negative_binomial, num_trials=additional_parameter ) gradient_handle = partial( handles.negative_binomial_grad, num_trials=additional_parameter ) lower_bound = 0 elif objective == Objectives.BETA: if data is not None and not valid_nonneg(data): raise ValueError(f"{objective.name} requires a non-negative tensor") if additional_parameter is None: raise ValueError(f"{objective.name} requires additional parameter for `b`") function_handle = partial(handles.beta, b=additional_parameter) gradient_handle = partial(handles.beta_grad, b=additional_parameter) lower_bound = 0 else: raise ValueError(f" Unknown objective: {objective}") return function_handle, gradient_handle, lower_bound
[docs] def valid_nonneg(data: Union[ttb.tensor, ttb.sptensor]) -> bool: """Check if provided data is valid non-negative tensor.""" if isinstance(data, ttb.sptensor): return bool(np.all(data.vals > 0)) return bool(np.all(data.data > 0))
[docs] def valid_binary(data: Union[ttb.tensor, ttb.sptensor]) -> bool: """Check if provided data is valid binary tensor.""" if isinstance(data, ttb.sptensor): return bool(np.all(data.vals == 1)) return bool(np.all(np.isin(np.unique(data.data), [0, 1])))
[docs] def valid_natural(data: Union[ttb.tensor, ttb.sptensor]) -> bool: """Check if provided data is valid natural number tensor.""" if isinstance(data, ttb.sptensor): vals = data.vals else: vals = data.data return bool(np.all(vals % 1 == 0))