diff --git a/src/core.cpp b/src/core.cpp index d69ee7e27a87d585ccab067074b0ff95fdd7ccb5..79072df8daeaa4b61fc9d6e6c16c5162f6b975e0 100644 --- a/src/core.cpp +++ b/src/core.cpp @@ -5,6 +5,9 @@ * * SPDX-License-Identifier: AGPL-3.0-or-later */ +#include <unordered_map> +#include <functional> + #include <pybind11/pybind11.h> #include <pybind11/complex.h> @@ -26,6 +29,16 @@ public: typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,2,2,2> > QType; typedef PyArrayView<const Mask,Shape<Dynamic,Dynamic> > MaskType; + struct pair_hash { + size_t operator () (const std::pair<size_t, size_t>& p) const + { + size_t h1 = std::hash<size_t>{}(p.first); + size_t h2 = std::hash<size_t>{}(p.second); + h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2); + return h1; + } + }; + private: UType U_; OmegaType Omega_; @@ -46,6 +59,8 @@ private: CppAD::sparse_rcv<SizeVector, ComplexVector> hes_subset_; bool computed_ = false; + std::unordered_map<std::pair<size_t,size_t>, bool, pair_hash> hes_pattern_map_; + pybind11::object py_sparse_matrix_type_; Vector<Complex> ravel(const QType& Q) @@ -67,7 +82,8 @@ private: public: Action(double Lx, double Ly, Complex alpha, MaskType&& mask, UType&& U, OmegaType&& Omega) - : U_(std::move(U)), Omega_(std::move(Omega)), mask_(std::move(mask)), nx_(U_.dim(0)), ny_(U_.dim(1)), Lx_(Lx), Ly_(Ly), alpha_(alpha), omega_(0) + : U_(std::move(U)), Omega_(std::move(Omega)), mask_(std::move(mask)), nx_(U_.dim(0)), ny_(U_.dim(1)), Lx_(Lx), Ly_(Ly), alpha_(alpha), omega_(0), + hes_pattern_map_() { if (Omega.dim(0) != nx_ || Omega.dim(1) != ny_) throw std::domain_error("U and Omega have incompatible shape"); @@ -145,6 +161,13 @@ public: hes_subset_ = hes_pattern_; + /* Update map */ + hes_pattern_map_.clear(); + auto r = hes_pattern_.row(); + auto c = hes_pattern_.col(); + for (size_t k = 0; k < hes_pattern_.nnz(); ++k) + hes_pattern_map_.emplace(std::make_pair(std::pair(r[k], c[k]), true)); + computed_ = true; } @@ -214,6 +237,46 @@ public: pybind11::make_tuple(std::move(val), pybind11::make_tuple(std::move(row), std::move(col))), pybind11::make_tuple(x.size(), x.size())); } + + void set_hess_pattern(PyArrayView<const int32_t,Shape<Dynamic> > row_in, + PyArrayView<const int32_t,Shape<Dynamic> > col_in) + { + size_t nnz_in = row_in.size(); + + if (row_in.size() != col_in.size()) + throw std::domain_error("row/col vectors of different size"); + + if (!computed_) + throw std::runtime_error("hessian not computed"); + + size_t n = hes_pattern_.nr(); + + /* Modified sparsity pattern */ + CppAD::sparse_rc<SizeVector> pat(n, n, nnz_in); + size_t nnz_out = 0; + for (size_t k = 0; k < nnz_in; ++k) { + size_t r = static_cast<size_t>(row_in(k)); + size_t c = static_cast<size_t>(col_in(k)); + if (hes_pattern_map_.find(std::make_pair(r, c)) != hes_pattern_map_.end()) + pat.set(nnz_out++, r, c); + } + pat.resize(n, n, nnz_out); + + hes_subset_ = pat; + hes_work_.clear(); + } + + auto get_hess_pattern() + { + if (!computed_) + throw std::runtime_error("hessian not computed"); + + size_t nnz = hes_pattern_.nnz(); + Array<size_t, Shape<Dynamic>, SizeVector> row(std::move(SizeVector(hes_pattern_.row())), {nnz}); + Array<size_t, Shape<Dynamic>, SizeVector> col(std::move(SizeVector(hes_pattern_.col())), {nnz}); + + return pybind11::make_tuple(std::move(row), std::move(col)); + } }; PYBIND11_MODULE(_core, m) @@ -231,6 +294,8 @@ PYBIND11_MODULE(_core, m) .def("grad", &Action::grad, "Evaluate gradient of action", py::arg("Q")) .def("hess", &Action::hess, "Evaluate hessian of action", py::arg("Q")) .def("hess_mul", &Action::hess_mul, "Evaluate hessian(Q) * dQ", py::arg("Q"), py::arg("dQ")) + .def("set_hess_pattern", &Action::set_hess_pattern, "Set Hessian sparsity pattern", py::arg("row"), py::arg("col")) + .def("get_hess_pattern", &Action::get_hess_pattern, "Get original Hessian sparsity pattern") .def("compute", &Action::compute, "Precompute for AD", py::arg("Q")) .def_property("omega", &Action::get_omega, &Action::set_omega) ;