From de33c2d97eaf446e08ffbc43c7f4019e0ab34070 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Wed, 3 Aug 2022 12:52:55 +0300 Subject: [PATCH] array: support compile-time known array sizes + restrict to row-major Let the compiler better optimize cases where array sizes are known at compile time. This implies known strides, so restrict to row-major. Slicing is then limited to picking from the first axis. --- src/action.hpp | 4 +- src/array.hpp | 217 ++++++++++++++++++++++++++++--------------------- src/main.cpp | 14 ++-- 3 files changed, 132 insertions(+), 103 deletions(-) diff --git a/src/action.hpp b/src/action.hpp index 51e45b5..feea85e 100644 --- a/src/action.hpp +++ b/src/action.hpp @@ -5,8 +5,8 @@ #include "array.hpp" -template <typename Scalar> -inline Scalar S_2(Array<Scalar,2> Q) +template <typename Scalar, typename Shape> +inline Scalar S_2(Array<Scalar, Shape> Q) { Q(0u,0u) = 1; Q(0u,1u) = 2; diff --git a/src/array.hpp b/src/array.hpp index e53a603..58a1786 100644 --- a/src/array.hpp +++ b/src/array.hpp @@ -7,104 +7,141 @@ #include <initializer_list> #include <stdexcept> +constexpr size_t Dynamic = SIZE_MAX; + + +/*! Array dimensions and strides: known at compile-time & dynamic + */ +template <size_t...> +struct Shape; + +template <> +struct Shape<> +{ + static const size_t head = 1; + static const size_t ndim = 0; + static const bool fixed = true; + static constexpr size_t dim(size_t i) { return 0; } + static constexpr size_t stride(size_t i) { return 1; } +}; + +template <size_t Dim0, size_t... Dims> +struct Shape<Dim0,Dims...> +{ + static const size_t head = Dim0; + typedef Shape<Dims...> Tail; + static const size_t ndim = 1 + Tail::ndim; + static const bool fixed = (Dim0 != Dynamic) && Tail::fixed; + + static constexpr size_t dim(size_t i) + { + return (i == 0) ? Dim0 : Tail::dim(i-1); + } + + static constexpr size_t stride(size_t i) + { + if (i == 0) { + constexpr size_t tail_stride = Tail::stride(0); + + if (Tail::head == Dynamic || tail_stride == Dynamic) + return Dynamic; + + return Tail::head * tail_stride; + } else { + return Tail::stride(i-1); + } + } + + static constexpr std::array<size_t, ndim> dims() { return {Dim0, Dims...}; } +}; /*! Simple data-by-reference strided array class a la Fortran */ -template <typename Scalar, size_t NDim, typename Storage = std::vector<Scalar> > +template <typename Scalar, typename Shape, typename Storage = std::vector<Scalar> > class Array { private: Storage& data_; - std::array<size_t, NDim> shape_; - std::array<size_t, NDim> stride_; + std::array<size_t, Shape::ndim> shape_; + std::array<size_t, Shape::ndim> stride_; size_t offset_ = 0; -public: - Array(Storage& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset) - : data_(data), shape_(shape), stride_(stride), offset_(offset) - {} - - Array(Storage& data, const std::array<size_t, NDim> shape, const bool row_major=true) - : data_(data), shape_(shape), offset_(0) + void data_bounds_check() const { - if (row_major) { - for (size_t i = NDim; i > 0; --i) { - if (i == NDim) - stride_[i-1] = 1; - else - stride_[i-1] = shape[i] * stride_[i]; - } - } else { - for (size_t i = 0; i < NDim; ++i) { - if (i == 0) - stride_[i] = 1; - else - stride_[i] = shape[i-1] * stride_[i-1]; - } - } - #ifndef NO_BOUNDS_CHECK size_t total = 1; - for (size_t i = 0; i < NDim; ++i) - total *= shape[i]; - if (total > data.size()) + + for (size_t i = 0; i < Shape::ndim; ++i) { + total *= shape_[i]; + if ((Shape::dim(i) != Dynamic && Shape::dim(i) != shape_[i]) + || shape_[i] == Dynamic) + throw std::out_of_range("mismatch with fixed shape"); + } + total += offset_; + if (total > data_.size()) throw std::out_of_range("data array too small"); #endif } - template <typename... Idx> - Scalar& operator()(Idx... idxs) { return data_[index(idxs...)]; } - - template <typename... Idx> - const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; } + void init_strides() + { + for (size_t i = Shape::ndim; i > 0; --i) { + if (i == Shape::ndim) + stride_[i-1] = 1; + else + stride_[i-1] = shape_[i] * stride_[i]; + } + for (size_t i = 0; i < Shape::ndim; ++i) { + std::cout << "stride "<< i << ": " << stride_[i] << " vs " << Shape::stride(i) << std::endl; + } + } template <size_t axis> - Array<Scalar, NDim-1> slice(size_t pos=0) const + static constexpr Eigen::Index eigen_shape() { - static_assert(axis < NDim, "invalid axis"); + static_assert(axis < Shape::ndim && axis < 2, "invalid axis"); + size_t n = (axis == 0) ? Shape::head : Shape::Tail::head; + return (n == Dynamic) ? Eigen::Dynamic : n; + } - std::array<size_t, NDim-1> shape; - std::array<size_t, NDim-1> stride; - size_t offset; +public: + Array(Storage& data, const std::array<size_t, Shape::ndim> shape, const size_t offset=0) + : data_(data), shape_(shape), offset_(offset) + { + init_strides(); + data_bounds_check(); + } - offset = offset_ + stride_[axis] * pos; + Array(Storage& data) + : data_(data), offset_(0) + { + static_assert(Shape::fixed, "array shape is not fixed"); + shape_ = Shape::dims(); + init_strides(); + data_bounds_check(); + } - for (size_t i = 0, j = 0; i < NDim; ++i) { - if (i == axis) - continue; - shape[j] = shape_[i]; - stride[j] = stride_[i]; - ++j; - } + template <typename... Idx> + Scalar& operator()(Idx... idxs) { return data_[index(idxs...)]; } - return {data_, shape, stride, offset}; - } + template <typename... Idx> + const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; } - template <size_t axis> - Array<Scalar, NDim> slice(size_t begin, size_t end) const + Array<Scalar, typename Shape::Tail> part(size_t pos) const { - static_assert(axis < NDim, "invalid axis"); - - std::array<size_t, NDim> shape; - std::array<size_t, NDim> stride; + std::array<size_t, Shape::ndim-1> shape; size_t offset; #ifndef NO_BOUNDS_CHECK - if (end < begin || begin >= shape[axis] || end > shape[axis]) - throw std::out_of_range("begin/end indices out of bounds"); + if (pos >= shape_[0]) + throw std::out_of_range("index out of bounds"); #endif - offset = offset_ + stride_[axis] * begin; + offset = offset_ + stride_[0] * pos; - for (size_t i = 0; i < NDim; ++i) { - if (i == axis) - shape[i] = end - begin; - else - shape[i] = shape_[i]; - stride[i] = stride_[i]; - } - - return {data_, shape, stride, offset}; + for (size_t i = 1; i < Shape::ndim; ++i) + shape[i-1] = shape_[i]; + return {data_, shape, offset}; } Storage& data() { return data_; } @@ -112,14 +149,19 @@ public: template <size_t axis> const size_t shape() const { - static_assert(axis < NDim, "invalid axis"); + static_assert(axis < Shape::ndim, "invalid axis"); + constexpr size_t n = Shape::dim(axis); + if (n != Dynamic) + return n; return shape_[axis]; } template <size_t axis> const size_t stride() const { - static_assert(axis < NDim, "invalid axis"); + static_assert(axis < Shape::ndim, "invalid axis"); + if (Shape::stride(axis) != Dynamic) + return Shape::stride(axis); return stride_[axis]; } @@ -128,14 +170,18 @@ public: template <typename... Idx> size_t index(Idx... idxs) const { - static_assert(sizeof...(idxs) == NDim, + static_assert(sizeof...(idxs) == Shape::ndim, "number of indices must equal the number of dimensions"); - const std::array<size_t, NDim> m{idxs...}; + const std::array<size_t, Shape::ndim> m{idxs...}; size_t idx = offset_; - for (size_t i = 0; i < NDim; ++i) { - idx += stride_[i] * m[i]; + for (size_t i = 0; i < Shape::ndim; ++i) { + if (Shape::stride(i) != Dynamic) { + idx += Shape::stride(i) * m[i]; + } else { + idx += stride_[i] * m[i]; + } #ifndef NO_BOUNDS_CHECK if (m[i] >= shape_[i]) throw std::out_of_range("index out of bounds"); @@ -149,31 +195,14 @@ public: return idx; } - typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> EigenMatrix; - typedef Eigen::Map<EigenMatrix, 0, Eigen::Stride<Eigen::Dynamic,Eigen::Dynamic> > EigenMap; + typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix; + typedef Eigen::Map<EigenMatrix> EigenMap; EigenMap to_matrix() const { - static_assert(NDim == 2, "matrix must be two-dimensional"); - return {(Scalar *)data_.data() + offset_, - static_cast<Eigen::Index>(shape_[0]), static_cast<Eigen::Index>(shape_[1]), - {static_cast<Eigen::Index>(stride_[1]), static_cast<Eigen::Index>(stride_[0])}}; - } - - template <size_t Rows, size_t Cols> - Eigen::Map< Eigen::Matrix<Scalar,Rows,Cols,Eigen::RowMajor> > to_fixed_matrix() const - { - static_assert(NDim == 2, "matrix must be two-dimensional"); - -#ifndef NO_BOUNDS_CHECK - if (shape_[0] != Rows || shape_[1] != Cols) - throw std::out_of_range("incorrect shape"); - if (stride_[0] != shape_[1] || stride_[1] != 1) - throw std::out_of_range("data not in row major order"); -#endif - + static_assert(Shape::ndim == 2, "matrix must be two-dimensional"); return {(Scalar *)data_.data() + offset_, - static_cast<Eigen::Index>(shape_[0]), static_cast<Eigen::Index>(shape_[1])}; + static_cast<Eigen::Index>(shape<0>()), static_cast<Eigen::Index>(shape<1>())}; } operator EigenMap() const { return to_matrix(); } diff --git a/src/main.cpp b/src/main.cpp index 1a79aad..54c3aee 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -4,8 +4,8 @@ using namespace std::complex_literals; -template <typename Scalar> -void dump(Array<Scalar,3> Q) +template <typename Scalar, typename Shape> +void dump(Array<Scalar,Shape> Q) { for (size_t i = 0; i < Q.template shape<0>(); ++i) { for (size_t j = 0; j < Q.template shape<1>(); ++j) { @@ -23,8 +23,8 @@ void dump(Array<Scalar,3> Q) int main() { std::vector<ADComplex> v(2*2*2); - Array<ADComplex,3> Q0(v,{2,2,2},true); - auto Q = Q0.slice<0>(0); + Array<ADComplex,Shape<Dynamic,2,2> > Q0(v, {2,2,2}); + auto Q = Q0.part(1); std::cout << Q0.index(1u,0u,0u) << std::endl; @@ -32,15 +32,15 @@ int main() std::cout << Q.index(0u,0u) << std::endl; - //auto mat = Q.to_matrix(); - auto mat = Q.to_fixed_matrix<2,2>(); + auto mat = Q.to_matrix(); mat = mat * mat * mat * adcomplex(1j); - mat(0,0) = mat(1,1) = mat.trace(); dump(Q0); std::cout << mat << std::endl; + static_assert(std::is_same<decltype(mat), Eigen::Map<Eigen::Matrix<ADComplex, 2, 2, Eigen::RowMajor> > >::value, "wrong type"); + return 0; } -- GitLab