Source code for pyttb.pyttb_utils

"""PYTTB shared utilities across tensor types."""

# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from __future__ import annotations

from enum import Enum
from math import prod
from typing import (
    Any,
    Iterable,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
    get_args,
    overload,
)

import numpy as np
from scipy import sparse

import pyttb as ttb

Shape = Union[int, Iterable[int]]
OneDArray = Union[int, float, Iterable[int], Iterable[float], np.ndarray]
MemoryLayout = Union[Literal["F"], Literal["C"]]


[docs] def tt_union_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray: """Reproduce functionality of MATLABS intersect(a,b,'rows'). Parameters ---------- MatrixA: First matrix. MatrixB: Second matrix. Returns ------- location: List of intersection indices Examples -------- >>> a = np.array([[1, 2], [3, 4]]) >>> b = np.array([[0, 0], [1, 2], [3, 4], [0, 0]]) >>> tt_union_rows(a, b) array([[0, 0], [1, 2], [3, 4]]) """ # TODO ismember and unique are very similar in function if MatrixA.size > 0: MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True) else: MatrixA = MatrixAUnique = np.empty(shape=MatrixB.shape) idxA = np.array([], dtype=int) if MatrixB.size > 0: MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True) else: MatrixB = MatrixBUnique = np.empty(shape=MatrixA.shape) idxB = np.array([], dtype=int) _, location = tt_ismember_rows( MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)] ) union = np.vstack( (MatrixB[np.sort(idxB[np.where(location < 0)])], MatrixA[np.sort(idxA)]) ) return union
@overload def tt_dimscheck( N: int, M: None = None, dims: Optional[OneDArray] = None, exclude_dims: Optional[OneDArray] = None, ) -> Tuple[np.ndarray, None]: ... # pragma: no cover see coveragepy/issues/970 @overload def tt_dimscheck( N: int, M: int, dims: Optional[OneDArray] = None, exclude_dims: Optional[OneDArray] = None, ) -> Tuple[np.ndarray, np.ndarray]: ... # pragma: no cover see coveragepy/issues/970
[docs] def tt_dimscheck( # noqa: PLR0912 N: int, M: Optional[int] = None, dims: Optional[OneDArray] = None, exclude_dims: Optional[OneDArray] = None, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """Preprocess dimensions for tensor operations. Parameters ---------- N: Tensor order M: Num of multiplicands dims: Dimensions to check exclude_dims: Check all dimensions but these. (Mutually exclusive with dims) Returns ------- sdims: New dimensions vidx: Index into the multiplicands (if M defined). Examples -------- # Default captures all dims and no index >>> rdims, _ = tt_dimscheck(6) >>> np.array_equal(rdims, np.arange(6)) True # Exclude single dim and still no index >>> rdims, _ = tt_dimscheck(6, exclude_dims=np.array([5])) >>> np.array_equal(rdims, np.arange(5)) True # Exclude single dim and number of multiplicands equals resulting size >>> rdims, ridx = tt_dimscheck(6, 5, exclude_dims=np.array([0])) >>> np.array_equal(rdims, np.array([1, 2, 3, 4, 5])) True >>> np.array_equal(ridx, np.arange(0, 5)) True """ if dims is not None and exclude_dims is not None: raise ValueError("Either specify dims to include or exclude, but not both") if dims is not None: dims = parse_one_d(dims) if exclude_dims is not None: exclude_dims = parse_one_d(exclude_dims) dim_array: np.ndarray = np.empty((1,)) # Explicit exclude to resolve ambiguous -0 if exclude_dims is not None: # Check that all members in range valid_indices = np.isin(exclude_dims, np.arange(0, N)) if not np.all(valid_indices): invalid_indices = np.logical_not(valid_indices) raise ValueError( f"Exclude dims provided: {exclude_dims} " f"but, {exclude_dims[invalid_indices]} were out of valid range" f"[0,{N}]" ) dim_array = np.setdiff1d(np.arange(0, N), exclude_dims) # Fix empty case # if (dims is None or dims.size == 0) and exclude_dims is None: if dims is None and exclude_dims is None: dim_array = np.arange(0, N) elif isinstance(dims, np.ndarray): dim_array = dims # Catch minus case to avoid silent errors if np.any(dim_array < 0): raise ValueError( "Negative dims aren't allowed in pyttb, see exclude_dims argument instead" ) # Save dimensions of dims P = len(dim_array) # Reorder dims from smallest to largest (this matters in particular for the vector # multiplicand case, where the order affects the result) sidx = np.argsort(dim_array) sdims = dim_array[sidx] vidx = None if M is not None: # Can't have more multiplicands than dimensions if M > N: assert False, "Cannot have more multiplicands than dimensions" # Check that the number of multiplicands must either be full dimensional or # equal to the specified dimensions (M==N) or M(==P) respectively if M not in (N, P): assert False, "Invalid number of multiplicands" # Check sizes to determine how to index multiplicands if P == M: # Case 1: Number of items in dims and number of multiplicands are equal; # therefore, index in order of sdims vidx = sidx else: # Case 2: Number of multiplicands is equal to the number of dimensions of # tensor; therefore, index multiplicands by dimensions in dims argument. vidx = sdims return sdims, vidx
[docs] def tt_setdiff_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray: """Reproduce functionality of MATLABS setdiff(a,b,'rows'). Parameters ---------- MatrixA: First matrix. MatrixB: Second matrix. Returns ------- List of set difference indices. """ # TODO intersect and setdiff are very similar in function if MatrixA.size > 0: MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True) else: MatrixAUnique = idxA = np.array([], dtype=int) if MatrixB.size > 0: MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True) else: MatrixBUnique = idxB = np.array([], dtype=int) valid, location = tt_ismember_rows( MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)] ) return np.setdiff1d(idxA, location[valid])
[docs] def tt_intersect_rows(MatrixA: np.ndarray, MatrixB: np.ndarray) -> np.ndarray: """Reproduce functionality of MATLABS intersect(a,b,'rows'). Parameters ---------- MatrixA: First matrix. MatrixB: Second matrix. Returns ------- location: List of intersection indices. Examples -------- >>> a = np.array([[1, 2], [3, 4]]) >>> b = np.array([[0, 0], [1, 2], [3, 4], [0, 0]]) >>> tt_intersect_rows(a, b) array([0, 1]) >>> tt_intersect_rows(b, a) array([1, 2]) """ # TODO ismember and unique are very similar in function if MatrixA.size > 0: MatrixAUnique, idxA = np.unique(MatrixA, axis=0, return_index=True) else: MatrixAUnique = idxA = np.array([], dtype=int) if MatrixB.size > 0: MatrixBUnique, idxB = np.unique(MatrixB, axis=0, return_index=True) else: MatrixBUnique = idxB = np.array([], dtype=int) valid, location = tt_ismember_rows( MatrixBUnique[np.argsort(idxB)], MatrixAUnique[np.argsort(idxA)] ) return location[valid]
[docs] def tt_irenumber( t: ttb.sptensor, shape: Tuple[int, ...], number_range: Sequence[IndexType] ) -> np.ndarray: """Renumber indices for sptensor __setitem__. Parameters ---------- t: Sptensor we are trying to assign from shape: Shape of destination tensor number_range: Key from __setitem__ for destination tensor Returns ------- Subscripts for sptensor assignment """ nz = t.nnz if nz == 0: newsubs = np.array([]) return newsubs newsubs = t.subs.astype(int) for i, r in enumerate(number_range): if isinstance(r, slice): start = r.start or 0 stop = r.stop or shape[i] newsubs[:, i] = np.arange(start, stop + 1)[newsubs[:, i]] elif isinstance(r, int): # This appears to be inserting new keys as rows to our subs here newsubs = np.insert(newsubs, obj=i, values=r, axis=1) else: if not isinstance(r, np.ndarray): r = np.array(r) # noqa: PLW2901 newsubs[:, i] = r[newsubs[:, i]] return newsubs
[docs] def tt_renumber( subs: np.ndarray, shape: Tuple[int, ...], number_range: Sequence[IndexType] ) -> Tuple[np.ndarray, Tuple[int, ...]]: """Renumber indices for sptensor __getitem__. [NEWSUBS,NEWSZ] = RENUMBER(SUBS,SZ,RANGE) takes a set of original subscripts SUBS with entries from a tensor of size SZ. All the entries in SUBS are assumed to be within the specified RANGE. These subscripts are then renumbered so that, in dimension i, the numbers range from 1:numel(RANGE(i)). Parameters ---------- subs: Original subscripts for source tensor. shape: Shape of source tensor. number_range: Key from __getitem__ for tensor. Returns ------- newsubs: Updated subscripts. newshape: Resulting shape. """ newshape = np.array(shape) newsubs = subs for i in range(0, len(shape)): if not number_range[i] == slice(None, None, None): if subs.size == 0: if not isinstance(number_range[i], slice): # This should be statically determinable but mypy unhappy # without intermediate number_range_i = number_range[i] if isinstance(number_range_i, (int, float, np.integer)): newshape[i] = number_range_i else: assert not isinstance(number_range_i, (int, slice, np.integer)) newshape[i] = len(number_range_i) else: # TODO get this length without generating the range # This should be statically determinable but mypy unhappy # without assert number_range_i = number_range[i] assert isinstance(number_range_i, slice) newshape[i] = len(range(0, shape[i])[number_range_i]) else: newsubs[:, i], newshape[i] = tt_renumberdim( subs[:, i], shape[i], number_range[i] ) return newsubs, tuple(newshape)
[docs] def tt_renumberdim( idx: np.ndarray, shape: int, number_range: IndexType ) -> Tuple[int, int]: """Renumber a single dimension. Helper function for RENUMBER. Parameters ---------- idx: shape: number_range: Returns ------- newidx: newshape: """ # Determine the size of the new range if isinstance(number_range, (int, np.integer)): number_range = [int(number_range)] newshape = 0 elif isinstance(number_range, slice): number_range = list(range(0, shape))[number_range] newshape = len(number_range) elif isinstance(number_range, (Sequence, np.ndarray)): newshape = len(number_range) else: raise ValueError(f"Bad number range: {number_range}") # Create map from old range to the new range idx_map = np.zeros(shape=shape) for i in range(0, newshape): idx_map[number_range[i]] = int(i) # Do the mapping newidx = idx_map[idx] return newidx, newshape
# TODO make more efficient # https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index # For thoughts on how to speed this up
[docs] def tt_ismember_rows( search: np.ndarray, source: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Find location of search rows in source array. Parameters ---------- search: Array to match to source array. source: Array to be matched against. Returns ------- matched: len(results)==len(matched) Boolean for indexing matched results. results: search.size==results.size, if search[0,:] == source[3,:], then results[0] = 3 if exists i such that search[i,:] != source[j,:] for any j, then results[i] = -1 Examples -------- >>> a = np.array([[4, 6], [1, 9], [2, 6]]) >>> b = np.array([[2, 6], [2, 1], [2, 4], [4, 6], [4, 7], [5, 9], [5, 2], [5, 1]]) >>> matched, results = tt_ismember_rows(a, b) >>> print(results) [ 3 -1 0] >>> print(matched) [ True False True] """ matched = np.zeros(shape=search.shape[0], dtype=bool) results = np.ones(shape=search.shape[0]) * -1 if search.size == 0: return matched, results.astype(int) if source.size == 0: return matched, results.astype(int) (row_idx, col_idx) = np.nonzero(np.all(source == search[:, np.newaxis], axis=2)) results[row_idx] = col_idx matched[row_idx] = True return matched, results.astype(int)
[docs] def tt_ind2sub( shape: Tuple[int, ...], idx: np.ndarray, order: MemoryLayout = "F", ) -> np.ndarray: """ Multiple subscripts from linear indices. Parameters ---------- shape: Shape of tensor indexing into. idx: Array of linear indices into the tensor. Returns ------- Multi-dimensional indices for the tensor. Example ------- >>> shape = (2, 2, 2) >>> linear_indices = np.array([0, 1]) >>> tt_ind2sub(shape, linear_indices) array([[0, 0, 0], [1, 0, 0]]) """ if idx.size == 0: return np.empty(shape=(0, len(shape)), dtype=int) idx[idx < 0] += prod(shape) # Handle negative indexing as simply as possible return np.array(np.unravel_index(idx, shape, order=order)).transpose()
[docs] def tt_subsubsref(obj: np.ndarray, s: Any) -> Union[float, np.ndarray]: """Helper function for tensor toolbox subsref. Parameters ---------- obj: Tensor Data Structure s: Reference into tensor Returns ------- Still uncertain to this functionality """ # noqa: D401 # TODO figure out when subsref yields key of length>1 for now ignore this logic and # just return # if len(s) == 1: # return obj # else: # return obj[s[1:]] if isinstance(obj, np.ndarray) and obj.size == 1: # TODO: Globally figure out why typing thinks item is a string return cast(float, obj.item()) return obj
[docs] def tt_sub2ind( shape: Tuple[int, ...], subs: np.ndarray, order: MemoryLayout = "F", ) -> np.ndarray: """Convert multidimensional subscripts to linear indices. Parameters ---------- shape: Shape of tensor subs: Subscripts for tensor order: Memory layout Examples -------- >>> shape = (2, 2, 2) >>> full_indices = np.array([[0, 0, 0], [1, 0, 0]], dtype=int) >>> tt_sub2ind(shape, full_indices) array([0, 1]) See Also -------- :func:`tt_ind2sub`: """ if subs.size == 0: return np.array([]) idx = np.ravel_multi_index(tuple(subs.transpose()), shape, order=order) return idx
[docs] def tt_sizecheck(shape: Tuple[int, ...], nargout: bool = True) -> bool: """ TT_SIZECHECK Checks that the shape is valid. TT_SIZECHECK(S) throws an error if S is not a valid shape tuple, which means that it is a row vector with strictly positive, real-valued, finite integer values. Parameters ---------- shape: Shape of tensor nargout: Controls if response returned or just acts as assert Returns ------- bool Examples -------- >>> tt_sizecheck((0, -1, 2)) False >>> tt_sizecheck((1, 1, 1)) True See Also -------- :func:`tt_subscheck`: """ siz = np.array(shape) if ( len(siz.shape) == 1 and all(np.isfinite(siz)) and issubclass(siz.dtype.type, np.integer) and all(siz > 0) ): ok = True elif siz.size == 0: ok = True else: ok = False if not ok and not nargout: assert False, "Size must be a row vector of real positive integers" return ok
[docs] def tt_subscheck(subs: np.ndarray, nargout: bool = True) -> bool: """ TT_SUBSCHECK Checks for valid subscripts. TT_SUBSCHECK(S) throws an error if S is not a valid subscript array, which means that S is a matrix of real-valued, finite, positive, integer subscripts. Parameters ---------- subs: Subs of tensor nargout: Controls if response returned or just acts as assert Returns ------- bool Examples -------- >>> tt_subscheck(np.array([[2, 2], [3, 3]])) True >>> tt_subscheck(np.array([[2, 2], [3, -1]])) False See Also -------- :func:`tt_sizecheck`: :func:`tt_valscheck`: """ if subs.size == 0: ok = True elif ( len(subs.shape) == 2 and (np.isfinite(subs)).all() and issubclass(subs.dtype.type, np.integer) and (subs >= 0).all() ): ok = True else: ok = False if not ok and not nargout: assert False, "Subscripts must be a matrix of real positive integers" return ok
[docs] def tt_valscheck(vals: np.ndarray, nargout: bool = True) -> bool: """ TT_VALSCHECK Checks for valid values. TT_VALSCHECK(S) throws an error if S is not a valid values array, which means that S is a column array. Parameters ---------- vals: Values of tensor nargout: Controls if response returned or just acts as assert Returns ------- bool Examples -------- >>> tt_valscheck(np.array([[1], [2]])) True >>> tt_valscheck(np.array([[1, 2, 3], [2, 2, 2]])) False See Also -------- :func:`tt_sizecheck`: :func:`tt_subscheck`: """ if vals.size == 0: ok = True elif len(vals.shape) == 2 and vals.shape[1] == 1: ok = True else: ok = False if not ok and not nargout: assert False, f"Values must be in array but got {vals}" return ok
[docs] def isrow(v: np.ndarray) -> bool: """ ISROW Checks if vector is a row vector. ISROW(V) returns True if V is a row vector; otherwise returns False. Parameters ---------- v: Vector input Examples -------- >>> isrow(np.array([[1, 2]])) True >>> isrow(np.array([[1, 2], [3, 4]])) False """ return v.ndim == 2 and v.shape[0] == 1 and v.shape[1] >= 1
[docs] def isvector(a: np.ndarray) -> bool: """ ISVECTOR Checks if vector is a row vector. ISVECTOR(A) returns True if A is a vector; otherwise returns False. Parameters ---------- a: Returns ------- bool """ return a.ndim == 1 or (a.ndim == 2 and (a.shape[0] == 1 or a.shape[1] == 1))
# TODO: this is a challenge, since it may need to apply to either Python built in types # or numpy types
[docs] def islogical(a: np.ndarray) -> bool: """ ISLOGICAL Checks if vector is a logical vector. ISLOGICAL(A) returns True if A is a logical array; otherwise returns False. Parameters ---------- a: Returns ------- bool """ return isinstance(a, bool)
# Adding all sorts of index support here, might consider splitting out to # more specific file later
[docs] class IndexVariant(Enum): """Methods for indexing entries of tensors.""" UNKNOWN = 0 LINEAR = 1 SUBTENSOR = 2 SUBSCRIPTS = 3
# We probably want to create a specific file for utility types LinearIndexType = Union[int, np.integer, slice] IndexType = Union[LinearIndexType, Sequence[int], np.ndarray]
[docs] def get_index_variant(indices: IndexType) -> IndexVariant: """Decide on intended indexing variant. No correctness checks. See getitem or setitem in :class:`pyttb.tensor` for elaboration of the various indexing options. """ variant = IndexVariant.UNKNOWN if isinstance(indices, get_args(LinearIndexType)): variant = IndexVariant.LINEAR elif isinstance(indices, np.ndarray): # TODO this is technically slightly stricter than what # we currently have but probably clearer if len(indices.shape) == 1: variant = IndexVariant.LINEAR else: variant = IndexVariant.SUBSCRIPTS elif isinstance(indices, tuple): variant = IndexVariant.SUBTENSOR elif isinstance(indices, Sequence) and isinstance(indices[0], int): # TODO this is slightly redundant/inefficient key = np.array(indices) if len(key.shape) == 1 or key.shape[1] == 1: variant = IndexVariant.LINEAR return variant
[docs] def get_mttkrp_factors( U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer], ndims: int ) -> Sequence[np.ndarray]: """Apply standard checks and type conversions for mttkrp factors.""" if isinstance(U, ttb.ktensor): U = U.copy() # Absorb lambda into one of the factors but not the one that is skipped if n == 0: U.redistribute(1) else: U.redistribute(0) # Extract the factor matrices U = U.factor_matrices assert isinstance( U, (Sequence, np.ndarray) ), "Second argument must be a sequence of numpy.ndarray's or a ktensor" assert len(U) == ndims, "List of factor matrices is the wrong length" return U
[docs] def gather_wrap_dims( ndims: int, rdims: Optional[np.ndarray] = None, cdims: Optional[np.ndarray] = None, cdims_cyclic: Optional[Union[Literal["fc"], Literal["bc"], Literal["t"]]] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Extract tensor modes mapped to rows and columns for matricized tensors. Parameters ---------- ndims: Number of dimensions. rdims: Mapping of row indices. cdims: Mapping of column indices. cdims_cyclic: When only rdims is specified maps a single rdim to the rows and the remaining dimensions span the columns. _fc_ (forward cyclic[1]_) in the order range(rdims,self.ndims()) followed by range(0, rdims). _bc_ (backward cyclic[2]_) range(rdims-1, -1, -1) then range(self.ndims(), rdims, -1). Notes ----- Forward cyclic is defined by Kiers [1]_ and backward cyclic is defined by De Lathauwer, De Moor, and Vandewalle [2]_. References ---------- .. [1] KIERS, H. A. L. 2000. Towards a standardized notation and terminology in multiway analysis. J. Chemometrics 14, 105-122. .. [2] DE LATHAUWER, L., DE MOOR, B., AND VANDEWALLE, J. 2000b. On the best rank-1 and rank-(R1, R2, ... , RN ) approximation of higher-order tensors. SIAM J. Matrix Anal. Appl. 21, 4, 1324-1342. """ alldims = np.array([range(ndims)]) if rdims is not None and cdims is None: # Single row mapping if len(rdims) == 1 and cdims_cyclic is not None: # TODO we should be able to remove this since we can just specify # cdims alone if cdims_cyclic == "t": cdims = rdims rdims = np.setdiff1d(alldims, rdims) elif cdims_cyclic == "fc": cdims = np.array( [i for i in range(rdims[0] + 1, ndims)] + [i for i in range(rdims[0])] ) elif cdims_cyclic == "bc": cdims = np.array( [i for i in range(rdims[0] - 1, -1, -1)] + [i for i in range(ndims - 1, rdims[0], -1)] ) else: assert False, ( "Unrecognized value for cdims_cyclic pattern, " 'must be "fc" or "bc".' ) else: # Multiple row mapping cdims = np.setdiff1d(alldims, rdims) elif rdims is None and cdims is not None: rdims = np.setdiff1d(alldims, cdims) assert rdims is not None and cdims is not None return rdims.astype(int), cdims.astype(int)
[docs] def np_to_python( iterable: Iterable, ) -> Iterable: """Convert a structure containing numpy scalars to pure python types. Mostly useful for prettier printing post numpy 2.0. Parameters ---------- iterable: Structure potentially containing numpy scalars. """ output_type = type(iterable) return output_type( # type: ignore [call-arg] element.item() if isinstance(element, np.generic) else element for element in iterable )
[docs] def parse_shape(shape: Shape) -> Tuple[int, ...]: """Parse flexible type into shape tuple. Examples -------- >>> integer_shape = 4 >>> parse_shape(integer_shape) (4,) >>> flat_numpy_shape = np.ones((4,), dtype=int) >>> parse_shape(flat_numpy_shape) (1, 1, 1, 1) >>> stacked_numpy_shape = np.ones((4, 1, 1), dtype=int) >>> parse_shape(stacked_numpy_shape) (1, 1, 1, 1) >>> list_shape = [1, 1, 1, 1] >>> parse_shape(list_shape) (1, 1, 1, 1) """ # FIXME do we care to map numpy ints to python ints? if isinstance(shape, (int, np.integer)): return (shape,) if isinstance(shape, np.ndarray): if not np.issubdtype(shape.dtype, np.integer): raise ValueError("Numpy arrays used as shapes must be integer valued") squeezed_shape = shape.squeeze() if squeezed_shape.ndim == 0: # If it's an array containing a single scalar return (int(squeezed_shape),) if squeezed_shape.ndim > 1: raise ValueError( "Numpy arrays used as shapes can only have one non-trivial dimension" ) return tuple(map(int, squeezed_shape)) shape = tuple(shape) if not all(isinstance(ele, (int, np.integer)) for ele in shape): raise ValueError("Shapes entries must be integers") return shape
[docs] def parse_one_d(maybe_vector: OneDArray) -> np.ndarray: """Parse flexible type into numpy array. Examples -------- >>> int_scalar = 1 >>> parse_one_d(int_scalar) array([1]) >>> np_int_scalar = np.int8(1) >>> parse_one_d(np_int_scalar) array([1], dtype=int8) >>> float_scalar = 1.0 >>> parse_one_d(float_scalar) array([1.]) >>> np_float_scalar = 1.0 >>> parse_one_d(np_float_scalar) array([1.]) >>> example_list = [1.0, 1.0] >>> parse_one_d(example_list) array([1., 1.]) >>> extra_dims = np.array([[1, 1]]) >>> parse_one_d(extra_dims) array([1, 1]) """ if isinstance(maybe_vector, (int, float, np.integer, np.floating)): return np.array([maybe_vector]) if isinstance(maybe_vector, np.ndarray): squeezed_vector = maybe_vector.squeeze() if squeezed_vector.ndim == 1: return squeezed_vector elif squeezed_vector.ndim == 0: # Squeezed to scalar so force vector return squeezed_vector[None] else: raise ValueError( "Vector can have at most one non-trivial dimension but " f"had shape {maybe_vector.shape}" ) return np.array(maybe_vector)
@overload def to_memory_order( array: np.ndarray, order: MemoryLayout, copy: bool = False ) -> np.ndarray: pass @overload def to_memory_order( array: sparse.coo_matrix, order: MemoryLayout, copy: bool = False ) -> sparse.coo_matrix: pass
[docs] def to_memory_order( array: Union[np.ndarray, sparse.coo_matrix], order: MemoryLayout, copy: bool = False ) -> Union[np.ndarray, sparse.coo_matrix]: """Convert an array to the specified memory layout. Parameters ---------- array: Data to ensure matches memory order. order: Desired memory order. copy: Whether to force a copy even if data already in supported memory order. Examples -------- >>> c_order = np.arange(16).reshape((2, 2, 2, 2)) >>> c_order.flags["C_CONTIGUOUS"] True >>> to_memory_order(c_order, "F").flags["F_CONTIGUOUS"] True """ if copy: # This could be slightly optimized # in worst case two copies occur array = array.copy() if isinstance(array, sparse.coo_matrix): return array if order == "F": return np.asfortranarray(array) elif order == "C": return np.ascontiguousarray(array) raise ValueError(f"Unsupported order {order}")
if __name__ == "__main__": import doctest # pragma: no cover doctest.testmod() # pragma: no cover