From 760c46e8e764269ab89b9432b55e04afc793cd7c Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Tue, 2 Aug 2022 15:38:02 +0300
Subject: [PATCH] array: better error checking etc

---
 src/action.hpp |  2 +-
 src/array.hpp  | 34 ++++++++++++++++++++++++++--------
 src/common.hpp |  2 --
 src/main.cpp   |  6 +++---
 4 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/src/action.hpp b/src/action.hpp
index c8c57b4..1006075 100644
--- a/src/action.hpp
+++ b/src/action.hpp
@@ -10,7 +10,7 @@ inline Scalar S_2(Array<Scalar,2> Q)
 {
     Q(0u,0u) = 123;
     Q(1u,1u) = 123;
-    return 0;
+    return Scalar(0);
 }
 
 
diff --git a/src/array.hpp b/src/array.hpp
index 672fed9..bb2bd7f 100644
--- a/src/array.hpp
+++ b/src/array.hpp
@@ -1,13 +1,19 @@
 #ifndef ARRAY_HPP_
 #define ARRAY_HPP_
 
+#include <vector>
+#include <array>
+#include <type_traits>
+#include <stdexcept>
+
+
 /*! Simple data-by-reference strided array class
  */
-template <typename Scalar, int NDim>
+template <typename Scalar, int NDim, typename Storage = std::vector<Scalar> >
 class Array
 {
 private:
-    std::vector<Scalar>& data_;
+    Storage& data_;
     std::array<size_t, NDim> shape_;
     std::array<size_t, NDim> stride_;
     size_t offset_ = 0;
@@ -17,11 +23,11 @@ public:
         : data_(array.data_), shape_(array.shape_), stride_(array.stride_), offset_(array.offset_)
         {}
 
-    Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset=0)
+    Array(Storage& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset=0)
         : data_(data), shape_(shape), stride_(stride), offset_(offset)
         {}
 
-    Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape)
+    Array(Storage& data, const std::array<size_t, NDim> shape)
         : data_(data), shape_(shape), offset_(0)
         {
             // row-major strides
@@ -42,6 +48,8 @@ public:
     template <size_t axis>
     Array<Scalar, NDim-1> slice(size_t pos=0) const
         {
+            static_assert(axis < NDim, "invalid slice index");
+
             std::array<size_t, NDim-1> shape;
             std::array<size_t, NDim-1> stride;
             size_t offset;
@@ -56,12 +64,14 @@ public:
                 ++j;
             }
 
-            return Array<Scalar,NDim-1>(data_, shape, stride, offset);
+            return {data_, shape, stride, offset};
         }
 
     template <size_t axis>
     Array<Scalar, NDim> slice(size_t begin, size_t end) const
         {
+            static_assert(axis < NDim, "invalid slice index");
+
             std::array<size_t, NDim> shape;
             std::array<size_t, NDim> stride;
             size_t offset;
@@ -76,10 +86,10 @@ public:
                 stride[i] = stride_[i];
             }
 
-            return Array<Scalar,NDim>(data_, shape, stride, offset);
+            return {data_, shape, stride, offset};
         }
 
-    std::vector<Scalar>& data() { return data_; }
+    Storage& data() { return data_; }
     
     template <size_t axis>
     const size_t shape() const { return shape_[axis]; }
@@ -92,11 +102,19 @@ public:
     template <typename... Idx>
     size_t index(Idx... idxs) const
         {
+            static_assert(sizeof...(idxs) == NDim,
+                          "number of indices must equal the number of dimensions");
+
             const std::array<size_t, NDim> m{idxs...};
             size_t idx = offset_;
 
-            for (size_t i = 0; i < NDim; ++i)
+            for (size_t i = 0; i < NDim; ++i) {
                 idx += stride_[i] * m[i];
+#ifndef NO_BOUNDS_CHECK
+                if (m[i] >= shape_[i])
+                    throw std::out_of_range("index out of bounds");
+#endif
+            }
 
             return idx;
         }
diff --git a/src/common.hpp b/src/common.hpp
index 0aac231..b7f4c4b 100644
--- a/src/common.hpp
+++ b/src/common.hpp
@@ -1,8 +1,6 @@
 #ifndef COMMON_HPP_
 #define COMMON_HPP_
 
-#include <array>
-#include <tuple>
 #include <iostream>
 
 #include <cppad/cppad.hpp>
diff --git a/src/main.cpp b/src/main.cpp
index 9a79fdc..e6cc131 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -4,10 +4,10 @@
 
 int main()
 {
-    std::vector<double> v(2*2*2);
-    Array<double,3> Q(v,{2,2,2});
+    std::vector<ADComplex> v(2*2*2);
+    Array<ADComplex,3> Q(v,{2,2,2});
 
-    S_2(Q.slice<3>());
+    S_2(Q.slice<2>());
 
     for (size_t i = 0; i < Q.shape<0>(); ++i) {
         for (size_t j = 0; j < Q.shape<1>(); ++j) {
-- 
GitLab