Skip to content
Snippets Groups Projects
array.hpp 18.30 KiB
/* usadelndsoc
 *
 * Copyright © 2022 Pauli Virtanen
 *    @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
 *
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */
#ifndef ARRAY_HPP_
#define ARRAY_HPP_

#include <vector>
#include <array>
#include <type_traits>
#include <initializer_list>
#include <stdexcept>

#include <Eigen/Core>

namespace array {

constexpr size_t Dynamic = SIZE_MAX;

#ifdef __GNUC__
#define ARRAY_RESTRICT __restrict
#else
#ifdef EIGEN_RESTRICT
#define ARRAY_RESTRICT EIGEN_RESTRICT
#else
#define ARRAY_RESTRICT
#endif
#endif

#ifdef NDEBUG
constexpr bool bounds_check = false;
#else
constexpr bool bounds_check = true;
#endif

/**
 * \class Shape
 *
 * \brief Array dimensions and strides. Either known at compile-time or dynamic.
 *
 * The Shape template keeps track of the dimensions and strides of a
 * N-dimensional row-major array, with the shape given in its template argument
 * pack.
 */
template <size_t...>
struct Shape;

template <>
struct Shape<>
{
    static const size_t head = 1;
    static const size_t ndim = 0;
    static const size_t size = 1;
    static const bool fixed = true;
    typedef Shape<> Tail;
    static constexpr size_t dim(size_t i) { return 0; }
    static constexpr size_t stride(size_t i) { return 1; }

    template <typename Shape2>
    static constexpr bool compatible = Shape2::ndim == 0;
};

template <size_t Dim0, size_t... Dims>
struct Shape<Dim0,Dims...>
{
    static const size_t head = Dim0;
    typedef Shape<Dims...> Tail;
    static const size_t ndim = 1 + sizeof...(Dims);
    static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
    static const size_t size = fixed ? Dim0*Tail::size : Dynamic;

    static constexpr size_t dim(size_t i)
        {
            return (i == 0) ? Dim0 : Tail::dim(i-1);
        }

    static constexpr size_t stride(size_t i)
        {
            if (i == 0) {
                constexpr size_t tail_stride = Tail::stride(0);

                if (Tail::head == Dynamic || tail_stride == Dynamic)
                    return Dynamic;

                return Tail::head * tail_stride;
            } else {
                return Tail::stride(i-1);
            }
        }

    static constexpr std::array<size_t, ndim> shape() { return {Dim0, Dims...}; }

    static constexpr std::array<size_t, ndim> strides()
        {
            std::array<size_t, ndim> m;

            for (size_t i = 0; i < ndim; ++i)
                m[i] = stride(i);
            return m;
        }

    template <typename Shape2>
    static constexpr bool compatible = (ndim == Shape2::ndim) &&
        (head == Dynamic || Shape2::head == Dynamic || head == Shape2::head) &&
        Tail::template compatible<typename Shape2::Tail>;
};

namespace detail
{
    /** \class TailNth
     *
     * \brief Extract Nth tail shape.
     *
     * Helper template class to convert e.g. Shape<1,2,3,4,5>
     * to its 2-nd tail Shape<3,4,5>.
     */
    template <size_t I, typename Shape>
    struct TailNth;

    template <typename Shape>
    struct TailNth<0, Shape>
    {
        using type = Shape;
    };

    template <size_t I, typename Shape>
    struct TailNth
    {
        using type = typename TailNth<I-1, typename Shape::Tail>::type;
    };

    template <typename Shape>
    inline size_t size_from_shape(const std::array<size_t, Shape::ndim>& shape)
    {
        size_t total = 1;

        for (size_t i = 0; i < Shape::ndim; ++i) {
            if (Shape::dim(i) == Dynamic)
                total *= shape[i];
            else
                total *= Shape::dim(i);
        }

        return total;
    }

    /**
     * \class fixed_base
     *
     * \brief Base class for compile-time fixed size arrays
     *
     * The shape and stride information for arrays with known size.  It is kept
     * fully in compile-time quantities.
     */
    template <typename Shape>
    class fixed_base
    {
    protected:
        void data_bounds_check(size_t size) const
            {
                if constexpr(bounds_check) {
                    if (Shape::size > size)
                        throw std::out_of_range("data array too small");
                }
            }

    public:
        fixed_base() {}

        fixed_base(const std::array<size_t, Shape::ndim>& shape)
            {
                if constexpr(bounds_check) {
                    for (size_t i = 0; i < Shape::ndim; ++i)
                        if (Shape::dim(i) != shape[i])
                            throw std::out_of_range("mismatch with fixed shape");
                }
            }

        constexpr size_t dim(size_t i) const { return Shape::dim(i); }
        constexpr size_t stride(size_t i) const { return Shape::stride(i); }

        constexpr std::array<size_t, Shape::ndim> shape() const { return Shape::shape(); }
        constexpr std::array<size_t, Shape::ndim> strides() const { return Shape::strides(); }
        constexpr size_t size() const { return Shape::size; }
    };

    /**
     * \class dynamic_base
     *
     * \brief Base class for compile-time dynamic size arrays
     *
     * Shape and stride information for dynamic arrays, stored in member
     * variables.
     */
    template <typename Shape>
    class dynamic_base
    {
    private:
        std::array<size_t, Shape::ndim> shape_;
        std::array<size_t, Shape::ndim> strides_;

        void init_strides()
            {
                for (size_t i = Shape::ndim; i > 0; --i) {
                    if (Shape::stride(i-1) != Dynamic)
                        strides_[i-1] = Shape::stride(i-1);
                    else if (i == Shape::ndim)
                        strides_[i-1] = 1;
                    else
                        strides_[i-1] = shape_[i] * strides_[i];
                }
            }

    protected:
        void data_bounds_check(size_t size) const
            {
                if constexpr(bounds_check) {
                    size_t total = 1;

                    for (size_t i = 0; i < Shape::ndim; ++i) {
                        total *= dim(i);
                        if (shape_[i] != dim(i))
                            throw std::out_of_range("mismatch with fixed shape");
                    }
                    if (total > size)
                        throw std::out_of_range("data array too small");
                }
        }

    public:
        dynamic_base()
            {
                for (size_t i = 0; i < Shape::ndim; ++i)
                    shape_[i] = (Shape::dim(i) == Dynamic) ? 0 : Shape::dim(i);
                init_strides();
            }

        dynamic_base(const std::array<size_t, Shape::ndim>& shape)
            : shape_(shape)
            {
                init_strides();
            }

        size_t dim(size_t axis) const
            {
                if constexpr(bounds_check) {
                    if (axis >= Shape::ndim)
                        throw std::out_of_range("array axis must be < ndim");
                }
                size_t n = Shape::dim(axis);
                return n == Dynamic ? shape_[axis] : n;
            }

        size_t stride(size_t axis) const
            {
                if constexpr(bounds_check) {
                    if (axis >= Shape::ndim)
                        throw std::out_of_range("array axis must be < ndim");
                }
                size_t n = Shape::stride(axis);
                return n == Dynamic ? strides_[axis] : n;
            }

        const std::array<size_t, Shape::ndim>& shape() const { return shape_; }
        const std::array<size_t, Shape::ndim>& strides() const { return strides_; }
        size_t size() const
            {
                size_t size = 1;
                for (size_t i = 0; i < Shape::ndim; ++i)
                    size *= dim(i);
                return size;
            }
    };

    template <typename Shape>
    using Base = std::conditional_t<Shape::fixed, detail::fixed_base<Shape>, detail::dynamic_base<Shape> >;
    template <size_t... Dims>
    using ShapeType = Shape<Dims...>;
}

/**
 * \class ArrayView
 *
 * \brief N-dimensional view to 1-dimensional array data
 *
 * Array view for index mapping to 1-dim array.
 */
template <typename Scalar_, typename Shape_, size_t Alignment_=0>
class ArrayView : public detail::Base<Shape_>
{
public:
    typedef Scalar_ Scalar;
    typedef Shape_ Shape;
    static constexpr size_t Alignment = Alignment_;

protected:
    typedef detail::Base<Shape> Base;

    Scalar *ARRAY_RESTRICT data_;

    template <size_t axis>
    static constexpr Eigen::Index eigen_shape()
        {
            /* bad Shape handled in matrix() */
            size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
            return (n == Dynamic) ? Eigen::Dynamic : n;
        }

    static constexpr int eigen_alignment()
        {
            switch (alignment()) {
            case 128: return Eigen::Aligned128;
            case 64: return Eigen::Aligned64;
            case 32: return Eigen::Aligned32;
            case 16: return Eigen::Aligned16;
            case 8: return Eigen::Aligned8;
            default: return Eigen::Unaligned;
            }
        };

    static constexpr bool compatible_align(size_t align)
        {
            return Alignment == 0 || (align % Alignment) == 0;
        }

    static constexpr int eigen_major()
        {
            constexpr size_t rows = Shape::head, cols = Shape::Tail::head;
            if (rows != 1 && cols == 1)
                /* Eigen expects Nx1 is column-major; it's compatible with row-major */
                return Eigen::ColMajor;
            else
                return Eigen::RowMajor;
        }

    void data_check(size_t size) const
        {
            Base::data_bounds_check(size);
            if constexpr(bounds_check) {
                if (reinterpret_cast<intptr_t>(data_) % alignment() != 0)
                    throw std::runtime_error("data is not aligned");
            }
        };

    static_assert(Shape::template compatible<Shape>);
public:
    template <typename S, size_t A, typename C = Scalar>
    ArrayView(const ArrayView<C,S,A>& other,
              typename std::enable_if_t<(Shape::fixed && S::fixed) && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
        : Base(), data_(other.data())
        {}

    template <typename S, size_t A, typename C = Scalar>
    ArrayView(const ArrayView<C,S,A>& other,
              typename std::enable_if_t<!(Shape::fixed && S::fixed) && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
        : Base(other.shape()), data_(other.data())
        {}

    template <typename S, size_t A, typename C = Scalar>
    ArrayView(const ArrayView<std::remove_const_t<C>,S,A>& other,
              typename std::enable_if_t<(Shape::fixed && S::fixed) && std::is_const<C>::value && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
        : Base(), data_(other.data())
        {}

    template <typename S, size_t A, typename C = Scalar>
    ArrayView(const ArrayView<std::remove_const_t<C>,S,A>& other,
              typename std::enable_if_t<!(Shape::fixed && S::fixed) && std::is_const<C>::value && Shape::template compatible<S> && compatible_align(A), void*> dummy = 0)
        : Base(other.shape()), data_(other.data())
        {}

    ArrayView(Scalar *data, size_t size, const std::array<size_t, Shape::ndim> shape)
        : Base(shape), data_(data)
        {
            data_check(size);
        }

    ArrayView(Scalar *data, size_t size)
        : Base(), data_(data)
        {
            static_assert(Shape::fixed, "array shape must be compile-time fixed");
            data_check(size);
        }

    template <typename... Idx>
    Scalar& operator()(Idx... idxs)
        {
            static_assert(sizeof...(Idx) == Shape::ndim,
                          "number of indices must be equal to the number of dimensions");
            return data_[index(idxs...)];
        }

    template <typename... Idx>
    const Scalar& operator()(Idx... idxs) const
        {
            static_assert(sizeof...(Idx) == Shape::ndim,
                          "number of indices must be equal to the number of dimensions");
            return data_[index(idxs...)];
        }

    Scalar* data() const { return data_; }

    static constexpr int alignment()
        {
            if (Alignment > 0)
                return Alignment;
            else if (sizeof(Scalar) % 16 == 0)
                return 16;
            else if (sizeof(Scalar) % 8 == 0)
                return 8;
            else
                return 1;
        };

    template <typename... Idx>
    size_t index(Idx... idxs) const
        {
            static_assert(sizeof...(idxs) <= Shape::ndim,
                          "number of indices must be smaller than the number of dimensions");

            const std::array<size_t, Shape::ndim> m{static_cast<size_t>(idxs)...};
            size_t idx = 0;

            for (size_t i = 0; i < sizeof...(idxs); ++i) {
                idx += Base::stride(i) * m[i];
                if constexpr(bounds_check) {
                    if (m[i] >= Base::dim(i))
                        throw std::out_of_range("index out of bounds");
                }
            }

            return idx;
        }

    /** Extract sub-array (compile-time fixed size) */
    template <typename... Idx>
    typename std::enable_if<detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
                            ArrayView<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type, Alignment> >::type
    part(Idx... idxs) const
        {
            constexpr size_t nidxs = sizeof...(Idx);
            static_assert(nidxs < Shape::ndim,
                          "number of indices must be smaller than the number of dimensions");

            size_t offset = index(idxs...);
            return {data_ + offset, Base::size() - offset};
        }

    /** Extract sub-array (compile-time dynamic size) */
    template <typename... Idx>
    typename std::enable_if<!detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
                            ArrayView<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type, Alignment> >::type
    part(Idx... idxs) const
        {
            constexpr size_t nidxs = sizeof...(Idx);
            static_assert(nidxs < Shape::ndim,
                          "number of indices must be small than the number of dimensions");

            size_t offset = index(idxs...);
            std::array<size_t, Shape::ndim - nidxs> shape;
            for (size_t i = nidxs; i < Shape::ndim; ++i)
                shape[i - nidxs] = Base::dim(i);
            return {data_ + offset, Base::size() - offset, shape};
        }

    typedef Eigen::Matrix<typename std::remove_cv<Scalar>::type, eigen_shape<0>(), eigen_shape<1>(), eigen_major()> EigenMatrix;
    typedef Eigen::Map<typename std::conditional<std::is_const<Scalar>::value, const EigenMatrix, EigenMatrix>::type, eigen_alignment()> EigenMap;

    EigenMap matrix() const
        {
            static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
            return {data_, static_cast<Eigen::Index>(Base::dim(0)), static_cast<Eigen::Index>(Base::dim(1))};
        }

    EigenMap vector() const
        {
            static_assert(Shape::ndim == 1, "vector must be one-dimensional");
            return {data_, static_cast<Eigen::Index>(Base::dim(0))};
        }

    operator EigenMap() const {
        if constexpr(Shape::ndim == 1) {
            return vector();
        } else {
            return matrix();
        }
    }

    /** Extract sub-matrix */
    template <typename... Idx>
    auto submatrix(Idx... idxs) const { return part(idxs...).matrix(); }

    /** Reshape (compile-time fixed size) */
    template <size_t... Dims>
    typename std::enable_if<detail::ShapeType<Dims...>::fixed,
                            ArrayView<Scalar, detail::ShapeType<Dims...>, Alignment> >::type
    reshape() const
        {
            static_assert(Shape::size == detail::ShapeType<Dims...>::size, "incompatible shapes");
            return {data(), Base::size()};
        }

    /** Reshape (compile-time dynamic size) */
    template <size_t... Dims>
    typename std::enable_if<!detail::ShapeType<Dims...>::fixed,
                            ArrayView<Scalar, detail::ShapeType<Dims...>, Alignment> >::type
    reshape(const std::array<size_t, detail::ShapeType<Dims...>::ndim> shape) const
        {
            static_assert(Shape::size == detail::ShapeType<Dims...>::size, "incompatible shapes");
            return {data(), Base::size(), shape};
        }

    /** Fill with value */
    void fill(const Scalar& value)
        {
            const size_t n = Base::size();

            for (size_t i = 0; i < n; ++i)
                data_[i] = value;
        }
};

/**
 * \class Array
 *
 * \brief Array that is also the storage for the data.
 */
template <typename Scalar,
          typename Shape,
          typename Storage_=std::vector<Scalar>,
          size_t Alignment = ArrayView<Scalar, Shape>::Alignment>
class Array : public ArrayView<Scalar, Shape, Alignment>
{
public:
    typedef Storage_ Storage;

private:
    typedef ArrayView<Scalar, Shape, Alignment> BaseType;
    static_assert(sizeof(Scalar) % sizeof(typename Storage::value_type) == 0, "incommensurate sizes");
    static constexpr size_t units = sizeof(Scalar) / sizeof(typename Storage::value_type);

    Storage storage_;

public:
    Array(const std::array<size_t, Shape::ndim>& shape)
        : BaseType(nullptr, detail::size_from_shape<Shape>(shape), shape),
          storage_(units * detail::size_from_shape<Shape>(shape))
        {
            BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
            BaseType::data_check(storage_.size() / units);
        }

    Array()
        : BaseType(nullptr, Shape::size),
          storage_(units * Shape::size)
        {
            BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
            BaseType::data_check(storage_.size() / units);
        }

    Array(Array&& other)
        : BaseType(nullptr, other.size(), other.shape()),
          storage_(std::move(other.storage()))
        {
            BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
        }

    Array(Storage&& storage, const std::array<size_t, Shape::ndim>& shape)
        : BaseType(nullptr, detail::size_from_shape<Shape>(shape), shape),
          storage_(std::move(storage))
        {
            BaseType::data_ = reinterpret_cast<Scalar *>(storage_.data());
            BaseType::data_check(storage_.size() / units);
        }

    Storage& storage() { return storage_; }
};

} // namespace array

#endif