From 9cd3030f7d5cb26f25e72ce5e64c26dc69a7def0 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Fri, 5 Aug 2022 19:02:05 +0300
Subject: [PATCH] Add simple vector storage compatible with CppAD

---
 src/common.hpp       |   1 +
 src/core.cpp         |  78 ++++++++++++++++++++-----
 src/simplevector.hpp | 134 +++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 199 insertions(+), 14 deletions(-)
 create mode 100644 src/simplevector.hpp

diff --git a/src/common.hpp b/src/common.hpp
index 64bb7b3..f4acc44 100644
--- a/src/common.hpp
+++ b/src/common.hpp
@@ -16,6 +16,7 @@
 #include <cppad/example/cppad_eigen.hpp>
 
 #include "array.hpp"
+#include "simplevector.hpp"
 
 typedef std::complex<double> Complex;
 typedef CppAD::AD<double> ADDouble;
diff --git a/src/core.cpp b/src/core.cpp
index 3df0f81..85f3f68 100644
--- a/src/core.cpp
+++ b/src/core.cpp
@@ -1,4 +1,5 @@
 #include <pybind11/pybind11.h>
+#include <pybind11/complex.h>
 
 #include "common.hpp"
 #include "array.hpp"
@@ -9,26 +10,75 @@
 using namespace array;
 using namespace array::python;
 
-typedef pybind11::detail::type_caster<array::ArrayView<double, array::Shape<4> > > xcaster;
-
-array::StoredArray<double, Shape<1> >
-add(array::python::PyArrayView<double, Shape<4> > array)
+#define CLASS_Action PY_VISIBILITY Action
+class CLASS_Action
 {
-    StoredArray<double, Shape<1> > out;
+public:
+    typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,4,4,4> > UType;
+    typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;
 
-    double sum = 0;
-    for (size_t i = 0; i < 4; ++i)
-        sum += array(i);
+private:
+    UType U_;
+    size_t nx_, ny_;
+    double h_;
 
-    out(0u) = sum;
+public:
+    Action(double h, UType&& U)
+        : U_(std::move(U)), nx_(U_.dim(0)), ny_(U_.dim(1)), h_(h)
+        {}
 
-    return out;
-}
+    Complex eval(QType Q)
+        {
+            if (Q.dim(0) != nx_ || Q.dim(1) != ny_)
+                throw std::out_of_range("Q array wrong size");
+            return S_2(Q, U_, h_);
+        }
+
+    CppAD::ADFun<double> eval_ad(QType& Qval)
+        {
+            if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_)
+                throw std::out_of_range("Q array wrong size");
+
+            StoredComplexArray<ADComplex, Shape<Dynamic,Dynamic,2,2,2> >
+                Q({nx_, ny_, 2, 2, 2});
+            StoredComplexArray<ADComplex, Shape<1> > res;
+
+            auto a = Qval.reshape<Dynamic>({Qval.size()});
+            auto b = Q.reshape<Dynamic>({Q.size()});
+            for (size_t i = 0; i < a.dim(0); ++i)
+                b(i) = a(i);
+
+            CppAD::Independent(Q.storage());
+
+            res(0u) = S_2(Q, U_, h_);
+
+            return {Q.storage(), res.storage()};
+        }
+
+    auto grad(QType Q)
+        {
+            SimpleVector<double> x(reinterpret_cast<double *>(Q.data()),
+                                   2 * Q.size());
+
+            CppAD::ADFun<double> f = eval_ad(Q);
+            auto jac = f.Jacobian(x);
+
+            std::array<size_t, 1> shape = {jac.size()/2};
+            return StoredArray<Complex, Shape<Dynamic>, 0, true, SimpleVector<double> >(std::move(jac), shape);
+        }
+};
+
+PYBIND11_MODULE(_core, m)
+{
+    namespace py = pybind11;
 
-PYBIND11_MODULE(_core, m) {
     import_array1();
 
-    m.doc() = "pybind11 example plugin"; // optional module docstring
+    m.doc() = "usadelndsoc._core";
 
-    m.def("add", &add, "A function that sums an array");
+    py::class_<Action>(m, "Action")
+        .def(py::init<double, Action::UType>())
+        .def("eval", &Action::eval)
+        .def("grad", &Action::grad)
+        ;
 }
diff --git a/src/simplevector.hpp b/src/simplevector.hpp
new file mode 100644
index 0000000..14f4ab5
--- /dev/null
+++ b/src/simplevector.hpp
@@ -0,0 +1,134 @@
+/* usadelndsoc
+ *
+ * Copyright © 2022 Pauli Virtanen
+ *    @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
+ *
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ */
+#ifndef SIMPLEVECTOR_HPP_
+#define SIMPLEVECTOR_HPP_
+
+#include <vector>
+
+template <class Scalar>
+class SimpleVector
+{
+public:
+    typedef Scalar value_type;
+
+private:
+    std::vector<Scalar> storage_;
+    Scalar *data_;
+    size_t size_;
+
+    bool own_data() const
+        {
+            return data_ == storage_.data();
+        }
+
+    void copy_storage()
+        {
+            if (own_data())
+                return;
+
+            storage_.resize(size_);
+            for (size_t i = 0; i < size_; ++i)
+                storage_[i] = data_[i];
+
+            data_ = storage_.data();
+            size_ = storage_.size();
+        }
+
+public:
+    SimpleVector()
+        : storage_(), data_(storage_.data()), size_(storage_.size())
+        {}
+
+    SimpleVector(size_t n)
+        : storage_(n), data_(storage_.data()), size_(storage_.size())
+        {}
+
+    SimpleVector(const SimpleVector& other)
+        : storage_(other.storage_), data_(other.data_), size_(other.size_)
+        {
+            if (other.own_data()) {
+                data_ = storage_.data();
+                size_ = storage_.size();
+            } else {
+                copy_storage();
+            }
+        }
+
+    SimpleVector(SimpleVector&& other)
+        : storage_(std::move(other.storage_)), data_(other.data_), size_(other.size_)
+        {
+            other.data_ = other.storage_.data();
+            other.size_ = other.storage_.size();
+        }
+
+    SimpleVector(Scalar *data, size_t size)
+        : storage_(), data_(data), size_(size)
+        {}
+
+    SimpleVector& operator=(const SimpleVector& other)
+        {
+            if (other.own_data()) {
+                storage_ = other.storage_;
+                data_ = storage_.data();
+                size_ = storage_.size();
+            } else {
+                data_ = other.data_;
+                size_ = other.size_;
+                copy_storage();
+            }
+            return *this;
+        }
+
+    SimpleVector& operator=(SimpleVector&& other)
+        {
+            if (other.own_data()) {
+                storage_ = std::move(other.storage_);
+                data_ = storage_.data();
+                size_ = storage_.size();
+            } else {
+                storage_ = std::move(other.storage_);
+                data_ = other.data_;
+                size_ = other.size_;
+                copy_storage();
+            }
+            other.data_ = other.storage_.data();
+            other.size_ = other.storage_.size();
+            return *this;
+        }
+
+    Scalar *data() const
+        {
+            return data_;
+        }
+
+    size_t size() const
+        {
+            return size_;
+        }
+
+    void resize(size_t n)
+        {
+            copy_storage();
+
+            storage_.resize(n);
+            data_ = storage_.data();
+            size_ = storage_.size();
+        }
+
+    Scalar& operator[](size_t i)
+        {
+            return data_[i];
+        }
+
+    const Scalar& operator[](size_t i) const
+        {
+            return data_[i];
+        }
+};
+
+#endif // SIMPLEVECTOR_HPP_
-- 
GitLab