diff --git a/src/core.cpp b/src/core.cpp
index d69ee7e27a87d585ccab067074b0ff95fdd7ccb5..79072df8daeaa4b61fc9d6e6c16c5162f6b975e0 100644
--- a/src/core.cpp
+++ b/src/core.cpp
@@ -5,6 +5,9 @@
  *
  * SPDX-License-Identifier: AGPL-3.0-or-later
  */
+#include <unordered_map>
+#include <functional>
+
 #include <pybind11/pybind11.h>
 #include <pybind11/complex.h>
 
@@ -26,6 +29,16 @@ public:
     typedef PyArrayView<const Complex,Shape<Dynamic,Dynamic,2,2,2> > QType;
     typedef PyArrayView<const Mask,Shape<Dynamic,Dynamic> > MaskType;
 
+    struct pair_hash {
+        size_t operator () (const std::pair<size_t, size_t>& p) const
+            {
+                size_t h1 = std::hash<size_t>{}(p.first);
+                size_t h2 = std::hash<size_t>{}(p.second);
+                h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
+                return h1;
+            }
+    };
+
 private:
     UType U_;
     OmegaType Omega_;
@@ -46,6 +59,8 @@ private:
     CppAD::sparse_rcv<SizeVector, ComplexVector> hes_subset_;
     bool computed_ = false;
 
+    std::unordered_map<std::pair<size_t,size_t>, bool, pair_hash> hes_pattern_map_;
+
     pybind11::object py_sparse_matrix_type_;
 
     Vector<Complex> ravel(const QType& Q)
@@ -67,7 +82,8 @@ private:
 
 public:
     Action(double Lx, double Ly, Complex alpha, MaskType&& mask, UType&& U, OmegaType&& Omega)
-        : U_(std::move(U)), Omega_(std::move(Omega)), mask_(std::move(mask)), nx_(U_.dim(0)), ny_(U_.dim(1)), Lx_(Lx), Ly_(Ly), alpha_(alpha), omega_(0)
+        : U_(std::move(U)), Omega_(std::move(Omega)), mask_(std::move(mask)), nx_(U_.dim(0)), ny_(U_.dim(1)), Lx_(Lx), Ly_(Ly), alpha_(alpha), omega_(0),
+          hes_pattern_map_()
         {
             if (Omega.dim(0) != nx_ || Omega.dim(1) != ny_)
                 throw std::domain_error("U and Omega have incompatible shape");
@@ -145,6 +161,13 @@ public:
 
             hes_subset_ = hes_pattern_;
 
+            /* Update map */
+            hes_pattern_map_.clear();
+            auto r = hes_pattern_.row();
+            auto c = hes_pattern_.col();
+            for (size_t k = 0; k < hes_pattern_.nnz(); ++k)
+                hes_pattern_map_.emplace(std::make_pair(std::pair(r[k], c[k]), true));
+
             computed_ = true;
         }
 
@@ -214,6 +237,46 @@ public:
                 pybind11::make_tuple(std::move(val), pybind11::make_tuple(std::move(row), std::move(col))),
                 pybind11::make_tuple(x.size(), x.size()));
         }
+
+    void set_hess_pattern(PyArrayView<const int32_t,Shape<Dynamic> > row_in,
+                          PyArrayView<const int32_t,Shape<Dynamic> > col_in)
+        {
+            size_t nnz_in = row_in.size();
+
+            if (row_in.size() != col_in.size())
+                throw std::domain_error("row/col vectors of different size");
+
+            if (!computed_)
+                throw std::runtime_error("hessian not computed");
+
+            size_t n = hes_pattern_.nr();
+
+            /* Modified sparsity pattern */
+            CppAD::sparse_rc<SizeVector> pat(n, n, nnz_in);
+            size_t nnz_out = 0;
+            for (size_t k = 0; k < nnz_in; ++k) {
+                size_t r = static_cast<size_t>(row_in(k));
+                size_t c = static_cast<size_t>(col_in(k));
+                if (hes_pattern_map_.find(std::make_pair(r, c)) != hes_pattern_map_.end())
+                    pat.set(nnz_out++, r, c);
+            }
+            pat.resize(n, n, nnz_out);
+
+            hes_subset_ = pat;
+            hes_work_.clear();
+        }
+
+    auto get_hess_pattern()
+        {
+            if (!computed_)
+                throw std::runtime_error("hessian not computed");
+
+            size_t nnz = hes_pattern_.nnz();
+            Array<size_t, Shape<Dynamic>, SizeVector> row(std::move(SizeVector(hes_pattern_.row())), {nnz});
+            Array<size_t, Shape<Dynamic>, SizeVector> col(std::move(SizeVector(hes_pattern_.col())), {nnz});
+
+            return pybind11::make_tuple(std::move(row), std::move(col));
+        }
 };
 
 PYBIND11_MODULE(_core, m)
@@ -231,6 +294,8 @@ PYBIND11_MODULE(_core, m)
         .def("grad", &Action::grad, "Evaluate gradient of action", py::arg("Q"))
         .def("hess", &Action::hess, "Evaluate hessian of action", py::arg("Q"))
         .def("hess_mul", &Action::hess_mul, "Evaluate hessian(Q) * dQ", py::arg("Q"), py::arg("dQ"))
+        .def("set_hess_pattern", &Action::set_hess_pattern, "Set Hessian sparsity pattern", py::arg("row"), py::arg("col"))
+        .def("get_hess_pattern", &Action::get_hess_pattern, "Get original Hessian sparsity pattern")
         .def("compute", &Action::compute, "Precompute for AD", py::arg("Q"))
         .def_property("omega", &Action::get_omega, &Action::set_omega)
         ;