From 57bd4a709a5a8a881c9c653ba54e7702742d8611 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Mon, 8 Aug 2022 15:20:00 +0300
Subject: [PATCH] Adjust to const + rename SimpleVector

---
 src/action.hpp                       | 14 +++++++-------
 src/common.hpp                       |  2 +-
 src/core.cpp                         | 13 ++++++-------
 src/main.cpp                         | 16 +++++++++-------
 src/{simplevector.hpp => vector.hpp} | 20 ++++++++++++--------
 5 files changed, 35 insertions(+), 30 deletions(-)
 rename src/{simplevector.hpp => vector.hpp} (90%)

diff --git a/src/action.hpp b/src/action.hpp
index 9e957ae..a5be36d 100644
--- a/src/action.hpp
+++ b/src/action.hpp
@@ -15,15 +15,15 @@
 #include "array.hpp"
 
 
-template <typename Scalar>
-inline Matrix<Scalar,4,4>
+template <typename Scalar, typename ScalarW = typename std::remove_const<Scalar>::type>
+inline Matrix<ScalarW,4,4>
 get_Q_matrix(array::ArrayView<Scalar, array::Shape<2,2,2> > Q)
 {
     auto g = Q.part(0u).matrix();
     auto gt = Q.part(1u).matrix();
     Matrix<double,2,2> I;
-    Matrix<Scalar,2,2> N, Nt;
-    Matrix<Scalar,4,4> Qm;
+    Matrix<ScalarW,2,2> N, Nt;
+    Matrix<ScalarW,4,4> Qm;
 
     I << 1, 0,
         0, 1;
@@ -39,13 +39,13 @@ get_Q_matrix(array::ArrayView<Scalar, array::Shape<2,2,2> > Q)
     return Qm;
 }
 
-template <typename Scalar, typename Shape, typename Shape2>
-inline Scalar S_2(array::ArrayView<Scalar, Shape> Q, array::ArrayView<Complex, Shape2> U, double h)
+template <typename Scalar, typename Shape, typename Shape2, typename ScalarW = typename std::remove_const<Scalar>::type>
+inline ScalarW S_2(array::ArrayView<Scalar, Shape> Q, array::ArrayView<const Complex, Shape2> U, double h)
 {
     const int N_i[4] = {0, 1, -1, 0};
     const int N_j[4] = {1, 0, 0, -1};
     const size_t invk[4] = {3, 2, 1, 0};
-    Scalar S = Scalar(0);
+    ScalarW S = Scalar(0);
 
     size_t nx = Q.dim(0);
     size_t ny = Q.dim(1);
diff --git a/src/common.hpp b/src/common.hpp
index f4acc44..e20407d 100644
--- a/src/common.hpp
+++ b/src/common.hpp
@@ -16,7 +16,7 @@
 #include <cppad/example/cppad_eigen.hpp>
 
 #include "array.hpp"
-#include "simplevector.hpp"
+#include "vector.hpp"
 
 typedef std::complex<double> Complex;
 typedef CppAD::AD<double> ADDouble;
diff --git a/src/core.cpp b/src/core.cpp
index 85f3f68..408b068 100644
--- a/src/core.cpp
+++ b/src/core.cpp
@@ -14,8 +14,8 @@ using namespace array::python;
 class CLASS_Action
 {
 public:
-    typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,4,4,4> > UType;
-    typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;
+    typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,4,4,4> > UType;
+    typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;
 
 private:
     UType U_;
@@ -39,9 +39,9 @@ public:
             if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_)
                 throw std::out_of_range("Q array wrong size");
 
-            StoredComplexArray<ADComplex, Shape<Dynamic,Dynamic,2,2,2> >
+            Array<ADComplex, Shape<Dynamic,Dynamic,2,2,2>, std::vector<ADDouble> >
                 Q({nx_, ny_, 2, 2, 2});
-            StoredComplexArray<ADComplex, Shape<1> > res;
+            Array<ADComplex, Shape<1>, std::vector<ADDouble> > res;
 
             auto a = Qval.reshape<Dynamic>({Qval.size()});
             auto b = Q.reshape<Dynamic>({Q.size()});
@@ -57,14 +57,13 @@ public:
 
     auto grad(QType Q)
         {
-            SimpleVector<double> x(reinterpret_cast<double *>(Q.data()),
-                                   2 * Q.size());
+            Vector<double> x(reinterpret_cast<double *>(const_cast<Complex *>(Q.data())), 2 * Q.size());
 
             CppAD::ADFun<double> f = eval_ad(Q);
             auto jac = f.Jacobian(x);
 
             std::array<size_t, 1> shape = {jac.size()/2};
-            return StoredArray<Complex, Shape<Dynamic>, 0, true, SimpleVector<double> >(std::move(jac), shape);
+            return Array<Complex, Shape<Dynamic>, Vector<double> >(std::move(jac), shape);
         }
 };
 
diff --git a/src/main.cpp b/src/main.cpp
index c9bfd71..3876187 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -25,10 +25,10 @@ void dump(ArrayView<Scalar,Shape> Q)
 
 int main0()
 {
-    StoredComplexArray<Complex,Shape<2,2,2,2,2> > Qval;
+    Array<Complex,Shape<2,2,2,2,2>,std::vector<double> > Qval;
 
-    StoredComplexArray<ADComplex,Shape<2,2,2,2,2> > Q;
-    StoredArray<Complex,Shape<2,2,4,4,4> > U;
+    Array<ADComplex,Shape<2,2,2,2,2>,std::vector<ADDouble> > Q;
+    Array<Complex,Shape<2,2,4,4,4> > U;
 
     for (size_t i = 0; i < Qval.dim(0); ++i) {
         for (size_t j = 0; i < Qval.dim(1); ++i) {
@@ -51,8 +51,10 @@ int main0()
 
     CppAD::Independent(Q.storage());
 
-    StoredComplexArray<ADComplex,Shape<1> > res;
-    res(0u) = S_2(Q, U, 0.1);
+    Array<ADComplex,Shape<1>,std::vector<ADDouble> > res;
+
+    ArrayView<const Complex,Shape<2,2,4,4,4> > U_const = U;
+    res(0u) = S_2(Q, U_const, 0.1);
 
     CppAD::ADFun<double> f(Q.storage(), res.storage());
 
@@ -70,7 +72,7 @@ int main0()
 
 int main()
 {
-    StoredArray<ADDouble,Shape<8,8> > Q;
+    Array<ADDouble,Shape<8,8> > Q;
 
     for (size_t i = 0; i < Q.dim(0); ++i) {
         for (size_t j = 0; i < Q.dim(1); ++i) {
@@ -80,7 +82,7 @@ int main()
 
     CppAD::Independent(Q.storage());
 
-    StoredArray<ADDouble,Shape<1> > res;
+    Array<ADDouble,Shape<1> > res;
     res(0u) = laplacian(Q);
 
     CppAD::ADFun<double> f(Q.storage(), res.storage());
diff --git a/src/simplevector.hpp b/src/vector.hpp
similarity index 90%
rename from src/simplevector.hpp
rename to src/vector.hpp
index 14f4ab5..d5aaad3 100644
--- a/src/simplevector.hpp
+++ b/src/vector.hpp
@@ -10,8 +10,10 @@
 
 #include <vector>
 
+namespace array {
+
 template <class Scalar>
-class SimpleVector
+class Vector
 {
 public:
     typedef Scalar value_type;
@@ -40,15 +42,15 @@ private:
         }
 
 public:
-    SimpleVector()
+    Vector()
         : storage_(), data_(storage_.data()), size_(storage_.size())
         {}
 
-    SimpleVector(size_t n)
+    Vector(size_t n)
         : storage_(n), data_(storage_.data()), size_(storage_.size())
         {}
 
-    SimpleVector(const SimpleVector& other)
+    Vector(const Vector& other)
         : storage_(other.storage_), data_(other.data_), size_(other.size_)
         {
             if (other.own_data()) {
@@ -59,18 +61,18 @@ public:
             }
         }
 
-    SimpleVector(SimpleVector&& other)
+    Vector(Vector&& other)
         : storage_(std::move(other.storage_)), data_(other.data_), size_(other.size_)
         {
             other.data_ = other.storage_.data();
             other.size_ = other.storage_.size();
         }
 
-    SimpleVector(Scalar *data, size_t size)
+    Vector(Scalar *data, size_t size)
         : storage_(), data_(data), size_(size)
         {}
 
-    SimpleVector& operator=(const SimpleVector& other)
+    Vector& operator=(const Vector& other)
         {
             if (other.own_data()) {
                 storage_ = other.storage_;
@@ -84,7 +86,7 @@ public:
             return *this;
         }
 
-    SimpleVector& operator=(SimpleVector&& other)
+    Vector& operator=(Vector&& other)
         {
             if (other.own_data()) {
                 storage_ = std::move(other.storage_);
@@ -131,4 +133,6 @@ public:
         }
 };
 
+}
+
 #endif // SIMPLEVECTOR_HPP_
-- 
GitLab