array.hpp 18.30 KiB
/* usadelndsoc
*
* Copyright © 2022 Pauli Virtanen
* @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
#ifndef ARRAY_HPP_
#define ARRAY_HPP_
#include <vector>
#include <array>
#include <type_traits>
#include <initializer_list>
#include <stdexcept>
#include <Eigen/Core>
namespace array {
constexpr size_t Dynamic = SIZE_MAX;
#ifdef __GNUC__
#define ARRAY_RESTRICT __restrict
#else
#ifdef EIGEN_RESTRICT
#define ARRAY_RESTRICT EIGEN_RESTRICT
#else
#define ARRAY_RESTRICT
#endif
#endif
#ifdef NDEBUG
constexpr bool bounds_check = false;
#else
constexpr bool bounds_check = true;
#endif
/**
* \class Shape
*
* \brief Array dimensions and strides. Either known at compile-time or dynamic.
*
* The Shape template keeps track of the dimensions and strides of a
* N-dimensional row-major array, with the shape given in its template argument
* pack.
*/
template <size_t...>
struct Shape;
template <>
struct Shape<>
{
static const size_t head = 1;
static const size_t ndim = 0;
static const size_t size = 1;
static const bool fixed = true;
typedef Shape<> Tail;
static constexpr size_t dim(size_t i) { return 0; }
static constexpr size_t stride(size_t i) { return 1; }
template <typename Shape2>
static constexpr bool compatible = Shape2::ndim == 0;
};
template <size_t Dim0, size_t... Dims>
struct Shape<Dim0,Dims...>
{
static const size_t head = Dim0;
typedef Shape<Dims...> Tail;
static const size_t ndim = 1 + sizeof...(Dims);
static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
static const size_t size = fixed ? Dim0*Tail::size : Dynamic;
static constexpr size_t dim(size_t i)
{
return (i == 0) ? Dim0 : Tail::dim(i-1);
}
static constexpr size_t stride(size_t i)
{
if (i == 0) {
constexpr size_t tail_stride = Tail::stride(0);
if (Tail::head == Dynamic || tail_stride == Dynamic)
return Dynamic;
return Tail::head * tail_stride;
} else {
return Tail::stride(i-1);
}
}
static constexpr std::array<size_t, ndim> shape() { return {Dim0, Dims...}; }
static constexpr std::array<size_t, ndim> strides()
{
std::array<size_t, ndim> m;
for (size_t i = 0; i < ndim; ++i)
m[i] = stride(i);
return m;
}
template <typename Shape2>
static constexpr bool compatible = (ndim == Shape2::ndim) &&
(head == Dynamic || Shape2::head == Dynamic || head == Shape2::head) &&
Tail::template compatible<typename Shape2::Tail>;
};
namespace detail
{
/** \class TailNth
*
* \brief Extract Nth tail shape.
*
* Helper template class to convert e.g. Shape<1,2,3,4,5>
* to its 2-nd tail Shape<3,4,5>.
*/
template <size_t I, typename Shape>
struct TailNth;
template <typename Shape>
struct TailNth<0, Shape>
{
using type = Shape;
};
template <size_t I, typename Shape>
struct TailNth
{
using type = typename TailNth<I-1, typename Shape::Tail>::type;
};
template <typename Shape>
inline size_t size_from_shape(const std::array<size_t, Shape::ndim>& shape)
{
size_t total = 1;
for (size_t i = 0; i < Shape::ndim; ++i) {
if (Shape::dim(i) == Dynamic)
total *= shape[i];
else
total *= Shape::dim(i);
}
return total;
}
/**
* \class fixed_base
*
* \brief Base class for compile-time fixed size arrays
*
* The shape and stride information for arrays with known size. It is kept
* fully in compile-time quantities.
*/
template <typename Shape>
class fixed_base
{
protected:
void data_bounds_check(size_t size) const
{
if constexpr(bounds_check) {
if (Shape::size > size)
throw std::out_of_range("data array too small");
}
}
public:
fixed_base() {}
fixed_base(const std::array<size_t, Shape::ndim>& shape)
{
if constexpr(bounds_check) {
for (size_t i = 0; i < Shape::ndim; ++i)
if (Shape::dim(i) != shape[i])
throw std::out_of_range("mismatch with fixed shape");
}
}
constexpr size_t dim(size_t i) const { return Shape::dim(i); }
constexpr size_t stride(size_t i) const { return Shape::stride(i); }
constexpr std::array<size_t, Shape::ndim> shape() const { return Shape::shape(); }
constexpr std::array<size_t, Shape::ndim> strides() const { return Shape::strides(); }
constexpr size_t size() const { return Shape::size; }
};
/**
* \class dynamic_base
*
* \brief Base class for compile-time dynamic size arrays
*
* Shape and stride information for dynamic arrays, stored in member
* variables.
*/
template <typename Shape>
class dynamic_base
{
private:
std::array<size_t, Shape::ndim> shape_;
std::array<size_t, Shape::ndim> strides_;
void init_strides()
{
for (size_t i = Shape::ndim; i > 0; --i) {
if (Shape::stride(i-1) != Dynamic)
strides_[i-1] = Shape::stride(i-1);
else if (i == Shape::ndim)
strides_[i-1] = 1;
else
strides_[i-1] = shape_[i] * strides_[i];
}
}
protected:
void data_bounds_check(size_t size) const
{
if constexpr(bounds_check) {
size_t total = 1;
for (size_t i = 0; i < Shape::ndim; ++i) {
total *= dim(i);
if (shape_[i] != dim(i))
throw std::out_of_range("mismatch with fixed shape");
}
if (total > size)
throw std::out_of_range("data array too small");
}
}
public:
dynamic_base()
{
for (size_t i = 0; i < Shape::ndim; ++i)
shape_[i] = (Shape::dim(i) == Dynamic) ? 0 : Shape::dim(i);
init_strides();
}
dynamic_base(const std::array<size_t, Shape::ndim>& shape)
: shape_(shape)
{
init_strides();
}
size_t dim(size_t axis) const
{
if constexpr(bounds_check) {
if (axis >= Shape::ndim)
throw std::out_of_range("array axis must be < ndim");
}
size_t n = Shape::dim(axis);
return n == Dynamic ? shape_[axis] : n;
}
size_t stride(size_t axis) const
{
if constexpr(bounds_check) {
if (axis >= Shape::ndim)
throw std::out_of_range("array axis must be < ndim");
}
size_t n = Shape::stride(axis);
return n == Dynamic ? strides_[axis] : n;
}
const std::array<size_t, Shape::ndim>& shape() const { return shape_; }
const std::array<size_t, Shape::ndim>& strides() const { return strides_; }
size_t size() const
{
size_t size = 1;
for (size_t i = 0; i < Shape::ndim; ++i)
size *= dim(i);
return size;
}
};
template <typename Shape>
using Base = std::conditional_t<Shape::fixed, detail::fixed_base<Shape>, detail::dynamic_base<Shape> >;
template <size_t... Dims>
using ShapeType = Shape<Dims...>;
}
/**
* \class ArrayView
*
* \brief N-dimensional view to 1-dimensional array data
*
* Array view for index mapping to 1-dim array.
*/
template <typename Scalar_, typename Shape_, size_t Alignment_=0>
class ArrayView : public detail::Base<Shape_>
{
public:
typedef Scalar_ Scalar;
typedef Shape_ Shape;
static constexpr size_t Alignment = Alignment_;
protected:
typedef detail::Base<Shape> Base;
Scalar *ARRAY_RESTRICT data_;
template <size_t axis>
static constexpr Eigen::Index eigen_shape()
{
/* bad Shape handled in matrix() */
size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
return (n == Dynamic) ? Eigen::Dynamic : n;
}
static constexpr int eigen_alignment()
{
switch (alignment()) {
case 128: return Eigen::Aligned128;
case 64: return Eigen::Aligned64;
case 32: return Eigen::Aligned32;
case 16: return Eigen::Aligned16;
case 8: return Eigen::Aligned8;
default: return Eigen::Unaligned;
}
};
static constexpr bool compatible_align(size_t align)
{
return Alignment == 0 || (align % Alignment) == 0;
}
static constexpr int eigen_major()
{
constexpr size_t rows = Shape::head, cols = Shape::Tail::head;
if (rows != 1 && cols == 1)
/* Eigen expects Nx1 is column-major; it's compatible with row-major */
return Eigen::ColMajor;
else
return Eigen::RowMajor;
}
void data_check(size_t size) const
{
Base::data_bounds_check(size);
if constexpr(bounds_check) {
if (reinterpret_cast<intptr_t>(data_) % alignment() != 0)
throw std::runtime_error("data is not aligned");
}
};
static_assert(Shape::template compatible<Shape>);
public:
template <typename S, size_t A, typename C = Scalar>
ArrayView(const ArrayView<C,S,A>& other,
typename std::enable_if_t<(Shape::fixed && S::fixed) && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
: Base(), data_(other.data())
{}
template <typename S, size_t A, typename C = Scalar>
ArrayView(const ArrayView<C,S,A>& other,
typename std::enable_if_t<!(Shape::fixed && S::fixed) && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
: Base(other.shape()), data_(other.data())
{}
template <typename S, size_t A, typename C = Scalar>
ArrayView(const ArrayView<std::remove_const_t<C>,S,A>& other,
typename std::enable_if_t<(Shape::fixed && S::fixed) && std::is_const<C>::value && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
: Base(), data_(other.data())
{}
template <typename S, size_t A, typename C = Scalar>
ArrayView(const ArrayView<std::remove_const_t<C>,S,A>& other,
typename std::enable_if_t<!(Shape::fixed && S::fixed) && std::is_const<C>::value && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
: Base(other.shape()), data_(other.data())
{}
ArrayView(Scalar *data, size_t size, const std::array<size_t, Shape::ndim> shape)
: Base(shape), data_(data)
{
data_check(size);
}
ArrayView(Scalar *data, size_t size)
: Base(), data_(data)
{
static_assert(Shape::fixed, "array shape must be compile-time fixed");
data_check(size);
}
template <typename... Idx>
Scalar& operator()(Idx... idxs)
{
static_assert(sizeof...(Idx) == Shape::ndim,
"number of indices must be equal to the number of dimensions");
return data_[index(idxs...)];
}
template <typename... Idx>
const Scalar& operator()(Idx... idxs) const
{
static_assert(sizeof...(Idx) == Shape::ndim,
"number of indices must be equal to the number of dimensions");
return data_[index(idxs...)];
}
Scalar* data() const { return data_; }
static constexpr int alignment()
{
if (Alignment > 0)
return Alignment;
else if (sizeof(Scalar) % 16 == 0)
return 16;
else if (sizeof(Scalar) % 8 == 0)
return 8;
else
return 1;
};
template <typename... Idx>
size_t index(Idx... idxs) const
{
static_assert(sizeof...(idxs) <= Shape::ndim,
"number of indices must be smaller than the number of dimensions");
const std::array<size_t, Shape::ndim> m{static_cast<size_t>(idxs)...};
size_t idx = 0;
for (size_t i = 0; i < sizeof...(idxs); ++i) {
idx += Base::stride(i) * m[i];
if constexpr(bounds_check) {
if (m[i] >= Base::dim(i))
throw std::out_of_range("index out of bounds");
}
}
return idx;
}
/** Extract sub-array (compile-time fixed size) */
template <typename... Idx>
typename std::enable_if<detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
ArrayView<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type, Alignment> >::type
part(Idx... idxs) const
{
constexpr size_t nidxs = sizeof...(Idx);
static_assert(nidxs < Shape::ndim,
"number of indices must be smaller than the number of dimensions");
size_t offset = index(idxs...);
return {data_ + offset, Base::size() - offset};
}
/** Extract sub-array (compile-time dynamic size) */
template <typename... Idx>
typename std::enable_if<!detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
ArrayView<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type, Alignment> >::type
part(Idx... idxs) const
{
constexpr size_t nidxs = sizeof...(Idx);
static_assert(nidxs < Shape::ndim,
"number of indices must be small than the number of dimensions");
size_t offset = index(idxs...);
std::array<size_t, Shape::ndim - nidxs> shape;
for (size_t i = nidxs; i < Shape::ndim; ++i)
shape[i - nidxs] = Base::dim(i);
return {data_ + offset, Base::size() - offset, shape};
}
typedef Eigen::Matrix<typename std::remove_cv<Scalar>::type, eigen_shape<0>(), eigen_shape<1>(), eigen_major()> EigenMatrix;
typedef Eigen::Map<typename std::conditional<std::is_const<Scalar>::value, const EigenMatrix, EigenMatrix>::type, eigen_alignment()> EigenMap;
EigenMap matrix() const
{
static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
return {data_, static_cast<Eigen::Index>(Base::dim(0)), static_cast<Eigen::Index>(Base::dim(1))};
}
EigenMap vector() const
{
static_assert(Shape::ndim == 1, "vector must be one-dimensional");
return {data_, static_cast<Eigen::Index>(Base::dim(0))};
}
operator EigenMap() const {
if constexpr(Shape::ndim == 1) {
return vector();
} else {
return matrix();
}
}
/** Extract sub-matrix */
template <typename... Idx>
auto submatrix(Idx... idxs) const { return part(idxs...).matrix(); }
/** Reshape (compile-time fixed size) */
template <size_t... Dims>
typename std::enable_if<detail::ShapeType<Dims...>::fixed,
ArrayView<Scalar, detail::ShapeType<Dims...>, Alignment> >::type
reshape() const
{
static_assert(Shape::size == detail::ShapeType<Dims...>::size, "incompatible shapes");
return {data(), Base::size()};
}
/** Reshape (compile-time dynamic size) */
template <size_t... Dims>
typename std::enable_if<!detail::ShapeType<Dims...>::fixed,
ArrayView<Scalar, detail::ShapeType<Dims...>, Alignment> >::type
reshape(const std::array<size_t, detail::ShapeType<Dims...>::ndim> shape) const
{
static_assert(Shape::size == detail::ShapeType<Dims...>::size, "incompatible shapes");
return {data(), Base::size(), shape};
}
/** Fill with value */
void fill(const Scalar& value)
{
const size_t n = Base::size();
for (size_t i = 0; i < n; ++i)
data_[i] = value;
}
};
/**
* \class Array
*
* \brief Array that is also the storage for the data.
*/
template <typename Scalar,
typename Shape,
typename Storage_=std::vector<Scalar>,
size_t Alignment = ArrayView<Scalar, Shape>::Alignment>
class Array : public ArrayView<Scalar, Shape, Alignment>
{
public:
typedef Storage_ Storage;
private:
typedef ArrayView<Scalar, Shape, Alignment> BaseType;
static_assert(sizeof(Scalar) % sizeof(typename Storage::value_type) == 0, "incommensurate sizes");
static constexpr size_t units = sizeof(Scalar) / sizeof(typename Storage::value_type);
Storage storage_;
public:
Array(const std::array<size_t, Shape::ndim>& shape)
: BaseType(nullptr, detail::size_from_shape<Shape>(shape), shape),
storage_(units * detail::size_from_shape<Shape>(shape))
{
BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
BaseType::data_check(storage_.size() / units);
}
Array()
: BaseType(nullptr, Shape::size),
storage_(units * Shape::size)
{
BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
BaseType::data_check(storage_.size() / units);
}
Array(Array&& other)
: BaseType(nullptr, other.size(), other.shape()),
storage_(std::move(other.storage()))
{
BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
}
Array(Storage&& storage, const std::array<size_t, Shape::ndim>& shape)
: BaseType(nullptr, detail::size_from_shape<Shape>(shape), shape),
storage_(std::move(storage))
{
BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
BaseType::data_check(storage_.size() / units);
}
Storage& storage() { return storage_; }
};
} // namespace array
#endif