/* nag_pde_parab_1d_keller_ode_remesh (d03prc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 *
 * Mark 7, 2001.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd03.h>
#include <nagx01.h>

#ifdef __cplusplus
extern "C" {
#endif
static void NAG_CALL pdedef(Integer, double, double, const double[],
                            const double[], const double[], Integer,
                            const double[], const double[], double[],
                            Integer *, Nag_Comm *);

static void NAG_CALL bndary(Integer, double, Integer, Integer, const double[],
                            const double[], Integer, const double[],
                            const double[], double[], Integer *, Nag_Comm *);

static void NAG_CALL uvinit(Integer, Integer, Integer, const double[],
                            const double[], double[], Integer, double[],
                            Nag_Comm *);

static void NAG_CALL monitf(double, Integer, Integer, const double[],
                            const double[], double[], Nag_Comm *);
#ifdef __cplusplus
}
#endif

static void exact(double, Integer, Integer, double *, double *);

#define UE(I, J)      ue[npde*((J) -1)+(I) -1]
#define U(I, J)       u[npde*((J) -1)+(I) -1]
#define UOUT(I, J, K) uout[npde*(intpts*((K) -1)+(J) -1)+(I) -1]


int main(void)
{
  const Integer npde = 2, npts = 61, ncode = 0, nxi = 0, nxfix = 0, nleft = 1;
  const Integer itype = 1, intpts = 5, neqn = npde*npts+ncode;
  const Integer lisave = 25+nxfix;
  const Integer nwkres = npde*(npts+3*npde+21)+7*npts+nxfix+3;
  const Integer lenode = 11*neqn+50, lrsave = neqn*neqn+neqn+nwkres+lenode;
  static double ruser[4] = {-1.0, -1.0, -1.0, -1.0};
  double        con, dxmesh, tout, trmesh, ts, xratio;
  Integer       exit_status = 0, i, ind, ipminf, it, itask, itol, itrace,
                nrmesh;
  Nag_Boolean   remesh, theta;
  double        *algopt = 0, *atol = 0, *rsave = 0, *rtol = 0, *u = 0, *ue = 0;
  double        *uout = 0, *x = 0, *xfix = 0, *xi = 0, *xout = 0;
  Integer       *isave = 0;
  NagError      fail;
  Nag_Comm      comm;
  Nag_D03_Save  saved;

  INIT_FAIL(fail);

  printf("nag_pde_parab_1d_keller_ode_remesh (d03prc) Example Program"
          " Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Allocate memory */
  if (!(algopt = NAG_ALLOC(30, double)) ||
      !(atol = NAG_ALLOC(1, double)) ||
      !(rsave = NAG_ALLOC(lrsave, double)) ||
      !(rtol = NAG_ALLOC(1, double)) ||
      !(u = NAG_ALLOC(npde*npts, double)) ||
      !(ue = NAG_ALLOC(npde*npts, double)) ||
      !(uout = NAG_ALLOC(npde*intpts*itype, double)) ||
      !(x = NAG_ALLOC(npts, double)) ||
      !(xfix = NAG_ALLOC(1, double)) ||
      !(xi = NAG_ALLOC(1, double)) ||
      !(xout = NAG_ALLOC(intpts, double)) ||
      !(isave = NAG_ALLOC(lisave, Integer)))
    {
      printf("Allocation failure\n");
      exit_status = 1;
      goto END;
    }

  itrace = 0;
  itol = 1;
  atol[0] = 5.0e-5;
  rtol[0] = atol[0];

  printf("  Accuracy requirement =%12.3e", atol[0]);
  printf(" Number of points = %3ld\n\n", npts);

  /* Set remesh parameters */

  remesh = Nag_TRUE;
  nrmesh = 3;
  dxmesh = 0.0;
  trmesh = 0.0;
  con = 5.0/(npts-1.0);
  xratio = 1.2;
  ipminf = 0;
  printf(" Remeshing every %3ld time steps\n\n", nrmesh);

  /* Initialise mesh */

  for (i = 0; i < npts; ++i) x[i] = i/(npts-1.0);

  xout[0] = 0.0;
  xout[1] = 0.25;
  xout[2] = 0.5;
  xout[3] = 0.75;
  xout[4] = 1.0;
  printf(" x        ");

  for (i = 0; i < intpts; ++i)
    {
      printf("%10.4f", xout[i]);
      printf((i+1)%5 == 0 || i == 4?"\n":" ");
    }
  printf("\n\n");

  xi[0] = 0.0;
  ind = 0;
  itask = 1;

  /* Set theta to TRUE if the Theta integrator is required */

  theta = Nag_FALSE;
  for (i = 0; i < 30; ++i) algopt[i] = 0.0;
  if (theta)
    {
      algopt[0] = 2.0;
      algopt[5] = 2.0;
      algopt[6] = 1.0;
    }

  /* Loop over output value of t */

  ts = 0.0;
  tout = 0.0;

  for (it = 0; it < 5; ++it)
    {
      tout = 0.05*(it+1);

      /* nag_pde_parab_1d_keller_ode_remesh (d03prc).
       * General system of first-order PDEs, coupled DAEs, method
       * of lines, Keller box discretisation, remeshing, one space
       * variable
       */
      nag_pde_parab_1d_keller_ode_remesh(npde, &ts, tout, pdedef, bndary,
                                         uvinit, u, npts, x, nleft, ncode,
                                         NULLFN, nxi, xi, neqn, rtol, atol,
                                         itol, Nag_TwoNorm, Nag_LinAlgFull,
                                         algopt, remesh, nxfix, xfix, nrmesh,
                                         dxmesh, trmesh, ipminf, xratio, con,
                                         monitf, rsave, lrsave, isave, lisave,
                                         itask, itrace, 0, &ind, &comm, &saved,
                                         &fail);

      if (fail.code != NE_NOERROR)
        {
          printf("Error from "
                  "nag_pde_parab_1d_keller_ode_remesh (d03prc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }

      /* Interpolate at output points */

      /* nag_pde_interp_1d_fd (d03pzc). PDEs, spatial interpolation with
       * nag_pde_parab_1d_keller_ode_remesh (d03prc).
       */
      nag_pde_interp_1d_fd(npde, 0, u, npts, x, xout, intpts, itype, uout,
                           &fail);

      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_pde_interp_1d_fd (d03pzc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }

      /* Check against exact solution */

      exact(ts, npde, intpts, xout, ue);

      printf(" t = %6.3f\n", ts);
      printf(" Approx u1");

      for (i = 1; i <= intpts; ++i)
        {
          printf("%10.4f", UOUT(1, i, 1));
          printf(i%5 == 0 || i == 5?"\n":"");
        }

      printf(" Exact  u1");

      for (i = 1; i <= 5; ++i)
        {
          printf("%10.4f", UE(1, i));
          printf(i%5 == 0 || i == 5?"\n":"");
        }

      printf(" Approx u2");

      for (i = 1; i <= 5; ++i)
        {
          printf("%10.4f", UOUT(2, i, 1));
          printf(i%5 == 0 || i == 5?"\n":"");
        }

      printf(" Exact  u2");

      for (i = 1; i <= 5; ++i)
        {
          printf("%10.4f", UE(2, i));
          printf(i%5 == 0 || i == 5?"\n":"");
        }

      printf("\n");
    }

  printf(" Number of integration steps in time = %6ld\n", isave[0]);
  printf(" Number of function evaluations = %6ld\n", isave[1]);
  printf(" Number of Jacobian evaluations =%6ld\n", isave[2]);
  printf(" Number of iterations = %6ld\n\n", isave[4]);

 END:
  NAG_FREE(algopt);
  NAG_FREE(atol);
  NAG_FREE(rsave);
  NAG_FREE(rtol);
  NAG_FREE(u);
  NAG_FREE(ue);
  NAG_FREE(uout);
  NAG_FREE(x);
  NAG_FREE(xfix);
  NAG_FREE(xi);
  NAG_FREE(xout);
  NAG_FREE(isave);

  return exit_status;
}


static void NAG_CALL uvinit(Integer npde, Integer npts, Integer nxi,
                            const double x[], const double xi[], double u[],
                            Integer ncode, double v[], Nag_Comm *comm)
{
  Integer i;

  if (comm->user[0] == -1.0)
    {
      printf("(User-supplied callback uvinit, first invocation.)\n");
      comm->user[0] = 0.0;
    }
  for (i = 1; i <= npts; ++i)
    {
      U(1, i) = exp(x[i-1]);
      U(2, i) = x[i-1]*x[i-1] + sin(2.0*nag_pi*(x[i-1]*x[i-1]));
    }
  return;
}


static void NAG_CALL pdedef(Integer npde, double t, double x, const double u[],
                            const double udot[], const double ux[],
                            Integer ncode, const double v[],
                            const double vdot[], double res[], Integer *ires,
                            Nag_Comm *comm)
{
  if (comm->user[1] == -1.0)
    {
      printf("(User-supplied callback pdedef, first invocation.)\n");
      comm->user[1] = 0.0;
    }
  if (*ires == -1)
    {
      res[0] = udot[0];
      res[1] = udot[1];
    }
  else
    {
      res[0] = udot[0] + ux[0] + ux[1];
      res[1] = udot[1] + 4.0*ux[0] + ux[1];
    }
  return;
}

static void NAG_CALL bndary(Integer npde, double t, Integer ibnd, Integer nobc,
                            const double u[], const double udot[],
                            Integer ncode, const double v[],
                            const double vdot[], double res[], Integer *ires,
                            Nag_Comm *comm)
{
  double pp;

  if (comm->user[2] == -1.0)
    {
      printf("(User-supplied callback bndary, first invocation.)\n");
      comm->user[2] = 0.0;
    }

  pp = 2.0*nag_pi;

  if (ibnd == 0)
    {
      if (*ires == -1)
        {
          res[0] = 0.0;
        }
      else
        {
          res[0] = u[0] - 0.5*(exp(t) + exp(-3.0*t))
                   - 0.25*(sin(9.0*pp*t* t) - sin(pp*t*t)) - 2.0*t*t;
        }
    }
  else
    {
      if (*ires == -1)
        {
          res[0] = 0.0;
        }
      else
        {
          res[0] = u[1] - (exp(1.0-3.0*t) - exp(1.0+t) +
                           0.5*sin(pp*(1.0-3.0*t)*(1.0-3.0*t)) +
                           0.5*sin(pp*(1.0+t)*(1.0+t))
                           + 1.0 + 5.0*t*t - 2.0*t);
        }
    }
  return;
}

static void NAG_CALL monitf(double t, Integer npts, Integer npde,
                            const double x[], const double u[], double fmon[],
                            Nag_Comm *comm)
{
  double  d2x1, d2x2, h1, h2, h3;
  Integer i;

  if (comm->user[3] == -1.0)
    {
      printf("(User-supplied callback monitf, first invocation.)\n");
      comm->user[3] = 0.0;
    }
  for (i = 2; i <= npts-1; ++i)
    {
      h1 = x[i - 1] - x[i - 2];
      h2 = x[i] - x[i - 1];
      h3 = 0.5* (x[i] - x[i - 2]);

      /* Second derivatives */

      d2x1 = fabs(((U(1, i+1)-U(1, i))/h2-(U(1, i)-U(1, i-1))/h1)/h3);
      d2x2 = fabs(((U(2, i+1)-U(2, i))/h2-(U(2, i)-U(2, i-1))/h1)/h3);
      fmon[i-1] = d2x1;
      if (d2x2 > d2x1) fmon[i- 1] = d2x2;
    }
  fmon[0] = fmon[1];
  fmon[npts- 1] = fmon[npts-2];

  return;
}

static void exact(double t, Integer npde, Integer npts, double *x, double *u)
{
  /* Exact solution (for comparison purposes) */

  double  pp;
  Integer i;

  pp = 2.0*nag_pi;
  for (i = 1; i <= npts; ++i)
    {
      U(1, i) = 0.5* (exp(x[i-1]+t) + exp(x[i-1]-3.0*t)) +
                0.25*(sin(pp*(x[i-1]-3.0*t)*(x[i-1]-3.0*t)) -
                      sin(pp*(x[i-1]+t)*(x[i-1]+t))) +
                2.0* t*t - 2.0*x[i-1]*t;

      U(2, i) = exp(x[i-1]-3.0*t) - exp(x[i-1]+t) +
                0.5*(sin(pp*((x[i-1]-3.0*t)*(x[i-1]-3.0*t))) +
                     sin(pp*((x[i-1]+t)*(x[i-1]+t)))) +
                x[i-1]*x[i-1] + 5.0*t*t - 2.0*x[i-1]*t;
    }
  return;
}