"""Evaluate Function And Gradient Handles."""
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union, overload
import numpy as np
import pyttb as ttb
from pyttb.gcp.fg_setup import function_type
@overload
def evaluate(
model: ttb.ktensor,
data: Union[ttb.tensor, ttb.sptensor],
weights: Optional[np.ndarray],
function_handle: Literal[None],
gradient_handle: function_type,
) -> List[np.ndarray]: ... # pragma: no cover see coveragepy/issues/970
@overload
def evaluate(
model: ttb.ktensor,
data: Union[ttb.tensor, ttb.sptensor],
weights: Optional[np.ndarray],
function_handle: function_type,
gradient_handle: Literal[None],
) -> float: ... # pragma: no cover see coveragepy/issues/970
@overload
def evaluate(
model: ttb.ktensor,
data: Union[ttb.tensor, ttb.sptensor],
weights: Optional[np.ndarray],
function_handle: function_type,
gradient_handle: function_type,
) -> Tuple[float, List[np.ndarray]]: ... # pragma: no cover see coveragepy/issues/970
[docs]
def evaluate(
model: ttb.ktensor,
data: Union[ttb.tensor, ttb.sptensor],
weights: Optional[np.ndarray] = None,
function_handle: Optional[function_type] = None,
gradient_handle: Optional[function_type] = None,
) -> Union[float, List[np.ndarray], Tuple[float, List[np.ndarray]]]:
"""Evaluate an objective function and/or gradient function.
Parameters
----------
model:
Current decomposition.
data:
Source tensor to decompose.
weights:
Weighted values for returned tensor. Can be used as a mask.
function_handle:
Objective function.
gradient_handle:
Gradient definition.
Returns
-------
Objective function value and/or gradient function value with respect to model.
"""
if function_handle is None and gradient_handle is None:
raise ValueError(
"Either a function handle, or a gradient handle must be provided."
)
if isinstance(data, ttb.sptensor):
data = data.full()
full_model = model.full()
# TODO should we early check shapes?
# I don't think we always get vectorization for free in python
# we should be able to operate on underlying np arrays directly though
F: Optional[float] = None
G: Optional[List[np.ndarray]] = None
if function_handle is not None:
Y = function_handle(data.data, full_model.data)
if weights is not None:
Y *= weights
F = float(np.sum(Y))
if gradient_handle is not None:
Y = gradient_handle(data.data, full_model.data)
if weights is not None:
Y *= weights
G = ttb.tensor(Y, copy=False).mttkrps(model.factor_matrices)
if F is not None and G is not None:
return F, G
if F is not None:
return F
if G is not None:
return G
raise ValueError(
"No valid outputs for either function or gradient handles"
) # pragma: no cover