From 5c9b7917411f99e8ce6da68843ee999e0bc7b4e8 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 3 Aug 2022 14:02:09 +0300
Subject: [PATCH] array: allow multi-index Array::part(idxs...)

---
 src/array.hpp | 77 ++++++++++++++++++++++++++++++++++-----------------
 src/main.cpp  |  4 +--
 2 files changed, 53 insertions(+), 28 deletions(-)

diff --git a/src/array.hpp b/src/array.hpp
index 58a1786..65f9302 100644
--- a/src/array.hpp
+++ b/src/array.hpp
@@ -30,7 +30,7 @@ struct Shape<Dim0,Dims...>
 {
     static const size_t head = Dim0;
     typedef Shape<Dims...> Tail;
-    static const size_t ndim = 1 + Tail::ndim;
+    static const size_t ndim = 1 + sizeof...(Dims);
     static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
 
     static constexpr size_t dim(size_t i)
@@ -55,6 +55,22 @@ struct Shape<Dim0,Dims...>
     static constexpr std::array<size_t, ndim> dims() { return {Dim0, Dims...}; }
 };
 
+template <size_t I, typename Shape>
+struct TailNth;
+
+template <typename Shape>
+struct TailNth<0, Shape>
+{
+    using type = Shape;
+};
+
+template <size_t I, typename Shape>
+struct TailNth
+{
+    using type = typename TailNth<I-1, typename Shape::Tail>::type;
+};
+
+
 /*! Simple data-by-reference strided array class a la Fortran
  */
 template <typename Scalar, typename Shape, typename Storage = std::vector<Scalar> >
@@ -86,20 +102,18 @@ private:
     void init_strides()
         {
             for (size_t i = Shape::ndim; i > 0; --i) {
-                if (i == Shape::ndim)
+                if (Shape::stride(i-1) != Dynamic)
+                    stride_[i-1] = Shape::stride(i-1);
+                else if (i == Shape::ndim)
                     stride_[i-1] = 1;
                 else
                     stride_[i-1] = shape_[i] * stride_[i];
             }
-            for (size_t i = 0; i < Shape::ndim; ++i) {
-                std::cout << "stride "<< i << ": " << stride_[i] << " vs " << Shape::stride(i) << std::endl;
-            }
         }
 
     template <size_t axis>
     static constexpr Eigen::Index eigen_shape()
         {
-            static_assert(axis < Shape::ndim && axis < 2, "invalid axis");
             size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
             return (n == Dynamic) ? Eigen::Dynamic : n;
         }
@@ -115,7 +129,7 @@ public:
     Array(Storage& data)
         : data_(data), offset_(0)
         {
-            static_assert(Shape::fixed, "array shape is not fixed");
+            static_assert(Shape::fixed, "array shape is not compile-time fixed");
             shape_ = Shape::dims();
             init_strides();
             data_bounds_check();
@@ -127,23 +141,6 @@ public:
     template <typename... Idx>
     const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; }
 
-    Array<Scalar, typename Shape::Tail> part(size_t pos) const
-        {
-            std::array<size_t, Shape::ndim-1> shape;
-            size_t offset;
-
-#ifndef NO_BOUNDS_CHECK
-            if (pos >= shape_[0])
-                throw std::out_of_range("index out of bounds");
-#endif
-
-            offset = offset_ + stride_[0] * pos;
-
-            for (size_t i = 1; i < Shape::ndim; ++i)
-                shape[i-1] = shape_[i];
-            return {data_, shape, offset};
-        }
-
     Storage& data() { return data_; }
 
     template <size_t axis>
@@ -195,17 +192,45 @@ public:
             return idx;
         }
 
+    template <typename... Idx>
+    Array<Scalar, typename TailNth<sizeof...(Idx), Shape>::type > part(Idx... idxs) const
+        {
+            constexpr size_t nidxs = sizeof...(idxs);
+            static_assert(nidxs < Shape::ndim,
+                          "number of indices must be small than the number of dimensions");
+
+            const std::array<size_t, Shape::ndim> m{idxs...};
+            std::array<size_t, Shape::ndim - nidxs> shape;
+            size_t offset = offset_;
+
+            for (size_t i = 0; i < nidxs; ++i) {
+                if (Shape::stride(i) != Dynamic) {
+                    offset += Shape::stride(i) * m[i];
+                } else {
+                    offset += stride_[i] * m[i];
+                }
+#ifndef NO_BOUNDS_CHECK
+                if (m[i] >= shape_[i])
+                    throw std::out_of_range("index out of bounds");
+#endif
+            }
+
+            for (size_t i = nidxs; i < Shape::ndim; ++i)
+                shape[i - nidxs] = shape_[i];
+            return {data_, shape, offset};
+        }
+
     typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix;
     typedef Eigen::Map<EigenMatrix> EigenMap;
 
-    EigenMap to_matrix() const
+    EigenMap matrix() const
         {
             static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
             return {(Scalar *)data_.data() + offset_,
                 static_cast<Eigen::Index>(shape<0>()), static_cast<Eigen::Index>(shape<1>())};
         }
 
-    operator EigenMap() const { return to_matrix(); }
+    operator EigenMap() const { return matrix(); }
 };
 
 
diff --git a/src/main.cpp b/src/main.cpp
index 54c3aee..867f7f7 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -24,7 +24,7 @@ int main()
 {
     std::vector<ADComplex> v(2*2*2);
     Array<ADComplex,Shape<Dynamic,2,2> > Q0(v, {2,2,2});
-    auto Q = Q0.part(1);
+    auto Q = Q0.part(1u);
 
     std::cout << Q0.index(1u,0u,0u) << std::endl;
 
@@ -32,7 +32,7 @@ int main()
 
     std::cout << Q.index(0u,0u) << std::endl;
 
-    auto mat = Q.to_matrix();
+    auto mat = Q.matrix();
 
     mat = mat * mat * mat * adcomplex(1j);
     mat(0,0) = mat(1,1) = mat.trace();
-- 
GitLab