From de33c2d97eaf446e08ffbc43c7f4019e0ab34070 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 3 Aug 2022 12:52:55 +0300
Subject: [PATCH] array: support compile-time known array sizes + restrict to
 row-major

Let the compiler better optimize cases where array sizes are known
at compile time.

This implies known strides, so restrict to row-major. Slicing is
then limited to picking from the first axis.
---
 src/action.hpp |   4 +-
 src/array.hpp  | 217 ++++++++++++++++++++++++++++---------------------
 src/main.cpp   |  14 ++--
 3 files changed, 132 insertions(+), 103 deletions(-)

diff --git a/src/action.hpp b/src/action.hpp
index 51e45b5..feea85e 100644
--- a/src/action.hpp
+++ b/src/action.hpp
@@ -5,8 +5,8 @@
 #include "array.hpp"
 
 
-template <typename Scalar>
-inline Scalar S_2(Array<Scalar,2> Q)
+template <typename Scalar, typename Shape>
+inline Scalar S_2(Array<Scalar, Shape> Q)
 {
     Q(0u,0u) = 1;
     Q(0u,1u) = 2;
diff --git a/src/array.hpp b/src/array.hpp
index e53a603..58a1786 100644
--- a/src/array.hpp
+++ b/src/array.hpp
@@ -7,104 +7,141 @@
 #include <initializer_list>
 #include <stdexcept>
 
+constexpr size_t Dynamic = SIZE_MAX;
+
+
+/*! Array dimensions and strides: known at compile-time & dynamic
+ */
+template <size_t...>
+struct Shape;
+
+template <>
+struct Shape<>
+{
+    static const size_t head = 1;
+    static const size_t ndim = 0;
+    static const bool fixed = true;
+    static constexpr size_t dim(size_t i) { return 0; }
+    static constexpr size_t stride(size_t i) { return 1; }
+};
+
+template <size_t Dim0, size_t... Dims>
+struct Shape<Dim0,Dims...>
+{
+    static const size_t head = Dim0;
+    typedef Shape<Dims...> Tail;
+    static const size_t ndim = 1 + Tail::ndim;
+    static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
+
+    static constexpr size_t dim(size_t i)
+        {
+            return (i == 0) ? Dim0 : Tail::dim(i-1);
+        }
+
+    static constexpr size_t stride(size_t i)
+        {
+            if (i == 0) {
+                constexpr size_t tail_stride = Tail::stride(0);
+
+                if (Tail::head == Dynamic || tail_stride == Dynamic)
+                    return Dynamic;
+
+                return Tail::head * tail_stride;
+            } else {
+                return Tail::stride(i-1);
+            }
+        }
+
+    static constexpr std::array<size_t, ndim> dims() { return {Dim0, Dims...}; }
+};
 
 /*! Simple data-by-reference strided array class a la Fortran
  */
-template <typename Scalar, size_t NDim, typename Storage = std::vector<Scalar> >
+template <typename Scalar, typename Shape, typename Storage = std::vector<Scalar> >
 class Array
 {
 private:
     Storage& data_;
-    std::array<size_t, NDim> shape_;
-    std::array<size_t, NDim> stride_;
+    std::array<size_t, Shape::ndim> shape_;
+    std::array<size_t, Shape::ndim> stride_;
     size_t offset_ = 0;
 
-public:
-    Array(Storage& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset)
-        : data_(data), shape_(shape), stride_(stride), offset_(offset)
-        {}
-
-    Array(Storage& data, const std::array<size_t, NDim> shape, const bool row_major=true)
-        : data_(data), shape_(shape), offset_(0)
+    void data_bounds_check() const
         {
-            if (row_major) {
-                for (size_t i = NDim; i > 0; --i) {
-                    if (i == NDim)
-                        stride_[i-1] = 1;
-                    else
-                        stride_[i-1] = shape[i] * stride_[i];
-                }
-            } else {
-                for (size_t i = 0; i < NDim; ++i) {
-                    if (i == 0)
-                        stride_[i] = 1;
-                    else
-                        stride_[i] = shape[i-1] * stride_[i-1];
-                }
-            }
-
 #ifndef NO_BOUNDS_CHECK
             size_t total = 1;
-            for (size_t i = 0; i < NDim; ++i)
-                total *= shape[i];
-            if (total > data.size())
+
+            for (size_t i = 0; i < Shape::ndim; ++i) {
+                total *= shape_[i];
+                if ((Shape::dim(i) != Dynamic && Shape::dim(i) != shape_[i])
+                    || shape_[i] == Dynamic)
+                    throw std::out_of_range("mismatch with fixed shape");
+            }
+            total += offset_;
+            if (total > data_.size())
                 throw std::out_of_range("data array too small");
 #endif
         }
 
-    template <typename... Idx>
-    Scalar& operator()(Idx... idxs) { return data_[index(idxs...)]; }
-
-    template <typename... Idx>
-    const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; }
+    void init_strides()
+        {
+            for (size_t i = Shape::ndim; i > 0; --i) {
+                if (i == Shape::ndim)
+                    stride_[i-1] = 1;
+                else
+                    stride_[i-1] = shape_[i] * stride_[i];
+            }
+            for (size_t i = 0; i < Shape::ndim; ++i) {
+                std::cout << "stride "<< i << ": " << stride_[i] << " vs " << Shape::stride(i) << std::endl;
+            }
+        }
 
     template <size_t axis>
-    Array<Scalar, NDim-1> slice(size_t pos=0) const
+    static constexpr Eigen::Index eigen_shape()
         {
-            static_assert(axis < NDim, "invalid axis");
+            static_assert(axis < Shape::ndim && axis < 2, "invalid axis");
+            size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
+            return (n == Dynamic) ? Eigen::Dynamic : n;
+        }
 
-            std::array<size_t, NDim-1> shape;
-            std::array<size_t, NDim-1> stride;
-            size_t offset;
+public:
+    Array(Storage& data, const std::array<size_t, Shape::ndim> shape, const size_t offset=0)
+        : data_(data), shape_(shape), offset_(offset)
+        {
+            init_strides();
+            data_bounds_check();
+        }
 
-            offset = offset_ + stride_[axis] * pos;
+    Array(Storage& data)
+        : data_(data), offset_(0)
+        {
+            static_assert(Shape::fixed, "array shape is not fixed");
+            shape_ = Shape::dims();
+            init_strides();
+            data_bounds_check();
+        }
 
-            for (size_t i = 0, j = 0; i < NDim; ++i) {
-                if (i == axis)
-                    continue;
-                shape[j] = shape_[i];
-                stride[j] = stride_[i];
-                ++j;
-            }
+    template <typename... Idx>
+    Scalar& operator()(Idx... idxs) { return data_[index(idxs...)]; }
 
-            return {data_, shape, stride, offset};
-        }
+    template <typename... Idx>
+    const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; }
 
-    template <size_t axis>
-    Array<Scalar, NDim> slice(size_t begin, size_t end) const
+    Array<Scalar, typename Shape::Tail> part(size_t pos) const
         {
-            static_assert(axis < NDim, "invalid axis");
-
-            std::array<size_t, NDim> shape;
-            std::array<size_t, NDim> stride;
+            std::array<size_t, Shape::ndim-1> shape;
             size_t offset;
 
 #ifndef NO_BOUNDS_CHECK
-            if (end < begin || begin >= shape[axis] || end > shape[axis])
-                throw std::out_of_range("begin/end indices out of bounds");
+            if (pos >= shape_[0])
+                throw std::out_of_range("index out of bounds");
 #endif
 
-            offset = offset_ + stride_[axis] * begin;
+            offset = offset_ + stride_[0] * pos;
 
-            for (size_t i = 0; i < NDim; ++i) {
-                if (i == axis)
-                    shape[i] = end - begin;
-                else
-                    shape[i] = shape_[i];
-                stride[i] = stride_[i];
-            }
-
-            return {data_, shape, stride, offset};
+            for (size_t i = 1; i < Shape::ndim; ++i)
+                shape[i-1] = shape_[i];
+            return {data_, shape, offset};
         }
 
     Storage& data() { return data_; }
@@ -112,14 +149,19 @@ public:
     template <size_t axis>
     const size_t shape() const
         {
-            static_assert(axis < NDim, "invalid axis");
+            static_assert(axis < Shape::ndim, "invalid axis");
+            constexpr size_t n = Shape::dim(axis);
+            if (n != Dynamic)
+                return n;
             return shape_[axis];
         }
 
     template <size_t axis>
     const size_t stride() const
         {
-            static_assert(axis < NDim, "invalid axis");
+            static_assert(axis < Shape::ndim, "invalid axis");
+            if (Shape::stride(axis) != Dynamic)
+                return Shape::stride(axis);
             return stride_[axis];
         }
 
@@ -128,14 +170,18 @@ public:
     template <typename... Idx>
     size_t index(Idx... idxs) const
         {
-            static_assert(sizeof...(idxs) == NDim,
+            static_assert(sizeof...(idxs) == Shape::ndim,
                           "number of indices must equal the number of dimensions");
 
-            const std::array<size_t, NDim> m{idxs...};
+            const std::array<size_t, Shape::ndim> m{idxs...};
             size_t idx = offset_;
 
-            for (size_t i = 0; i < NDim; ++i) {
-                idx += stride_[i] * m[i];
+            for (size_t i = 0; i < Shape::ndim; ++i) {
+                if (Shape::stride(i) != Dynamic) {
+                    idx += Shape::stride(i) * m[i];
+                } else {
+                    idx += stride_[i] * m[i];
+                }
 #ifndef NO_BOUNDS_CHECK
                 if (m[i] >= shape_[i])
                     throw std::out_of_range("index out of bounds");
@@ -149,31 +195,14 @@ public:
             return idx;
         }
 
-    typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> EigenMatrix;
-    typedef Eigen::Map<EigenMatrix, 0, Eigen::Stride<Eigen::Dynamic,Eigen::Dynamic> > EigenMap;
+    typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix;
+    typedef Eigen::Map<EigenMatrix> EigenMap;
 
     EigenMap to_matrix() const
         {
-            static_assert(NDim == 2, "matrix must be two-dimensional");
-            return {(Scalar *)data_.data() + offset_,
-                static_cast<Eigen::Index>(shape_[0]), static_cast<Eigen::Index>(shape_[1]),
-                {static_cast<Eigen::Index>(stride_[1]), static_cast<Eigen::Index>(stride_[0])}};
-        }
-
-    template <size_t Rows, size_t Cols>
-    Eigen::Map< Eigen::Matrix<Scalar,Rows,Cols,Eigen::RowMajor> > to_fixed_matrix() const
-        {
-            static_assert(NDim == 2, "matrix must be two-dimensional");
-
-#ifndef NO_BOUNDS_CHECK
-            if (shape_[0] != Rows || shape_[1] != Cols)
-                throw std::out_of_range("incorrect shape");
-            if (stride_[0] != shape_[1] || stride_[1] != 1)
-                throw std::out_of_range("data not in row major order");
-#endif
-
+            static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
             return {(Scalar *)data_.data() + offset_,
-                static_cast<Eigen::Index>(shape_[0]), static_cast<Eigen::Index>(shape_[1])};
+                static_cast<Eigen::Index>(shape<0>()), static_cast<Eigen::Index>(shape<1>())};
         }
 
     operator EigenMap() const { return to_matrix(); }
diff --git a/src/main.cpp b/src/main.cpp
index 1a79aad..54c3aee 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -4,8 +4,8 @@
 
 using namespace std::complex_literals;
 
-template <typename Scalar>
-void dump(Array<Scalar,3> Q)
+template <typename Scalar, typename Shape>
+void dump(Array<Scalar,Shape> Q)
 {
     for (size_t i = 0; i < Q.template shape<0>(); ++i) {
         for (size_t j = 0; j < Q.template shape<1>(); ++j) {
@@ -23,8 +23,8 @@ void dump(Array<Scalar,3> Q)
 int main()
 {
     std::vector<ADComplex> v(2*2*2);
-    Array<ADComplex,3> Q0(v,{2,2,2},true);
-    auto Q = Q0.slice<0>(0);
+    Array<ADComplex,Shape<Dynamic,2,2> > Q0(v, {2,2,2});
+    auto Q = Q0.part(1);
 
     std::cout << Q0.index(1u,0u,0u) << std::endl;
 
@@ -32,15 +32,15 @@ int main()
 
     std::cout << Q.index(0u,0u) << std::endl;
 
-    //auto mat = Q.to_matrix();
-    auto mat = Q.to_fixed_matrix<2,2>();
+    auto mat = Q.to_matrix();
 
     mat = mat * mat * mat * adcomplex(1j);
-
     mat(0,0) = mat(1,1) = mat.trace();
 
     dump(Q0);
     std::cout << mat << std::endl;
 
+    static_assert(std::is_same<decltype(mat), Eigen::Map<Eigen::Matrix<ADComplex, 2, 2, Eigen::RowMajor> > >::value, "wrong type");
+
     return 0;
 }
-- 
GitLab