#include <pybind11/pybind11.h>
#include <pybind11/complex.h>

#include "common.hpp"
#include "array.hpp"
#include "action.hpp"

#include "pythonutil.hpp"

using namespace array;
using namespace array::python;

#define CLASS_Action PY_VISIBILITY Action
class CLASS_Action
{
public:
    typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,4,4,4> > UType;
    typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;

private:
    UType U_;
    size_t nx_, ny_;
    double h_;

public:
    Action(double h, UType&& U)
        : U_(std::move(U)), nx_(U_.dim(0)), ny_(U_.dim(1)), h_(h)
        {}

    Complex eval(QType Q)
        {
            if (Q.dim(0) != nx_ || Q.dim(1) != ny_)
                throw std::out_of_range("Q array wrong size");
            return S_2(Q, U_, h_);
        }

    CppAD::ADFun<double> eval_ad(QType& Qval)
        {
            if (Qval.dim(0) != nx_ || Qval.dim(1) != ny_)
                throw std::out_of_range("Q array wrong size");

            Array<ADComplex, Shape<Dynamic,Dynamic,2,2,2>, std::vector<ADDouble> >
                Q({nx_, ny_, 2, 2, 2});
            Array<ADComplex, Shape<1>, std::vector<ADDouble> > res;

            auto a = Qval.reshape<Dynamic>({Qval.size()});
            auto b = Q.reshape<Dynamic>({Q.size()});
            for (size_t i = 0; i < a.dim(0); ++i)
                b(i) = a(i);

            CppAD::Independent(Q.storage());

            res(0u) = S_2(Q, U_, h_);

            return {Q.storage(), res.storage()};
        }

    auto grad(QType Q)
        {
            Vector<double> x(reinterpret_cast<double *>(const_cast<Complex *>(Q.data())), 2 * Q.size());

            CppAD::ADFun<double> f = eval_ad(Q);
            auto jac = f.Jacobian(x);

            std::array<size_t, 1> shape = {jac.size()/2};
            return Array<Complex, Shape<Dynamic>, Vector<double> >(std::move(jac), shape);
        }
};

PYBIND11_MODULE(_core, m)
{
    namespace py = pybind11;

    import_array1();

    m.doc() = "usadelndsoc._core";

    py::class_<Action>(m, "Action")
        .def(py::init<double, Action::UType>())
        .def("eval", &Action::eval)
        .def("grad", &Action::grad)
        ;
}