From 680d5ce22c0f0ba3597294a1eedd6adfea4258ca Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Thu, 4 Aug 2022 15:54:25 +0300
Subject: [PATCH] common: fix binops and Eigen binops for CppAD

---
 src/common.hpp | 49 ++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 46 insertions(+), 3 deletions(-)

diff --git a/src/common.hpp b/src/common.hpp
index e518a9e..64bb7b3 100644
--- a/src/common.hpp
+++ b/src/common.hpp
@@ -10,23 +10,66 @@
 
 #include <iostream>
 
+#include <Eigen/Core>
+
 #include <cppad/cppad.hpp>
 #include <cppad/example/cppad_eigen.hpp>
 
-#include <Eigen/Core>
+#include "array.hpp"
 
 typedef std::complex<double> Complex;
 typedef CppAD::AD<double> ADDouble;
 typedef std::complex<ADDouble> ADComplex;
 
-ADComplex adcomplex(const ADComplex& cpx)
+template <typename Scalar_, int Rows_, int Cols_>
+using Matrix = Eigen::Matrix<Scalar_, Rows_, Cols_, Eigen::RowMajor>;
+
+ADComplex adcomplex(const ADComplex&& cpx)
 {
     return {cpx};
 }
 
-ADComplex adcomplex(const std::complex<double>& cpx)
+ADComplex adcomplex(const std::complex<double>&& cpx)
 {
     return {cpx};
 }
 
+namespace Eigen
+{
+    /* Eigen binop maps: (double, ADDouble) -> ADDouble */
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<ADDouble,double,BinaryOp> { typedef ADDouble ReturnType;  };
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<double,ADDouble,BinaryOp> { typedef ADDouble ReturnType;  };
+
+    /* Eigen binop maps: (double/complex, ADComplex) -> ADComplex */
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<ADComplex,double,BinaryOp> { typedef ADComplex ReturnType;  };
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<double,ADComplex,BinaryOp> { typedef ADComplex ReturnType;  };
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<ADComplex,Complex,BinaryOp> { typedef ADComplex ReturnType;  };
+    template<typename BinaryOp> struct ScalarBinaryOpTraits<Complex,ADComplex,BinaryOp> { typedef ADComplex ReturnType;  };
+
+    namespace usadelndsoc {
+        ADComplex operator*(const ADComplex& a, const double& b) { return a * adcomplex(b); }
+        ADComplex operator*(const double& a, const ADComplex& b) { return adcomplex(a) * b; }
+        ADComplex operator*(const ADComplex& a, const Complex& b) { return a * adcomplex(b); }
+        ADComplex operator*(const Complex& a, const ADComplex& b) { return adcomplex(a) * b; }
+
+        ADComplex operator+(const ADComplex& a, const double& b) { return a + adcomplex(b); }
+        ADComplex operator+(const double& a, const ADComplex& b) { return adcomplex(a) + b; }
+        ADComplex operator+(const ADComplex& a, const Complex& b) { return a + adcomplex(b); }
+        ADComplex operator+(const Complex& a, const ADComplex& b) { return adcomplex(a) + b; }
+
+        ADComplex operator-(const ADComplex& a, const double& b) { return a - adcomplex(b); }
+        ADComplex operator-(const double& a, const ADComplex& b) { return adcomplex(a) - b; }
+        ADComplex operator-(const ADComplex& a, const Complex& b) { return a - adcomplex(b); }
+        ADComplex operator-(const Complex& a, const ADComplex& b) { return adcomplex(a) - b; }
+
+        ADComplex operator/(const ADComplex& a, const double& b) { return a / adcomplex(b); }
+        ADComplex operator/(const double& a, const ADComplex& b) { return adcomplex(a) / b; }
+        ADComplex operator/(const ADComplex& a, const Complex& b) { return a / adcomplex(b); }
+        ADComplex operator/(const Complex& a, const ADComplex& b) { return adcomplex(a) / b; }
+    }
+    using namespace usadelndsoc;
+}
+
+using namespace Eigen::usadelndsoc;
+
 #endif
-- 
GitLab