Skip to content
Snippets Groups Projects
Select Git revision
  • 205aca8fdcdafe5303166295a2b34e6e7c07868b
  • main default protected
  • 0.1 protected
3 results

array.hpp

Blame
  • array.hpp 10.19 KiB
    #ifndef ARRAY_HPP_
    #define ARRAY_HPP_
    
    #include <vector>
    #include <array>
    #include <type_traits>
    #include <initializer_list>
    #include <stdexcept>
    
    constexpr size_t Dynamic = SIZE_MAX;
    
    /*! Array dimensions and strides: known at compile-time & dynamic
     */
    template <size_t...>
    struct Shape;
    
    template <>
    struct Shape<>
    {
        static const size_t head = 1;
        static const size_t ndim = 0;
        static const size_t size = 1;
        static const bool fixed = true;
        typedef Shape<> Tail;
        static constexpr size_t dim(size_t i) { return 0; }
        static constexpr size_t stride(size_t i) { return 1; }
    };
    
    template <size_t Dim0, size_t... Dims>
    struct Shape<Dim0,Dims...>
    {
        static const size_t head = Dim0;
        typedef Shape<Dims...> Tail;
        static const size_t ndim = 1 + sizeof...(Dims);
        static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
        static const size_t size = fixed ? Dim0*Tail::size : Dynamic;
    
        static constexpr size_t dim(size_t i)
            {
                return (i == 0) ? Dim0 : Tail::dim(i-1);
            }
    
        static constexpr size_t stride(size_t i)
            {
                if (i == 0) {
                    constexpr size_t tail_stride = Tail::stride(0);
    
                    if (Tail::head == Dynamic || tail_stride == Dynamic)
                        return Dynamic;
    
                    return Tail::head * tail_stride;
                } else {
                    return Tail::stride(i-1);
                }
            }
    
        static constexpr std::array<size_t, ndim> shape() { return {Dim0, Dims...}; }
    
        static constexpr std::array<size_t, ndim> strides()
            {
                std::array<size_t, ndim> m;
    
                for (size_t i = 0; i < ndim; ++i)
                    m[i] = stride(i);
                return m;
            }
    };
    
    namespace detail
    {
        /*! Extract Nth tail shape
         */
        template <size_t I, typename Shape>
        struct TailNth;
    
        template <typename Shape>
        struct TailNth<0, Shape>
        {
            using type = Shape;
        };
    
        template <size_t I, typename Shape>
        struct TailNth
        {
            using type = typename TailNth<I-1, typename Shape::Tail>::type;
        };
    
        /*! Base class for compile-time fixed size arrays
         */
        template <typename Shape>
        class fixed_base
        {
        protected:
            void data_bounds_check(size_t size) const
                {
    #ifndef NO_BOUNDS_CHECK
                    if (Shape::size > size)
                        throw std::out_of_range("data array too small");
    #endif
                }
    
        public:
            fixed_base() {}
    
            constexpr size_t dim(size_t i) const { return Shape::dim(i); }
            constexpr size_t stride(size_t i) const { return Shape::stride(i); }
    
            constexpr std::array<size_t, Shape::ndim> shape() const { return Shape::shape(); }
            constexpr std::array<size_t, Shape::ndim> strides() const { return Shape::strides(); }
            constexpr size_t size() const { return Shape::size; }
        };
    
        /*! Base class for compile-time dynamic size arrays
         */
        template <typename Shape>
        class dynamic_base
        {
        protected:
            std::array<size_t, Shape::ndim> shape_;
            std::array<size_t, Shape::ndim> strides_;
    
            void data_bounds_check(size_t size) const
                {
    #ifndef NO_BOUNDS_CHECK
                    size_t total = 1;
    
                    for (size_t i = 0; i < Shape::ndim; ++i) {
                        total *= dim(i);
                        if (shape_[i] != dim(i))
                            throw std::out_of_range("mismatch with fixed shape");
                    }
                    if (total > size)
                        throw std::out_of_range("data array too small");
    #endif
            }
    
        public:
            dynamic_base(const std::array<size_t, Shape::ndim> shape)
                : shape_(shape)
                {
                    for (size_t i = Shape::ndim; i > 0; --i) {
                        if (Shape::stride(i-1) != Dynamic)
                            strides_[i-1] = Shape::stride(i-1);
                        else if (i == Shape::ndim)
                            strides_[i-1] = 1;
                        else
                            strides_[i-1] = shape_[i] * strides_[i];
                    }
                }
    
            size_t dim(size_t axis) const
                {
    #ifndef NO_BOUNDS_CHECK
                    if (axis >= Shape::ndim)
                        throw std::out_of_range("array axis must be < ndim");
    #endif
                    size_t n = Shape::dim(axis);
                    return n == Dynamic ? shape_[axis] : n;
                }
    
            size_t stride(size_t axis) const
                {
    #ifndef NO_BOUNDS_CHECK
                    if (axis >= Shape::ndim)
                        throw std::out_of_range("array axis must be < ndim");
    #endif
                    size_t n = Shape::stride(axis);
                    return n == Dynamic ? strides_[axis] : n;
                }
    
            const std::array<size_t, Shape::ndim>& shape() const { return shape_; }
            const std::array<size_t, Shape::ndim>& strides() const { return strides_; }
            size_t size() const
                {
                    size_t size = 1;
                    for (size_t i = 0; i < Shape::ndim; ++i)
                        size *= dim(i);
                    return size;
                }
        };
    }
    
    /*! Simple data-by-reference strided array class a la Fortran
     */
    template <typename Scalar, typename Shape, size_t Alignment=0>
    class Array : public std::conditional_t<Shape::fixed, detail::fixed_base<Shape>, detail::dynamic_base<Shape> >
    {
    private:
        typedef std::conditional_t<Shape::fixed, detail::fixed_base<Shape>, detail::dynamic_base<Shape> > Base;
    
        Scalar *data_;
    
        template <size_t axis>
        static constexpr Eigen::Index eigen_shape()
            {
                /* bad Shape handled in matrix() */
                size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
                return (n == Dynamic) ? Eigen::Dynamic : n;
            }
    
        static constexpr int eigen_alignment()
            {
                switch (alignment()) {
                case 128: return Eigen::Aligned128;
                case 64: return Eigen::Aligned64;
                case 32: return Eigen::Aligned32;
                case 16: return Eigen::Aligned16;
                case 8: return Eigen::Aligned8;
                default: return Eigen::Unaligned;
                }
            };
    
    public:
        Array(Scalar *data, size_t size, const std::array<size_t, Shape::ndim> shape)
            : Base(shape), data_(data)
            {
                static_assert(!Shape::fixed, "array shape must not be compile-time fixed");
                Base::data_bounds_check(size);
    #ifndef NO_BOUNDS_CHECK
                if (reinterpret_cast<intptr_t>(data) % alignment() != 0)
                    throw std::runtime_error("data is not aligned");
    #endif
            }
    
        Array(Scalar *data, size_t size)
            : Base(), data_(data)
            {
                static_assert(Shape::fixed, "array shape must be compile-time fixed");
                Base::data_bounds_check(size);
    #ifndef NO_BOUNDS_CHECK
                if (reinterpret_cast<intptr_t>(data) % alignment() != 0)
                    throw std::runtime_error("data is not aligned");
    #endif
            }
    
        template <typename... Idx>
        Scalar& operator()(Idx... idxs)
            {
                static_assert(sizeof...(Idx) == Shape::ndim,
                              "number of indices must be equal to the number of dimensions");
                return data_[index(idxs...)];
            }
    
        template <typename... Idx>
        const Scalar&& operator()(Idx... idxs) const
            {
                static_assert(sizeof...(Idx) == Shape::ndim,
                              "number of indices must be equal to the number of dimensions");
                return data_[index(idxs...)];
            }
    
        Scalar* data() { return data_; }
    
        static constexpr int alignment()
            {
                if (Alignment > 0)
                    return Alignment;
                else if (sizeof(Scalar) % 16 == 0)
                    return 16;
                else if (sizeof(Scalar) % 8 == 0)
                    return 8;
                else
                    return 1;
            };
    
        template <typename... Idx>
        size_t index(Idx... idxs) const
            {
                static_assert(sizeof...(idxs) <= Shape::ndim,
                              "number of indices must be smaller than the number of dimensions");
    
                const std::array<size_t, Shape::ndim> m{idxs...};
                size_t idx = 0;
    
                for (size_t i = 0; i < sizeof...(idxs); ++i) {
                    idx += Base::stride(i) * m[i];
    #ifndef NO_BOUNDS_CHECK
                    if (m[i] >= Base::dim(i))
                        throw std::out_of_range("index out of bounds");
    #endif
                }
    
                return idx;
            }
    
        /*! Extract partial matrix (compile-time fixed size) */
        template <typename... Idx>
        typename std::enable_if<detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
                                Array<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type > >::type
        part(Idx... idxs) const
            {
                constexpr size_t nidxs = sizeof...(Idx);
                static_assert(nidxs < Shape::ndim,
                              "number of indices must be smaller than the number of dimensions");
    
                size_t offset = index(idxs...);
                return {data_ + offset, Base::size() - offset};
            }
    
        /*! Extract partial matrix (compile-time dynamic size) */
        template <typename... Idx>
        typename std::enable_if<!detail::TailNth<sizeof...(Idx), Shape>::type::fixed,
                                Array<Scalar, typename detail::TailNth<sizeof...(Idx), Shape>::type > >::type
        part(Idx... idxs) const
            {
                constexpr size_t nidxs = sizeof...(Idx);
                static_assert(nidxs < Shape::ndim,
                              "number of indices must be small than the number of dimensions");
    
                size_t offset = index(idxs...);
                std::array<size_t, Shape::ndim - nidxs> shape;
                for (size_t i = nidxs; i < Shape::ndim; ++i)
                    shape[i - nidxs] = Base::dim(i);
                return {data_ + offset, Base::size() - offset, shape};
            }
    
        typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix;
        typedef Eigen::Map<EigenMatrix, eigen_alignment()> EigenMap;
    
        EigenMap matrix() const
            {
                static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
                return {data_, static_cast<Eigen::Index>(Base::dim(0)), static_cast<Eigen::Index>(Base::dim(1))};
            }
    
        operator EigenMap() const { return matrix(); }
    };
    
    
    #endif