/* nag_pde_parab_1d_coll_ode (d03pjc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 *
 * Mark 7, 2001.
 * Mark 7b revised, 2004.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd03.h>

#ifdef __cplusplus
extern "C" {
#endif
static void NAG_CALL pdedef(Integer, double, const double[], Integer, const
                            double[], const double[], Integer, const double[],
                            const double[], double[], double[], double[],
                            Integer *, Nag_Comm *);
static void NAG_CALL bndary(Integer, double, const double[], const double[],
                            Integer, const double[], const double[], Integer,
                            double[], double [], Integer *, Nag_Comm *);
static void NAG_CALL odedef(Integer, double, Integer, const double[],
                            const double[], Integer, const double[],
                            const double[], const double[], const double[],
                            const double[], const double[], double[],
                            Integer *, Nag_Comm *);
static void NAG_CALL uvinit(Integer, Integer, const double[], double[],
                            Integer, double[], Nag_Comm *);
static void NAG_CALL exact(double, Integer, double *, double *);
#ifdef __cplusplus
}
#endif


#define U(I, J)    u[npde*((J) -1)+(I) -1]
#define UX(I, J)   ux[npde*((J) -1)+(I) -1]
#define UCP(I, J)  ucp[npde*((J) -1)+(I) -1]
#define UCPX(I, J) ucpx[npde*((J) -1)+(I) -1]
#define P(I, J, K) p[npde*(npde*((K) -1)+(J) -1)+(I) -1]
#define Q(I, J)    q[npde*((J) -1)+(I) -1]
#define R(I, J)    r[npde*((J) -1)+(I) -1]


int main(void)
{
  const Integer npde = 1, ncode = 1, npoly = 2, m = 0, nbkpts = 11;
  const Integer nel = nbkpts-1, npts = nel*npoly+1, neqn = npde*npts+ncode;
  const Integer nxi = 1, lisave = 24, npl1 = npoly+1;
  const Integer nwkres = 3*npl1*npl1+npl1*(npde*npde+6*npde+nbkpts+1)+8*npde
                         +nxi*(5*npde+1)+ncode+3;
  const Integer lenode = 11*neqn+50, lrsave = neqn*neqn+neqn+nwkres+lenode;
  static double ruser[4] = {-1.0, -1.0, -1.0, -1.0};
  double        tout, ts;
  Integer       exit_status = 0, i, ind, it, itask, itol, itrace;
  Nag_Boolean   theta;
  double        *algopt = 0, *atol = 0, *exy = 0, *rsave = 0, *rtol = 0;
  double        *u = 0, *x = 0, *xbkpts = 0, *xi = 0;
  Integer       *isave = 0;
  NagError      fail;
  Nag_Comm      comm;
  Nag_D03_Save  saved;

  INIT_FAIL(fail);

  printf(
          " nag_pde_parab_1d_coll_ode (d03pjc) Example Program Results\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Allocate memory */

  if (!(algopt = NAG_ALLOC(30, double)) ||
      !(atol = NAG_ALLOC(1, double)) ||
      !(exy = NAG_ALLOC(nbkpts, double)) ||
      !(rsave = NAG_ALLOC(lrsave, double)) ||
      !(rtol = NAG_ALLOC(1, double)) ||
      !(u = NAG_ALLOC(neqn, double)) ||
      !(x = NAG_ALLOC(npts, double)) ||
      !(xbkpts = NAG_ALLOC(nbkpts, double)) ||
      !(xi = NAG_ALLOC(nxi, double)) ||
      !(isave = NAG_ALLOC(lisave, Integer)))
    {
      printf("Allocation failure\n");
      exit_status = 1;
      goto END;
    }

  itrace = 0;
  itol = 1;
  atol[0] = 1e-4;
  rtol[0] = atol[0];
  printf(" Degree of Polynomial =%4ld", npoly);
  printf("   No. of elements =%4ld\n\n\n", nbkpts-1);
  printf("  Simple coupled PDE using BDF\n ");
  printf(" Accuracy requirement =%12.3e", atol[0]);
  printf(" Number of points = %4ld\n\n", npts);

  /* Set break-points */

  for (i = 0; i < nbkpts; ++i) xbkpts[i] = i/(nbkpts-1.0);

  xi[0] = 1.0;
  ind = 0;
  itask = 1;

  /* Set theta = TRUE if the Theta integrator is required */

  theta = Nag_FALSE;
  for (i = 0; i < 30; ++i) algopt[i] = 0.0;

  if (theta)
    {
      algopt[0] = 2.0;
    }
  else
    {
      algopt[0] = 0.0;
    }

  /* Loop over output value of t */

  ts = 1.e-4;
  comm.p = (Pointer)&ts;
  tout = 0.0;
  printf("  x        %9.3f%9.3f%9.3f%9.3f%9.3f\n\n",
          xbkpts[0], xbkpts[2], xbkpts[4], xbkpts[6], xbkpts[10]);

  for (it = 0; it < 5; ++it)
    {
      tout = 0.1*pow((double) npoly, (it+1.0));
      /* nag_pde_parab_1d_coll_ode (d03pjc).
       * General system of parabolic PDEs, coupled DAEs, method of
       * lines, Chebyshev C^0 collocation, one space variable
       */
      nag_pde_parab_1d_coll_ode(npde, m, &ts, tout, pdedef, bndary, u, nbkpts,
                                xbkpts, npoly, npts, x, ncode, odedef, nxi, xi,
                                neqn, uvinit, rtol, atol, itol, Nag_TwoNorm,
                                Nag_LinAlgFull, algopt, rsave, lrsave, isave,
                                lisave, itask, itrace, 0, &ind, &comm, &saved,
                                &fail);

      if (fail.code != NE_NOERROR)
        {
          printf(
                  "Error from nag_pde_parab_1d_coll_ode (d03pjc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }

      /* Check against the exact solution */

      exact(tout, nbkpts, xbkpts, exy);
      printf(" t = %6.3f\n", ts);
      printf(" App.  sol.  %7.3f%9.3f%9.3f%9.3f%9.3f",
              u[0], u[4], u[8], u[12], u[20]);
      printf("  ODE sol. =%8.3f\n", u[21]);
      printf(" Exact sol.  %7.3f%9.3f%9.3f%9.3f%9.3f",
              exy[0], exy[2], exy[4], exy[6], exy[10]);
      printf("  ODE sol. =%8.3f\n\n", ts);
    }
  printf(" Number of integration steps in time = %6ld\n", isave[0]);
  printf(" Number of function evaluations = %6ld\n", isave[1]);
  printf(" Number of Jacobian evaluations =%6ld\n", isave[2]);
  printf(" Number of iterations = %6ld\n\n", isave[4]);
 END:
  NAG_FREE(algopt);
  NAG_FREE(atol);
  NAG_FREE(exy);
  NAG_FREE(rsave);
  NAG_FREE(rtol);
  NAG_FREE(u);
  NAG_FREE(x);
  NAG_FREE(xbkpts);
  NAG_FREE(xi);
  NAG_FREE(isave);

  return exit_status;
}



static void NAG_CALL uvinit(Integer npde, Integer npts, const double x[],
                            double u[], Integer ncode, double v[],
                            Nag_Comm *comm)
{
  /* Routine for PDE initial values (start time is 0.1e-6) */

  double  *ts = (double *) comm->p;
  Integer i;

  if (comm->user[0] == -1.0)
    {
      printf("(User-supplied callback uvinit, first invocation.)\n");
      comm->user[0] = 0.0;
    }
  v[0] = *ts;
  for (i = 1; i <= npts; ++i) U(1, i) = exp(*ts*(1.0- x[i-1])) - 1.0;
  return;
}


static void NAG_CALL odedef(Integer npde, double t, Integer ncode,
                            const double v[], const double vdot[], Integer nxi,
                            const double xi[], const double ucp[],
                            const double ucpx[], const double rcp[],
                            const double ucpt[], const double ucptx[],
                            double f[], Integer *ires, Nag_Comm *comm)
{
  if (comm->user[1] == -1.0)
    {
      printf("(User-supplied callback odedef, first invocation.)\n");
      comm->user[1] = 0.0;
    }
  if (*ires == 1)
    {
      f[0] = vdot[0] - v[0]*UCP(1, 1) - UCPX(1, 1) - 1.0 - t;
    }
  else if (*ires == -1)
    {
      f[0] = vdot[0];
    }
  return;
}


static void NAG_CALL pdedef(Integer npde, double t, const double x[],
                            Integer nptl, const double u[], const double ux[],
                            Integer ncode, const double v[],
                            const double vdot[], double p[], double q[],
                            double r[], Integer *ires, Nag_Comm *comm)
{
  Integer i;
  if (comm->user[2] == -1.0)
    {
      printf("(User-supplied callback pdedef, first invocation.)\n");
      comm->user[2] = 0.0;
    }
  for (i = 1; i <= nptl; ++i)
    {
      P(1, 1, i) = v[0]*v[0];
      R(1, i) = UX(1, i);
      Q(1, i) = -x[i-1]*UX(1, i)*v[0]*vdot[0];
    }
  return;
}


static void NAG_CALL bndary(Integer npde, double t, const double u[],
                            const double ux[], Integer ncode, const double v[],
                            const double vdot[], Integer ibnd, double beta[],
                            double gamma[], Integer *ires, Nag_Comm *comm)
{
  if (comm->user[3] == -1.0)
    {
      printf("(User-supplied callback bndary, first invocation.)\n");
      comm->user[3] = 0.0;
    }
  beta[0] = 1.0;
  if (ibnd == 0)
    {
      gamma[0] = -v[0]* exp(t);
    }
  else
    {
      gamma[0] = -v[0]* vdot[0 ];
    }
  return;
}


static void NAG_CALL exact(double time, Integer npts, double *x, double *u)
{
  /* Exact solution (for comparison purposes) */

  Integer i;

  for (i = 0; i < npts; ++i)
    u[i] = exp(time*(1.0 - x[i])) - 1.0;
  return;
}