Skip to content
Snippets Groups Projects
Commit 9cd3030f authored by patavirt's avatar patavirt
Browse files

Add simple vector storage compatible with CppAD

parent 4ded7719
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cppad/example/cppad_eigen.hpp> #include <cppad/example/cppad_eigen.hpp>
#include "array.hpp" #include "array.hpp"
#include "simplevector.hpp"
typedef std::complex<double> Complex; typedef std::complex<double> Complex;
typedef CppAD::AD<double> ADDouble; typedef CppAD::AD<double> ADDouble;
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/complex.h>
#include "common.hpp" #include "common.hpp"
#include "array.hpp" #include "array.hpp"
...@@ -9,26 +10,75 @@ ...@@ -9,26 +10,75 @@
using namespace array; using namespace array;
using namespace array::python; using namespace array::python;
typedef pybind11::detail::type_caster<array::ArrayView<double, array::Shape<4> > > xcaster; #define CLASS_Action PY_VISIBILITY Action
class CLASS_Action
array::StoredArray<double, Shape<1> >
add(array::python::PyArrayView<double, Shape<4> > array)
{ {
StoredArray<double, Shape<1> > out; public:
typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,4,4,4> > UType;
typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;
double sum = 0; private:
for (size_t i = 0; i < 4; ++i) UType U_;
sum += array(i); size_t nx_, ny_;
double h_;
out(0u) = sum; public:
Action(double h, UType&& U)
: U_(std::move(U)), nx_(U_.dim(0)), ny_(U_.dim(1)), h_(h)
{}
return out; Complex eval(QType Q)
} {
if (Q.dim(0) != nx_ || Q.dim(1) != ny_)
throw std::out_of_range("Q array wrong size");
return S_2(Q, U_, h_);
}
CppAD::ADFun<double> eval_ad(QType& Qval)
{
if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_)
throw std::out_of_range("Q array wrong size");
StoredComplexArray<ADComplex, Shape<Dynamic,Dynamic,2,2,2> >
Q({nx_, ny_, 2, 2, 2});
StoredComplexArray<ADComplex, Shape<1> > res;
auto a = Qval.reshape<Dynamic>({Qval.size()});
auto b = Q.reshape<Dynamic>({Q.size()});
for (size_t i = 0; i < a.dim(0); ++i)
b(i) = a(i);
CppAD::Independent(Q.storage());
res(0u) = S_2(Q, U_, h_);
return {Q.storage(), res.storage()};
}
auto grad(QType Q)
{
SimpleVector<double> x(reinterpret_cast<double *>(Q.data()),
2 * Q.size());
CppAD::ADFun<double> f = eval_ad(Q);
auto jac = f.Jacobian(x);
std::array<size_t, 1> shape = {jac.size()/2};
return StoredArray<Complex, Shape<Dynamic>, 0, true, SimpleVector<double> >(std::move(jac), shape);
}
};
PYBIND11_MODULE(_core, m)
{
namespace py = pybind11;
PYBIND11_MODULE(_core, m) {
import_array1(); import_array1();
m.doc() = "pybind11 example plugin"; // optional module docstring m.doc() = "usadelndsoc._core";
m.def("add", &add, "A function that sums an array"); py::class_<Action>(m, "Action")
.def(py::init<double, Action::UType>())
.def("eval", &Action::eval)
.def("grad", &Action::grad)
;
} }
/* usadelndsoc
*
* Copyright © 2022 Pauli Virtanen
* @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
#ifndef SIMPLEVECTOR_HPP_
#define SIMPLEVECTOR_HPP_
#include <vector>
template <class Scalar>
class SimpleVector
{
public:
typedef Scalar value_type;
private:
std::vector<Scalar> storage_;
Scalar *data_;
size_t size_;
bool own_data() const
{
return data_ == storage_.data();
}
void copy_storage()
{
if (own_data())
return;
storage_.resize(size_);
for (size_t i = 0; i < size_; ++i)
storage_[i] = data_[i];
data_ = storage_.data();
size_ = storage_.size();
}
public:
SimpleVector()
: storage_(), data_(storage_.data()), size_(storage_.size())
{}
SimpleVector(size_t n)
: storage_(n), data_(storage_.data()), size_(storage_.size())
{}
SimpleVector(const SimpleVector& other)
: storage_(other.storage_), data_(other.data_), size_(other.size_)
{
if (other.own_data()) {
data_ = storage_.data();
size_ = storage_.size();
} else {
copy_storage();
}
}
SimpleVector(SimpleVector&& other)
: storage_(std::move(other.storage_)), data_(other.data_), size_(other.size_)
{
other.data_ = other.storage_.data();
other.size_ = other.storage_.size();
}
SimpleVector(Scalar *data, size_t size)
: storage_(), data_(data), size_(size)
{}
SimpleVector& operator=(const SimpleVector& other)
{
if (other.own_data()) {
storage_ = other.storage_;
data_ = storage_.data();
size_ = storage_.size();
} else {
data_ = other.data_;
size_ = other.size_;
copy_storage();
}
return *this;
}
SimpleVector& operator=(SimpleVector&& other)
{
if (other.own_data()) {
storage_ = std::move(other.storage_);
data_ = storage_.data();
size_ = storage_.size();
} else {
storage_ = std::move(other.storage_);
data_ = other.data_;
size_ = other.size_;
copy_storage();
}
other.data_ = other.storage_.data();
other.size_ = other.storage_.size();
return *this;
}
Scalar *data() const
{
return data_;
}
size_t size() const
{
return size_;
}
void resize(size_t n)
{
copy_storage();
storage_.resize(n);
data_ = storage_.data();
size_ = storage_.size();
}
Scalar& operator[](size_t i)
{
return data_[i];
}
const Scalar& operator[](size_t i) const
{
return data_[i];
}
};
#endif // SIMPLEVECTOR_HPP_
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment