program main
  use adjac
  implicit none

  integer, parameter :: n = 35

  type(adjac_complexan), dimension(:,:,:,:,:), allocatable :: Q
  double complex, dimension(:,:,:,:,:), allocatable :: U
  type(adjac_complexan), dimension(1) :: res

  double complex, dimension(:, :), allocatable ::jac
  double complex :: v

  double complex, dimension(:), allocatable :: hess_v
  integer, dimension(:), allocatable :: hess_i
  integer, dimension(:), allocatable :: hess_j
  integer :: nnz

  integer i1, i2, i3, i4, i5, p, p2

  allocate(Q(n,n,2,2,2))
  allocate(U(n,n,4,4,4))
  allocate(jac(1, n*n*2*2*2))

  call adjac_reset()

  p = 1
  p2 = 1
  do i1 = 1, size(Q,1)
     do i2 = 1, size(Q,2)
        do i3 = 1, size(Q,3)
           do i4 = 1, size(Q,4)  
              do i5 = 1, size(Q,5)
                 v = p - 1
                 call adjac_set_independent(Q(i1,i2,i3,i4,i5), v, p)
                 p = p + 1
              end do
           end do
        end do
        do i3 = 1, size(U,3)
           do i4 = 1, size(U,4)  
              do i5 = 1, size(U,5)
                 U(i1,i2,i3,i4,i5) = p2 - 1
                 p2 = p2 + 1
              end do
           end do
        end do
     end do
  end do

  res(1) = S_2(Q, U, 0.1d0)

  call adjac_get_dense_jacobian(res, jac)

  do i1 = 1, 10
     write(*,*) '->', jac(1,i1)
  end do

  call adjac_get_coo_hessian(res(1), nnz, hess_v, hess_i, hess_j)

  do i1 = 1, nnz
     write(*,*) hess_i(i1), hess_j(i1), '->>', hess_v(i1)
  end do

  call adjac_reset(.false.)

  return
contains
  function inverse(M) result(W)
    use adjac
    implicit none
    type(adjac_complexan), dimension(2,2), intent(in) :: M
    type(adjac_complexan), dimension(2,2) :: W

    type(adjac_complexan) :: det

    det = M(1,1)*M(2,2) - M(1,2)*M(2,1)

    W(1,1) = M(2,2) / det
    W(2,2) = M(1,1) / det
    W(1,2) = -M(1,2) / det
    W(2,1) = -M(2,1) / det
  end function inverse

  function get_Q_matrix(Q) result(Qm)
    use adjac
    implicit none
    type(adjac_complexan), dimension(2,2,2), target, intent(in) :: Q
    type(adjac_complexan), dimension(:,:), pointer :: g
    type(adjac_complexan), dimension(:,:), pointer :: gt
    type(adjac_complexan), dimension(4,4) :: Qm

    double complex, dimension(2,2) :: II
    type(adjac_complexan), dimension(2,2) :: N, Nt

    g => Q(1,:,:)
    gt => Q(2,:,:)

    II = 0
    II(1,1) = 1
    II(2,2) = 1

    N = inverse(II + matmul(g,gt))
    Nt = inverse(II + matmul(gt,g))

    Qm(1:2,1:2) = matmul(N, II - matmul(g, gt))
    Qm(1:2,3:4) = 2 * matmul(N, g)
    Qm(3:4,1:2) = 2 * matmul(Nt, gt)
    Qm(3:4,3:4) = -matmul(Nt, II - matmul(gt, g))
  end function get_Q_matrix

  function trace(M) result(res)
    use adjac
    implicit none
    type(adjac_complexan), dimension(4,4), intent(in) :: M
    type(adjac_complexan) :: res

    integer :: i

    res = 0
    do i = 1, size(M,1)
       res = res + M(i,i)
    end do
  end function trace

  function S_2(Q, U, h) result(res)
    use adjac
    implicit none

    type(adjac_complexan), dimension(:,:,:,:,:), intent(in) :: Q
    double complex, dimension(:,:,:,:,:), intent(in) :: U
    double precision, intent(in) :: h
    type(adjac_complexan) :: res

    integer, dimension(2), parameter :: Ni = (/ 0, 1 /)
    integer, dimension(2), parameter :: Nj = (/ 1, 0 /)
    integer, dimension(4), parameter :: invk = (/ 4, 3, 2, 1 /)

    type(adjac_complexan), dimension(:,:,:,:), allocatable :: Qmat
    logical, dimension(:,:), allocatable :: initQ

    type(adjac_complexan), dimension(4,4) :: M, Q1, Q2
    double complex, dimension(4,4) :: U12, U21

    integer :: i, j, k, i2, j2, nx, ny

    nx = size(Q, 1)
    ny = size(Q, 2)
    if (size(Q, 3) .ne. 2 .or. size(Q, 4) .ne. 2 .or. size(Q, 5) .ne. 2) then
       write(*,*) 'wrong size for Q'
       stop
    end if
    if (size(U, 3) .ne. 4 .or. size(U, 4) .ne. 4 .or. size(U, 5) .ne. 4) then
       write(*,*) 'wrong size for U'
       stop
    end if

    allocate(Qmat(nx,ny,4,4))
    allocate(initQ(ny,nx))
    initQ = .false.
    
    res = 0

    do i = 1, nx
       do j = 1, ny
          if (.not. initQ(i,j)) then
             Qmat(i,j,:,:) = get_Q_matrix(Q(i,j,:,:,:))
             initQ(i,j) = .true.
          end if

          do k = 1, 2
             i2 = i + Ni(k)
             j2 = j + Nj(k)

             if (i2 < 1 .or. i2 > nx) cycle
             if (j2 < 1 .or. j2 > ny) cycle

             if (.not. initQ(i2,j2)) then
                Qmat(i2,j2,:,:) = get_Q_matrix(Q(i2,j2,:,:,:))
                initQ(i2,j2) = .true.
             end if

             Q1 = Qmat(i,j,:,:)
             Q2 = Qmat(i2,j2,:,:)

             U12 = U(i,j,k,:,:)
             U21 = U(i2,j2,invk(k),:,:)

             U12(1:2,3:4) = 0
             U12(3:4,1:2) = 0
             U21(1:2,3:4) = 0
             U21(3:4,1:2) = 0

             M = Q1 - matmul(U12, matmul(Q2, U21))
             res = res + 2/(4*h) * trace(matmul(M, M))
          end do
       end do
    end do
  end function S_2
end program main