diff --git a/src/action.hpp b/src/action.hpp index 9e957ae5b91098c02b09d5e3ebe15c166ff0f3ce..a5be36d9781a29e9c92367e543fcde9b181ee710 100644 --- a/src/action.hpp +++ b/src/action.hpp @@ -15,15 +15,15 @@ #include "array.hpp" -template <typename Scalar> -inline Matrix<Scalar,4,4> +template <typename Scalar, typename ScalarW = typename std::remove_const<Scalar>::type> +inline Matrix<ScalarW,4,4> get_Q_matrix(array::ArrayView<Scalar, array::Shape<2,2,2> > Q) { auto g = Q.part(0u).matrix(); auto gt = Q.part(1u).matrix(); Matrix<double,2,2> I; - Matrix<Scalar,2,2> N, Nt; - Matrix<Scalar,4,4> Qm; + Matrix<ScalarW,2,2> N, Nt; + Matrix<ScalarW,4,4> Qm; I << 1, 0, 0, 1; @@ -39,13 +39,13 @@ get_Q_matrix(array::ArrayView<Scalar, array::Shape<2,2,2> > Q) return Qm; } -template <typename Scalar, typename Shape, typename Shape2> -inline Scalar S_2(array::ArrayView<Scalar, Shape> Q, array::ArrayView<Complex, Shape2> U, double h) +template <typename Scalar, typename Shape, typename Shape2, typename ScalarW = typename std::remove_const<Scalar>::type> +inline ScalarW S_2(array::ArrayView<Scalar, Shape> Q, array::ArrayView<const Complex, Shape2> U, double h) { const int N_i[4] = {0, 1, -1, 0}; const int N_j[4] = {1, 0, 0, -1}; const size_t invk[4] = {3, 2, 1, 0}; - Scalar S = Scalar(0); + ScalarW S = Scalar(0); size_t nx = Q.dim(0); size_t ny = Q.dim(1); diff --git a/src/common.hpp b/src/common.hpp index f4acc44f862f2be0ccb064b2177c0525036a2566..e20407d277eb9cbdae21737c1af0b6b5f87d7555 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -16,7 +16,7 @@ #include <cppad/example/cppad_eigen.hpp> #include "array.hpp" -#include "simplevector.hpp" +#include "vector.hpp" typedef std::complex<double> Complex; typedef CppAD::AD<double> ADDouble; diff --git a/src/core.cpp b/src/core.cpp index 85f3f684620c1794d64f3374d5c858dd8a483ba3..408b068cbdf0f14d9085857e467cece667935024 100644 --- a/src/core.cpp +++ b/src/core.cpp @@ -14,8 +14,8 @@ using namespace array::python; class CLASS_Action { public: - typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,4,4,4> > UType; - typedef PyArrayView<Complex,Shape<Dynamic,Dynamic,2,2,2> > QType; + typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,4,4,4> > UType; + typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,2,2,2> > QType; private: UType U_; @@ -39,9 +39,9 @@ public: if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_) throw std::out_of_range("Q array wrong size"); - StoredComplexArray<ADComplex, Shape<Dynamic,Dynamic,2,2,2> > + Array<ADComplex, Shape<Dynamic,Dynamic,2,2,2>, std::vector<ADDouble> > Q({nx_, ny_, 2, 2, 2}); - StoredComplexArray<ADComplex, Shape<1> > res; + Array<ADComplex, Shape<1>, std::vector<ADDouble> > res; auto a = Qval.reshape<Dynamic>({Qval.size()}); auto b = Q.reshape<Dynamic>({Q.size()}); @@ -57,14 +57,13 @@ public: auto grad(QType Q) { - SimpleVector<double> x(reinterpret_cast<double *>(Q.data()), - 2 * Q.size()); + Vector<double> x(reinterpret_cast<double *>(const_cast<Complex *>(Q.data())), 2 * Q.size()); CppAD::ADFun<double> f = eval_ad(Q); auto jac = f.Jacobian(x); std::array<size_t, 1> shape = {jac.size()/2}; - return StoredArray<Complex, Shape<Dynamic>, 0, true, SimpleVector<double> >(std::move(jac), shape); + return Array<Complex, Shape<Dynamic>, Vector<double> >(std::move(jac), shape); } }; diff --git a/src/main.cpp b/src/main.cpp index c9bfd71ab0856f99158a8ecfa998a759f77bb9eb..387618793611f2126cca65a397d4089f803bc87e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,10 +25,10 @@ void dump(ArrayView<Scalar,Shape> Q) int main0() { - StoredComplexArray<Complex,Shape<2,2,2,2,2> > Qval; + Array<Complex,Shape<2,2,2,2,2>,std::vector<double> > Qval; - StoredComplexArray<ADComplex,Shape<2,2,2,2,2> > Q; - StoredArray<Complex,Shape<2,2,4,4,4> > U; + Array<ADComplex,Shape<2,2,2,2,2>,std::vector<ADDouble> > Q; + Array<Complex,Shape<2,2,4,4,4> > U; for (size_t i = 0; i < Qval.dim(0); ++i) { for (size_t j = 0; i < Qval.dim(1); ++i) { @@ -51,8 +51,10 @@ int main0() CppAD::Independent(Q.storage()); - StoredComplexArray<ADComplex,Shape<1> > res; - res(0u) = S_2(Q, U, 0.1); + Array<ADComplex,Shape<1>,std::vector<ADDouble> > res; + + ArrayView<const Complex,Shape<2,2,4,4,4> > U_const = U; + res(0u) = S_2(Q, U_const, 0.1); CppAD::ADFun<double> f(Q.storage(), res.storage()); @@ -70,7 +72,7 @@ int main0() int main() { - StoredArray<ADDouble,Shape<8,8> > Q; + Array<ADDouble,Shape<8,8> > Q; for (size_t i = 0; i < Q.dim(0); ++i) { for (size_t j = 0; i < Q.dim(1); ++i) { @@ -80,7 +82,7 @@ int main() CppAD::Independent(Q.storage()); - StoredArray<ADDouble,Shape<1> > res; + Array<ADDouble,Shape<1> > res; res(0u) = laplacian(Q); CppAD::ADFun<double> f(Q.storage(), res.storage()); diff --git a/src/simplevector.hpp b/src/vector.hpp similarity index 90% rename from src/simplevector.hpp rename to src/vector.hpp index 14f4ab512dca04f67ae4e7f6d5abe143dcc17006..d5aaad3c97e32f48a2b461dbdf9ba1900c36f6db 100644 --- a/src/simplevector.hpp +++ b/src/vector.hpp @@ -10,8 +10,10 @@ #include <vector> +namespace array { + template <class Scalar> -class SimpleVector +class Vector { public: typedef Scalar value_type; @@ -40,15 +42,15 @@ private: } public: - SimpleVector() + Vector() : storage_(), data_(storage_.data()), size_(storage_.size()) {} - SimpleVector(size_t n) + Vector(size_t n) : storage_(n), data_(storage_.data()), size_(storage_.size()) {} - SimpleVector(const SimpleVector& other) + Vector(const Vector& other) : storage_(other.storage_), data_(other.data_), size_(other.size_) { if (other.own_data()) { @@ -59,18 +61,18 @@ public: } } - SimpleVector(SimpleVector&& other) + Vector(Vector&& other) : storage_(std::move(other.storage_)), data_(other.data_), size_(other.size_) { other.data_ = other.storage_.data(); other.size_ = other.storage_.size(); } - SimpleVector(Scalar *data, size_t size) + Vector(Scalar *data, size_t size) : storage_(), data_(data), size_(size) {} - SimpleVector& operator=(const SimpleVector& other) + Vector& operator=(const Vector& other) { if (other.own_data()) { storage_ = other.storage_; @@ -84,7 +86,7 @@ public: return *this; } - SimpleVector& operator=(SimpleVector&& other) + Vector& operator=(Vector&& other) { if (other.own_data()) { storage_ = std::move(other.storage_); @@ -131,4 +133,6 @@ public: } }; +} + #endif // SIMPLEVECTOR_HPP_