From 5c9b7917411f99e8ce6da68843ee999e0bc7b4e8 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Wed, 3 Aug 2022 14:02:09 +0300 Subject: [PATCH] array: allow multi-index Array::part(idxs...) --- src/array.hpp | 77 ++++++++++++++++++++++++++++++++++----------------- src/main.cpp | 4 +-- 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/src/array.hpp b/src/array.hpp index 58a1786..65f9302 100644 --- a/src/array.hpp +++ b/src/array.hpp @@ -30,7 +30,7 @@ struct Shape<Dim0,Dims...> { static const size_t head = Dim0; typedef Shape<Dims...> Tail; - static const size_t ndim = 1 + Tail::ndim; + static const size_t ndim = 1 + sizeof...(Dims); static const bool fixed = (Dim0 != Dynamic) && Tail::fixed; static constexpr size_t dim(size_t i) @@ -55,6 +55,22 @@ struct Shape<Dim0,Dims...> static constexpr std::array<size_t, ndim> dims() { return {Dim0, Dims...}; } }; +template <size_t I, typename Shape> +struct TailNth; + +template <typename Shape> +struct TailNth<0, Shape> +{ + using type = Shape; +}; + +template <size_t I, typename Shape> +struct TailNth +{ + using type = typename TailNth<I-1, typename Shape::Tail>::type; +}; + + /*! Simple data-by-reference strided array class a la Fortran */ template <typename Scalar, typename Shape, typename Storage = std::vector<Scalar> > @@ -86,20 +102,18 @@ private: void init_strides() { for (size_t i = Shape::ndim; i > 0; --i) { - if (i == Shape::ndim) + if (Shape::stride(i-1) != Dynamic) + stride_[i-1] = Shape::stride(i-1); + else if (i == Shape::ndim) stride_[i-1] = 1; else stride_[i-1] = shape_[i] * stride_[i]; } - for (size_t i = 0; i < Shape::ndim; ++i) { - std::cout << "stride "<< i << ": " << stride_[i] << " vs " << Shape::stride(i) << std::endl; - } } template <size_t axis> static constexpr Eigen::Index eigen_shape() { - static_assert(axis < Shape::ndim && axis < 2, "invalid axis"); size_t n = (axis == 0) ? Shape::head : Shape::Tail::head; return (n == Dynamic) ? Eigen::Dynamic : n; } @@ -115,7 +129,7 @@ public: Array(Storage& data) : data_(data), offset_(0) { - static_assert(Shape::fixed, "array shape is not fixed"); + static_assert(Shape::fixed, "array shape is not compile-time fixed"); shape_ = Shape::dims(); init_strides(); data_bounds_check(); @@ -127,23 +141,6 @@ public: template <typename... Idx> const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; } - Array<Scalar, typename Shape::Tail> part(size_t pos) const - { - std::array<size_t, Shape::ndim-1> shape; - size_t offset; - -#ifndef NO_BOUNDS_CHECK - if (pos >= shape_[0]) - throw std::out_of_range("index out of bounds"); -#endif - - offset = offset_ + stride_[0] * pos; - - for (size_t i = 1; i < Shape::ndim; ++i) - shape[i-1] = shape_[i]; - return {data_, shape, offset}; - } - Storage& data() { return data_; } template <size_t axis> @@ -195,17 +192,45 @@ public: return idx; } + template <typename... Idx> + Array<Scalar, typename TailNth<sizeof...(Idx), Shape>::type > part(Idx... idxs) const + { + constexpr size_t nidxs = sizeof...(idxs); + static_assert(nidxs < Shape::ndim, + "number of indices must be small than the number of dimensions"); + + const std::array<size_t, Shape::ndim> m{idxs...}; + std::array<size_t, Shape::ndim - nidxs> shape; + size_t offset = offset_; + + for (size_t i = 0; i < nidxs; ++i) { + if (Shape::stride(i) != Dynamic) { + offset += Shape::stride(i) * m[i]; + } else { + offset += stride_[i] * m[i]; + } +#ifndef NO_BOUNDS_CHECK + if (m[i] >= shape_[i]) + throw std::out_of_range("index out of bounds"); +#endif + } + + for (size_t i = nidxs; i < Shape::ndim; ++i) + shape[i - nidxs] = shape_[i]; + return {data_, shape, offset}; + } + typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix; typedef Eigen::Map<EigenMatrix> EigenMap; - EigenMap to_matrix() const + EigenMap matrix() const { static_assert(Shape::ndim == 2, "matrix must be two-dimensional"); return {(Scalar *)data_.data() + offset_, static_cast<Eigen::Index>(shape<0>()), static_cast<Eigen::Index>(shape<1>())}; } - operator EigenMap() const { return to_matrix(); } + operator EigenMap() const { return matrix(); } }; diff --git a/src/main.cpp b/src/main.cpp index 54c3aee..867f7f7 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -24,7 +24,7 @@ int main() { std::vector<ADComplex> v(2*2*2); Array<ADComplex,Shape<Dynamic,2,2> > Q0(v, {2,2,2}); - auto Q = Q0.part(1); + auto Q = Q0.part(1u); std::cout << Q0.index(1u,0u,0u) << std::endl; @@ -32,7 +32,7 @@ int main() std::cout << Q.index(0u,0u) << std::endl; - auto mat = Q.to_matrix(); + auto mat = Q.matrix(); mat = mat * mat * mat * adcomplex(1j); mat(0,0) = mat(1,1) = mat.trace(); -- GitLab