From 680d5ce22c0f0ba3597294a1eedd6adfea4258ca Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Thu, 4 Aug 2022 15:54:25 +0300 Subject: [PATCH] common: fix binops and Eigen binops for CppAD --- src/common.hpp | 49 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/src/common.hpp b/src/common.hpp index e518a9e..64bb7b3 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -10,23 +10,66 @@ #include <iostream> +#include <Eigen/Core> + #include <cppad/cppad.hpp> #include <cppad/example/cppad_eigen.hpp> -#include <Eigen/Core> +#include "array.hpp" typedef std::complex<double> Complex; typedef CppAD::AD<double> ADDouble; typedef std::complex<ADDouble> ADComplex; -ADComplex adcomplex(const ADComplex& cpx) +template <typename Scalar_, int Rows_, int Cols_> +using Matrix = Eigen::Matrix<Scalar_, Rows_, Cols_, Eigen::RowMajor>; + +ADComplex adcomplex(const ADComplex&& cpx) { return {cpx}; } -ADComplex adcomplex(const std::complex<double>& cpx) +ADComplex adcomplex(const std::complex<double>&& cpx) { return {cpx}; } +namespace Eigen +{ + /* Eigen binop maps: (double, ADDouble) -> ADDouble */ + template<typename BinaryOp> struct ScalarBinaryOpTraits<ADDouble,double,BinaryOp> { typedef ADDouble ReturnType; }; + template<typename BinaryOp> struct ScalarBinaryOpTraits<double,ADDouble,BinaryOp> { typedef ADDouble ReturnType; }; + + /* Eigen binop maps: (double/complex, ADComplex) -> ADComplex */ + template<typename BinaryOp> struct ScalarBinaryOpTraits<ADComplex,double,BinaryOp> { typedef ADComplex ReturnType; }; + template<typename BinaryOp> struct ScalarBinaryOpTraits<double,ADComplex,BinaryOp> { typedef ADComplex ReturnType; }; + template<typename BinaryOp> struct ScalarBinaryOpTraits<ADComplex,Complex,BinaryOp> { typedef ADComplex ReturnType; }; + template<typename BinaryOp> struct ScalarBinaryOpTraits<Complex,ADComplex,BinaryOp> { typedef ADComplex ReturnType; }; + + namespace usadelndsoc { + ADComplex operator*(const ADComplex& a, const double& b) { return a * adcomplex(b); } + ADComplex operator*(const double& a, const ADComplex& b) { return adcomplex(a) * b; } + ADComplex operator*(const ADComplex& a, const Complex& b) { return a * adcomplex(b); } + ADComplex operator*(const Complex& a, const ADComplex& b) { return adcomplex(a) * b; } + + ADComplex operator+(const ADComplex& a, const double& b) { return a + adcomplex(b); } + ADComplex operator+(const double& a, const ADComplex& b) { return adcomplex(a) + b; } + ADComplex operator+(const ADComplex& a, const Complex& b) { return a + adcomplex(b); } + ADComplex operator+(const Complex& a, const ADComplex& b) { return adcomplex(a) + b; } + + ADComplex operator-(const ADComplex& a, const double& b) { return a - adcomplex(b); } + ADComplex operator-(const double& a, const ADComplex& b) { return adcomplex(a) - b; } + ADComplex operator-(const ADComplex& a, const Complex& b) { return a - adcomplex(b); } + ADComplex operator-(const Complex& a, const ADComplex& b) { return adcomplex(a) - b; } + + ADComplex operator/(const ADComplex& a, const double& b) { return a / adcomplex(b); } + ADComplex operator/(const double& a, const ADComplex& b) { return adcomplex(a) / b; } + ADComplex operator/(const ADComplex& a, const Complex& b) { return a / adcomplex(b); } + ADComplex operator/(const Complex& a, const ADComplex& b) { return adcomplex(a) / b; } + } + using namespace usadelndsoc; +} + +using namespace Eigen::usadelndsoc; + #endif -- GitLab