"""PYTTB shared utilities across tensor types."""
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
from __future__ import annotations
from enum import Enum
from math import prod
from typing import (
Any,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
get_args,
overload,
)
import numpy as np
from scipy import sparse
import pyttb as ttb
Shape = Union[int, Iterable[int]]
OneDArray = Union[int, float, Iterable[int], Iterable[float], np.ndarray]
MemoryLayout = Union[Literal["F"], Literal["C"]]
[docs]
def tt_union_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray:
"""Reproduce functionality of MATLABS intersect(a,b,'rows').
Parameters
----------
MatrixA:
First matrix.
MatrixB:
Second matrix.
Returns
-------
location:
List of intersection indices
Examples
--------
>>> a = np.array([[1, 2], [3, 4]])
>>> b = np.array([[0, 0], [1, 2], [3, 4], [0, 0]])
>>> tt_union_rows(a, b)
array([[0, 0],
[1, 2],
[3, 4]])
"""
# TODO ismember and unique are very similar in function
if MatrixA.size > 0:
MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True)
else:
MatrixA = MatrixAUnique = np.empty(shape=MatrixB.shape)
idxA = np.array([], dtype=int)
if MatrixB.size > 0:
MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True)
else:
MatrixB = MatrixBUnique = np.empty(shape=MatrixA.shape)
idxB = np.array([], dtype=int)
_, location = tt_ismember_rows(
MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)]
)
union = np.vstack(
(MatrixB[np.sort(idxB[np.where(location < 0)])], MatrixA[np.sort(idxA)])
)
return union
@overload
def tt_dimscheck(
N: int,
M: None = None,
dims: Optional[OneDArray] = None,
exclude_dims: Optional[OneDArray] = None,
) -> Tuple[np.ndarray, None]: ... # pragma: no cover see coveragepy/issues/970
@overload
def tt_dimscheck(
N: int,
M: int,
dims: Optional[OneDArray] = None,
exclude_dims: Optional[OneDArray] = None,
) -> Tuple[np.ndarray, np.ndarray]: ... # pragma: no cover see coveragepy/issues/970
[docs]
def tt_dimscheck( # noqa: PLR0912
N: int,
M: Optional[int] = None,
dims: Optional[OneDArray] = None,
exclude_dims: Optional[OneDArray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Preprocess dimensions for tensor operations.
Parameters
----------
N: Tensor order
M: Num of multiplicands
dims: Dimensions to check
exclude_dims: Check all dimensions but these. (Mutually exclusive with dims)
Returns
-------
sdims: New dimensions
vidx: Index into the multiplicands (if M defined).
Examples
--------
# Default captures all dims and no index
>>> rdims, _ = tt_dimscheck(6)
>>> np.array_equal(rdims, np.arange(6))
True
# Exclude single dim and still no index
>>> rdims, _ = tt_dimscheck(6, exclude_dims=np.array([5]))
>>> np.array_equal(rdims, np.arange(5))
True
# Exclude single dim and number of multiplicands equals resulting size
>>> rdims, ridx = tt_dimscheck(6, 5, exclude_dims=np.array([0]))
>>> np.array_equal(rdims, np.array([1, 2, 3, 4, 5]))
True
>>> np.array_equal(ridx, np.arange(0, 5))
True
"""
if dims is not None and exclude_dims is not None:
raise ValueError("Either specify dims to include or exclude, but not both")
if dims is not None:
dims = parse_one_d(dims)
if exclude_dims is not None:
exclude_dims = parse_one_d(exclude_dims)
dim_array: np.ndarray = np.empty((1,))
# Explicit exclude to resolve ambiguous -0
if exclude_dims is not None:
# Check that all members in range
valid_indices = np.isin(exclude_dims, np.arange(0, N))
if not np.all(valid_indices):
invalid_indices = np.logical_not(valid_indices)
raise ValueError(
f"Exclude dims provided: {exclude_dims} "
f"but, {exclude_dims[invalid_indices]} were out of valid range"
f"[0,{N}]"
)
dim_array = np.setdiff1d(np.arange(0, N), exclude_dims)
# Fix empty case
# if (dims is None or dims.size == 0) and exclude_dims is None:
if dims is None and exclude_dims is None:
dim_array = np.arange(0, N)
elif isinstance(dims, np.ndarray):
dim_array = dims
# Catch minus case to avoid silent errors
if np.any(dim_array < 0):
raise ValueError(
"Negative dims aren't allowed in pyttb, see exclude_dims argument instead"
)
# Save dimensions of dims
P = len(dim_array)
# Reorder dims from smallest to largest (this matters in particular for the vector
# multiplicand case, where the order affects the result)
sidx = np.argsort(dim_array)
sdims = dim_array[sidx]
vidx = None
if M is not None:
# Can't have more multiplicands than dimensions
if M > N:
assert False, "Cannot have more multiplicands than dimensions"
# Check that the number of multiplicands must either be full dimensional or
# equal to the specified dimensions (M==N) or M(==P) respectively
if M not in (N, P):
assert False, "Invalid number of multiplicands"
# Check sizes to determine how to index multiplicands
if P == M:
# Case 1: Number of items in dims and number of multiplicands are equal;
# therefore, index in order of sdims
vidx = sidx
else:
# Case 2: Number of multiplicands is equal to the number of dimensions of
# tensor; therefore, index multiplicands by dimensions in dims argument.
vidx = sdims
return sdims, vidx
[docs]
def tt_setdiff_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray:
"""Reproduce functionality of MATLABS setdiff(a,b,'rows').
Parameters
----------
MatrixA:
First matrix.
MatrixB:
Second matrix.
Returns
-------
List of set difference indices.
"""
# TODO intersect and setdiff are very similar in function
if MatrixA.size > 0:
MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True)
else:
MatrixAUnique = idxA = np.array([], dtype=int)
if MatrixB.size > 0:
MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True)
else:
MatrixBUnique = idxB = np.array([], dtype=int)
valid, location = tt_ismember_rows(
MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)]
)
return np.setdiff1d(idxA, location[valid])
[docs]
def tt_intersect_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray:
"""Reproduce functionality of MATLABS intersect(a,b,'rows').
Parameters
----------
MatrixA:
First matrix.
MatrixB:
Second matrix.
Returns
-------
location:
List of intersection indices.
Examples
--------
>>> a = np.array([[1, 2], [3, 4]])
>>> b = np.array([[0, 0], [1, 2], [3, 4], [0, 0]])
>>> tt_intersect_rows(a, b)
array([0, 1])
>>> tt_intersect_rows(b, a)
array([1, 2])
"""
# TODO ismember and unique are very similar in function
if MatrixA.size > 0:
MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True)
else:
MatrixAUnique = idxA = np.array([], dtype=int)
if MatrixB.size > 0:
MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True)
else:
MatrixBUnique = idxB = np.array([], dtype=int)
valid, location = tt_ismember_rows(
MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)]
)
return location[valid]
[docs]
def tt_irenumber(
t: ttb.sptensor, shape: Tuple[int, ...], number_range: Sequence[IndexType]
) -> np.ndarray:
"""Renumber indices for sptensor __setitem__.
Parameters
----------
t:
Sptensor we are trying to assign from
shape:
Shape of destination tensor
number_range:
Key from __setitem__ for destination tensor
Returns
-------
Subscripts for sptensor assignment
"""
nz = t.nnz
if nz == 0:
newsubs = np.array([])
return newsubs
newsubs = t.subs.astype(int)
for i, r in enumerate(number_range):
if isinstance(r, slice):
start = r.start or 0
stop = r.stop or shape[i]
newsubs[:, i] = np.arange(start, stop + 1)[newsubs[:, i]]
elif isinstance(r, int):
# This appears to be inserting new keys as rows to our subs here
newsubs = np.insert(newsubs, obj=i, values=r, axis=1)
else:
if not isinstance(r, np.ndarray):
r = np.array(r) # noqa: PLW2901
newsubs[:, i] = r[newsubs[:, i]]
return newsubs
[docs]
def tt_renumber(
subs: np.ndarray, shape: Tuple[int, ...], number_range: Sequence[IndexType]
) -> Tuple[np.ndarray, Tuple[int, ...]]:
"""Renumber indices for sptensor __getitem__.
[NEWSUBS,NEWSZ] = RENUMBER(SUBS,SZ,RANGE) takes a set of
original subscripts SUBS with entries from a tensor of size
SZ. All the entries in SUBS are assumed to be within the
specified RANGE. These subscripts are then renumbered so that,
in dimension i, the numbers range from 1:numel(RANGE(i)).
Parameters
----------
subs:
Original subscripts for source tensor.
shape:
Shape of source tensor.
number_range:
Key from __getitem__ for tensor.
Returns
-------
newsubs:
Updated subscripts.
newshape:
Resulting shape.
"""
newshape = np.array(shape)
newsubs = subs
for i in range(0, len(shape)):
if not number_range[i] == slice(None, None, None):
if subs.size == 0:
if not isinstance(number_range[i], slice):
# This should be statically determinable but mypy unhappy
# without intermediate
number_range_i = number_range[i]
if isinstance(number_range_i, (int, float, np.integer)):
newshape[i] = number_range_i
else:
assert not isinstance(number_range_i, (int, slice, np.integer))
newshape[i] = len(number_range_i)
else:
# TODO get this length without generating the range
# This should be statically determinable but mypy unhappy
# without assert
number_range_i = number_range[i]
assert isinstance(number_range_i, slice)
newshape[i] = len(range(0, shape[i])[number_range_i])
else:
newsubs[:, i], newshape[i] = tt_renumberdim(
subs[:, i], shape[i], number_range[i]
)
return newsubs, tuple(newshape)
[docs]
def tt_renumberdim(
idx: np.ndarray, shape: int, number_range: IndexType
) -> Tuple[int, int]:
"""Renumber a single dimension.
Helper function for RENUMBER.
Parameters
----------
idx:
shape:
number_range:
Returns
-------
newidx:
newshape:
"""
# Determine the size of the new range
if isinstance(number_range, (int, np.integer)):
number_range = [int(number_range)]
newshape = 0
elif isinstance(number_range, slice):
number_range = list(range(0, shape))[number_range]
newshape = len(number_range)
elif isinstance(number_range, (Sequence, np.ndarray)):
newshape = len(number_range)
else:
raise ValueError(f"Bad number range: {number_range}")
# Create map from old range to the new range
idx_map = np.zeros(shape=shape)
for i in range(0, newshape):
idx_map[number_range[i]] = int(i)
# Do the mapping
newidx = idx_map[idx]
return newidx, newshape
# TODO make more efficient
# https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index
# For thoughts on how to speed this up
[docs]
def tt_ismember_rows(
search: np.ndarray, source: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Find location of search rows in source array.
Parameters
----------
search:
Array to match to source array.
source:
Array to be matched against.
Returns
-------
matched:
len(results)==len(matched) Boolean for indexing matched results.
results:
search.size==results.size,
if search[0,:] == source[3,:], then results[0] = 3
if exists i such that search[i,:] != source[j,:] for any j, then results[i] = -1
Examples
--------
>>> a = np.array([[4, 6], [1, 9], [2, 6]])
>>> b = np.array([[2, 6], [2, 1], [2, 4], [4, 6], [4, 7], [5, 9], [5, 2], [5, 1]])
>>> matched, results = tt_ismember_rows(a, b)
>>> print(results)
[ 3 -1 0]
>>> print(matched)
[ True False True]
"""
matched = np.zeros(shape=search.shape[0], dtype=bool)
results = np.ones(shape=search.shape[0]) * -1
if search.size == 0:
return matched, results.astype(int)
if source.size == 0:
return matched, results.astype(int)
(row_idx, col_idx) = np.nonzero(np.all(source == search[:, np.newaxis], axis=2))
results[row_idx] = col_idx
matched[row_idx] = True
return matched, results.astype(int)
[docs]
def tt_ind2sub(
shape: Tuple[int, ...],
idx: np.ndarray,
order: MemoryLayout = "F",
) -> np.ndarray:
"""
Multiple subscripts from linear indices.
Parameters
----------
shape: Shape of tensor indexing into.
idx: Array of linear indices into the tensor.
Returns
-------
Multi-dimensional indices for the tensor.
Example
-------
>>> shape = (2, 2, 2)
>>> linear_indices = np.array([0, 1])
>>> tt_ind2sub(shape, linear_indices)
array([[0, 0, 0],
[1, 0, 0]])
"""
if idx.size == 0:
return np.empty(shape=(0, len(shape)), dtype=int)
idx[idx < 0] += prod(shape) # Handle negative indexing as simply as possible
return np.array(np.unravel_index(idx, shape, order=order)).transpose()
[docs]
def tt_subsubsref(obj: np.ndarray, s: Any) -> Union[float, np.ndarray]:
"""Helper function for tensor toolbox subsref.
Parameters
----------
obj:
Tensor Data Structure
s:
Reference into tensor
Returns
-------
Still uncertain to this functionality
""" # noqa: D401
# TODO figure out when subsref yields key of length>1 for now ignore this logic and
# just return
# if len(s) == 1:
# return obj
# else:
# return obj[s[1:]]
if isinstance(obj, np.ndarray) and obj.size == 1:
# TODO: Globally figure out why typing thinks item is a string
return cast(float, obj.item())
return obj
[docs]
def tt_sub2ind(
shape: Tuple[int, ...],
subs: np.ndarray,
order: MemoryLayout = "F",
) -> np.ndarray:
"""Convert multidimensional subscripts to linear indices.
Parameters
----------
shape:
Shape of tensor
subs:
Subscripts for tensor
order:
Memory layout
Examples
--------
>>> shape = (2, 2, 2)
>>> full_indices = np.array([[0, 0, 0], [1, 0, 0]], dtype=int)
>>> tt_sub2ind(shape, full_indices)
array([0, 1])
See Also
--------
:func:`tt_ind2sub`:
"""
if subs.size == 0:
return np.array([])
idx = np.ravel_multi_index(tuple(subs.transpose()), shape, order=order)
return idx
[docs]
def tt_sizecheck(shape: Tuple[int, ...], nargout: bool = True) -> bool:
"""
TT_SIZECHECK Checks that the shape is valid.
TT_SIZECHECK(S) throws an error if S is not a valid shape tuple,
which means that it is a row vector with strictly positive,
real-valued, finite integer values.
Parameters
----------
shape:
Shape of tensor
nargout:
Controls if response returned or just acts as assert
Returns
-------
bool
Examples
--------
>>> tt_sizecheck((0, -1, 2))
False
>>> tt_sizecheck((1, 1, 1))
True
See Also
--------
:func:`tt_subscheck`:
"""
siz = np.array(shape)
if (
len(siz.shape) == 1
and all(np.isfinite(siz))
and issubclass(siz.dtype.type, np.integer)
and all(siz > 0)
):
ok = True
elif siz.size == 0:
ok = True
else:
ok = False
if not ok and not nargout:
assert False, "Size must be a row vector of real positive integers"
return ok
[docs]
def tt_subscheck(subs: np.ndarray, nargout: bool = True) -> bool:
"""
TT_SUBSCHECK Checks for valid subscripts.
TT_SUBSCHECK(S) throws an error if S is not a valid subscript
array, which means that S is a matrix of real-valued, finite,
positive, integer subscripts.
Parameters
----------
subs:
Subs of tensor
nargout:
Controls if response returned or just acts as assert
Returns
-------
bool
Examples
--------
>>> tt_subscheck(np.array([[2, 2], [3, 3]]))
True
>>> tt_subscheck(np.array([[2, 2], [3, -1]]))
False
See Also
--------
:func:`tt_sizecheck`:
:func:`tt_valscheck`:
"""
if subs.size == 0:
ok = True
elif (
len(subs.shape) == 2
and (np.isfinite(subs)).all()
and issubclass(subs.dtype.type, np.integer)
and (subs >= 0).all()
):
ok = True
else:
ok = False
if not ok and not nargout:
assert False, "Subscripts must be a matrix of real positive integers"
return ok
[docs]
def tt_valscheck(vals: np.ndarray, nargout: bool = True) -> bool:
"""
TT_VALSCHECK Checks for valid values.
TT_VALSCHECK(S) throws an error if S is not a valid values
array, which means that S is a column array.
Parameters
----------
vals:
Values of tensor
nargout:
Controls if response returned or just acts as assert
Returns
-------
bool
Examples
--------
>>> tt_valscheck(np.array([[1], [2]]))
True
>>> tt_valscheck(np.array([[1, 2, 3], [2, 2, 2]]))
False
See Also
--------
:func:`tt_sizecheck`:
:func:`tt_subscheck`:
"""
if vals.size == 0:
ok = True
elif len(vals.shape) == 2 and vals.shape[1] == 1:
ok = True
else:
ok = False
if not ok and not nargout:
assert False, f"Values must be in array but got {vals}"
return ok
[docs]
def isrow(v: np.ndarray) -> bool:
"""
ISROW Checks if vector is a row vector.
ISROW(V) returns True if V is a row vector; otherwise returns False.
Parameters
----------
v:
Vector input
Examples
--------
>>> isrow(np.array([[1, 2]]))
True
>>> isrow(np.array([[1, 2], [3, 4]]))
False
"""
return v.ndim == 2 and v.shape[0] == 1 and v.shape[1] >= 1
[docs]
def isvector(a: np.ndarray) -> bool:
"""
ISVECTOR Checks if vector is a row vector.
ISVECTOR(A) returns True if A is a vector; otherwise returns False.
Parameters
----------
a:
Returns
-------
bool
"""
return a.ndim == 1 or (a.ndim == 2 and (a.shape[0] == 1 or a.shape[1] == 1))
# TODO: this is a challenge, since it may need to apply to either Python built in types
# or numpy types
[docs]
def islogical(a: np.ndarray) -> bool:
"""
ISLOGICAL Checks if vector is a logical vector.
ISLOGICAL(A) returns True if A is a logical array; otherwise returns False.
Parameters
----------
a:
Returns
-------
bool
"""
return isinstance(a, bool)
# Adding all sorts of index support here, might consider splitting out to
# more specific file later
[docs]
class IndexVariant(Enum):
"""Methods for indexing entries of tensors."""
UNKNOWN = 0
LINEAR = 1
SUBTENSOR = 2
SUBSCRIPTS = 3
# We probably want to create a specific file for utility types
LinearIndexType = Union[int, np.integer, slice]
IndexType = Union[LinearIndexType, Sequence[int], np.ndarray]
[docs]
def get_index_variant(indices: IndexType) -> IndexVariant:
"""Decide on intended indexing variant. No correctness checks.
See getitem or setitem in :class:`pyttb.tensor` for elaboration of the
various indexing options.
"""
variant = IndexVariant.UNKNOWN
if isinstance(indices, get_args(LinearIndexType)):
variant = IndexVariant.LINEAR
elif isinstance(indices, np.ndarray):
# TODO this is technically slightly stricter than what
# we currently have but probably clearer
if len(indices.shape) == 1:
variant = IndexVariant.LINEAR
else:
variant = IndexVariant.SUBSCRIPTS
elif isinstance(indices, tuple):
variant = IndexVariant.SUBTENSOR
elif isinstance(indices, Sequence) and isinstance(indices[0], int):
# TODO this is slightly redundant/inefficient
key = np.array(indices)
if len(key.shape) == 1 or key.shape[1] == 1:
variant = IndexVariant.LINEAR
return variant
[docs]
def get_mttkrp_factors(
U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer], ndims: int
) -> Sequence[np.ndarray]:
"""Apply standard checks and type conversions for mttkrp factors."""
if isinstance(U, ttb.ktensor):
U = U.copy()
# Absorb lambda into one of the factors but not the one that is skipped
if n == 0:
U.redistribute(1)
else:
U.redistribute(0)
# Extract the factor matrices
U = U.factor_matrices
assert isinstance(
U, (Sequence, np.ndarray)
), "Second argument must be a sequence of numpy.ndarray's or a ktensor"
assert len(U) == ndims, "List of factor matrices is the wrong length"
return U
[docs]
def gather_wrap_dims(
ndims: int,
rdims: Optional[np.ndarray] = None,
cdims: Optional[np.ndarray] = None,
cdims_cyclic: Optional[Union[Literal["fc"], Literal["bc"], Literal["t"]]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Extract tensor modes mapped to rows and columns for matricized tensors.
Parameters
----------
ndims:
Number of dimensions.
rdims:
Mapping of row indices.
cdims:
Mapping of column indices.
cdims_cyclic:
When only rdims is specified maps a single rdim to the rows and
the remaining dimensions span the columns. _fc_ (forward cyclic[1]_)
in the order range(rdims,self.ndims()) followed by range(0, rdims).
_bc_ (backward cyclic[2]_) range(rdims-1, -1, -1) then
range(self.ndims(), rdims, -1).
Notes
-----
Forward cyclic is defined by Kiers [1]_ and backward cyclic is defined by
De Lathauwer, De Moor, and Vandewalle [2]_.
References
----------
.. [1] KIERS, H. A. L. 2000. Towards a standardized notation and terminology
in multiway analysis. J. Chemometrics 14, 105-122.
.. [2] DE LATHAUWER, L., DE MOOR, B., AND VANDEWALLE, J. 2000b. On the best
rank-1 and rank-(R1, R2, ... , RN ) approximation of higher-order
tensors. SIAM J. Matrix Anal. Appl. 21, 4, 1324-1342.
"""
alldims = np.array([range(ndims)])
if rdims is not None and cdims is None:
# Single row mapping
if len(rdims) == 1 and cdims_cyclic is not None:
# TODO we should be able to remove this since we can just specify
# cdims alone
if cdims_cyclic == "t":
cdims = rdims
rdims = np.setdiff1d(alldims, rdims)
elif cdims_cyclic == "fc":
cdims = np.array(
[i for i in range(rdims[0] + 1, ndims)]
+ [i for i in range(rdims[0])]
)
elif cdims_cyclic == "bc":
cdims = np.array(
[i for i in range(rdims[0] - 1, -1, -1)]
+ [i for i in range(ndims - 1, rdims[0], -1)]
)
else:
assert False, (
"Unrecognized value for cdims_cyclic pattern, "
'must be "fc" or "bc".'
)
else:
# Multiple row mapping
cdims = np.setdiff1d(alldims, rdims)
elif rdims is None and cdims is not None:
rdims = np.setdiff1d(alldims, cdims)
assert rdims is not None and cdims is not None
return rdims.astype(int), cdims.astype(int)
[docs]
def np_to_python(
iterable: Iterable,
) -> Iterable:
"""Convert a structure containing numpy scalars to pure python types.
Mostly useful for prettier printing post numpy 2.0.
Parameters
----------
iterable:
Structure potentially containing numpy scalars.
"""
output_type = type(iterable)
return output_type( # type: ignore [call-arg]
element.item() if isinstance(element, np.generic) else element
for element in iterable
)
[docs]
def parse_shape(shape: Shape) -> Tuple[int, ...]:
"""Parse flexible type into shape tuple.
Examples
--------
>>> integer_shape = 4
>>> parse_shape(integer_shape)
(4,)
>>> flat_numpy_shape = np.ones((4,), dtype=int)
>>> parse_shape(flat_numpy_shape)
(1, 1, 1, 1)
>>> stacked_numpy_shape = np.ones((4, 1, 1), dtype=int)
>>> parse_shape(stacked_numpy_shape)
(1, 1, 1, 1)
>>> list_shape = [1, 1, 1, 1]
>>> parse_shape(list_shape)
(1, 1, 1, 1)
"""
# FIXME do we care to map numpy ints to python ints?
if isinstance(shape, (int, np.integer)):
return (shape,)
if isinstance(shape, np.ndarray):
if not np.issubdtype(shape.dtype, np.integer):
raise ValueError("Numpy arrays used as shapes must be integer valued")
squeezed_shape = shape.squeeze()
if squeezed_shape.ndim == 0:
# If it's an array containing a single scalar
return (int(squeezed_shape),)
if squeezed_shape.ndim > 1:
raise ValueError(
"Numpy arrays used as shapes can only have one non-trivial dimension"
)
return tuple(map(int, squeezed_shape))
shape = tuple(shape)
if not all(isinstance(ele, (int, np.integer)) for ele in shape):
raise ValueError("Shapes entries must be integers")
return shape
[docs]
def parse_one_d(maybe_vector: OneDArray) -> np.ndarray:
"""Parse flexible type into numpy array.
Examples
--------
>>> int_scalar = 1
>>> parse_one_d(int_scalar)
array([1])
>>> np_int_scalar = np.int8(1)
>>> parse_one_d(np_int_scalar)
array([1], dtype=int8)
>>> float_scalar = 1.0
>>> parse_one_d(float_scalar)
array([1.])
>>> np_float_scalar = 1.0
>>> parse_one_d(np_float_scalar)
array([1.])
>>> example_list = [1.0, 1.0]
>>> parse_one_d(example_list)
array([1., 1.])
>>> extra_dims = np.array([[1, 1]])
>>> parse_one_d(extra_dims)
array([1, 1])
"""
if isinstance(maybe_vector, (int, float, np.integer, np.floating)):
return np.array([maybe_vector])
if isinstance(maybe_vector, np.ndarray):
squeezed_vector = maybe_vector.squeeze()
if squeezed_vector.ndim == 1:
return squeezed_vector
elif squeezed_vector.ndim == 0:
# Squeezed to scalar so force vector
return squeezed_vector[None]
else:
raise ValueError(
"Vector can have at most one non-trivial dimension but "
f"had shape {maybe_vector.shape}"
)
return np.array(maybe_vector)
@overload
def to_memory_order(
array: np.ndarray, order: MemoryLayout, copy: bool = False
) -> np.ndarray:
pass
@overload
def to_memory_order(
array: sparse.coo_matrix, order: MemoryLayout, copy: bool = False
) -> sparse.coo_matrix:
pass
[docs]
def to_memory_order(
array: Union[np.ndarray, sparse.coo_matrix], order: MemoryLayout, copy: bool = False
) -> Union[np.ndarray, sparse.coo_matrix]:
"""Convert an array to the specified memory layout.
Parameters
----------
array: Data to ensure matches memory order.
order: Desired memory order.
copy: Whether to force a copy even if data already in supported memory order.
Examples
--------
>>> c_order = np.arange(16).reshape((2, 2, 2, 2))
>>> c_order.flags["C_CONTIGUOUS"]
True
>>> to_memory_order(c_order, "F").flags["F_CONTIGUOUS"]
True
"""
if copy:
# This could be slightly optimized
# in worst case two copies occur
array = array.copy()
if isinstance(array, sparse.coo_matrix):
return array
if order == "F":
return np.asfortranarray(array)
elif order == "C":
return np.ascontiguousarray(array)
raise ValueError(f"Unsupported order {order}")
if __name__ == "__main__":
import doctest # pragma: no cover
doctest.testmod() # pragma: no cover