From 9cd3030f7d5cb26f25e72ce5e64c26dc69a7def0 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Fri, 5 Aug 2022 19:02:05 +0300 Subject: [PATCH] Add simple vector storage compatible with CppAD --- src/common.hpp | 1 + src/core.cpp | 78 ++++++++++++++++++++----- src/simplevector.hpp | 134 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 14 deletions(-) create mode 100644 src/simplevector.hpp diff --git a/src/common.hpp b/src/common.hpp index 64bb7b3..f4acc44 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -16,6 +16,7 @@ #include <cppad/example/cppad_eigen.hpp> #include "array.hpp" +#include "simplevector.hpp" typedef std::complex<double> Complex; typedef CppAD::AD<double> ADDouble; diff --git a/src/core.cpp b/src/core.cpp index 3df0f81..85f3f68 100644 --- a/src/core.cpp +++ b/src/core.cpp @@ -1,4 +1,5 @@ #include <pybind11/pybind11.h> +#include <pybind11/complex.h> #include "common.hpp" #include "array.hpp" @@ -9,26 +10,75 @@ using namespace array; using namespace array::python; -typedef pybind11::detail::type_caster<array::ArrayView<double, array::Shape<4> > > xcaster; - -array::StoredArray<double, Shape<1> > -add(array::python::PyArrayView<double, Shape<4> > array) +#define CLASS_Action PY_VISIBILITY Action +class CLASS_Action { - StoredArray<double, Shape<1> > out; +public: + typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,4,4,4> > UType; + typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,2,2,2> > QType; - double sum = 0; - for (size_t i = 0; i < 4; ++i) - sum += array(i); +private: + UType U_; + size_t nx_, ny_; + double h_; - out(0u) = sum; +public: + Action(double h, UType&& U) + : U_(std::move(U)), nx_(U_.dim(0)), ny_(U_.dim(1)), h_(h) + {} - return out; -} + Complex eval(QType Q) + { + if (Q.dim(0) != nx_ || Q.dim(1) != ny_) + throw std::out_of_range("Q array wrong size"); + return S_2(Q, U_, h_); + } + + CppAD::ADFun<double> eval_ad(QType& Qval) + { + if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_) + throw std::out_of_range("Q array wrong size"); + + StoredComplexArray<ADComplex, Shape<Dynamic,Dynamic,2,2,2> > + Q({nx_, ny_, 2, 2, 2}); + StoredComplexArray<ADComplex, Shape<1> > res; + + auto a = Qval.reshape<Dynamic>({Qval.size()}); + auto b = Q.reshape<Dynamic>({Q.size()}); + for (size_t i = 0; i < a.dim(0); ++i) + b(i) = a(i); + + CppAD::Independent(Q.storage()); + + res(0u) = S_2(Q, U_, h_); + + return {Q.storage(), res.storage()}; + } + + auto grad(QType Q) + { + SimpleVector<double> x(reinterpret_cast<double *>(Q.data()), + 2 * Q.size()); + + CppAD::ADFun<double> f = eval_ad(Q); + auto jac = f.Jacobian(x); + + std::array<size_t, 1> shape = {jac.size()/2}; + return StoredArray<Complex, Shape<Dynamic>, 0, true, SimpleVector<double> >(std::move(jac), shape); + } +}; + +PYBIND11_MODULE(_core, m) +{ + namespace py = pybind11; -PYBIND11_MODULE(_core, m) { import_array1(); - m.doc() = "pybind11 example plugin"; // optional module docstring + m.doc() = "usadelndsoc._core"; - m.def("add", &add, "A function that sums an array"); + py::class_<Action>(m, "Action") + .def(py::init<double, Action::UType>()) + .def("eval", &Action::eval) + .def("grad", &Action::grad) + ; } diff --git a/src/simplevector.hpp b/src/simplevector.hpp new file mode 100644 index 0000000..14f4ab5 --- /dev/null +++ b/src/simplevector.hpp @@ -0,0 +1,134 @@ +/* usadelndsoc + * + * Copyright © 2022 Pauli Virtanen + * @author Pauli Virtanen <pauli.t.virtanen@jyu.fi> + * + * SPDX-License-Identifier: AGPL-3.0-or-later + */ +#ifndef SIMPLEVECTOR_HPP_ +#define SIMPLEVECTOR_HPP_ + +#include <vector> + +template <class Scalar> +class SimpleVector +{ +public: + typedef Scalar value_type; + +private: + std::vector<Scalar> storage_; + Scalar *data_; + size_t size_; + + bool own_data() const + { + return data_ == storage_.data(); + } + + void copy_storage() + { + if (own_data()) + return; + + storage_.resize(size_); + for (size_t i = 0; i < size_; ++i) + storage_[i] = data_[i]; + + data_ = storage_.data(); + size_ = storage_.size(); + } + +public: + SimpleVector() + : storage_(), data_(storage_.data()), size_(storage_.size()) + {} + + SimpleVector(size_t n) + : storage_(n), data_(storage_.data()), size_(storage_.size()) + {} + + SimpleVector(const SimpleVector& other) + : storage_(other.storage_), data_(other.data_), size_(other.size_) + { + if (other.own_data()) { + data_ = storage_.data(); + size_ = storage_.size(); + } else { + copy_storage(); + } + } + + SimpleVector(SimpleVector&& other) + : storage_(std::move(other.storage_)), data_(other.data_), size_(other.size_) + { + other.data_ = other.storage_.data(); + other.size_ = other.storage_.size(); + } + + SimpleVector(Scalar *data, size_t size) + : storage_(), data_(data), size_(size) + {} + + SimpleVector& operator=(const SimpleVector& other) + { + if (other.own_data()) { + storage_ = other.storage_; + data_ = storage_.data(); + size_ = storage_.size(); + } else { + data_ = other.data_; + size_ = other.size_; + copy_storage(); + } + return *this; + } + + SimpleVector& operator=(SimpleVector&& other) + { + if (other.own_data()) { + storage_ = std::move(other.storage_); + data_ = storage_.data(); + size_ = storage_.size(); + } else { + storage_ = std::move(other.storage_); + data_ = other.data_; + size_ = other.size_; + copy_storage(); + } + other.data_ = other.storage_.data(); + other.size_ = other.storage_.size(); + return *this; + } + + Scalar *data() const + { + return data_; + } + + size_t size() const + { + return size_; + } + + void resize(size_t n) + { + copy_storage(); + + storage_.resize(n); + data_ = storage_.data(); + size_ = storage_.size(); + } + + Scalar& operator[](size_t i) + { + return data_[i]; + } + + const Scalar& operator[](size_t i) const + { + return data_[i]; + } +}; + +#endif // SIMPLEVECTOR_HPP_ -- GitLab