/* nag_ode_bvp_coll_nlin_contin (d02txc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 * 
 * Mark 24, 2013.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd02.h>

#define Y(I, J)   y[J*neq + I-1]

typedef struct {
  double  el, en, s;
} func_data;

#ifdef __cplusplus
extern "C" {
#endif
static void NAG_CALL ffun(double x, const double y[], Integer neq, 
                          const Integer m[], double f[], Nag_Comm *comm);
static void NAG_CALL fjac(double x, const double y[], Integer neq, 
                          const Integer m[], double dfdy[], Nag_Comm *comm);
static void NAG_CALL gafun(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double ga[], Nag_Comm *comm);
static void NAG_CALL gbfun(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double gb[], Nag_Comm *comm);
static void NAG_CALL gajac(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double dgady[], Nag_Comm *comm);
static void NAG_CALL gbjac(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double dgbdy[], Nag_Comm *comm);
static void NAG_CALL guess(double x, Integer neq, const Integer m[], double y[],
                           double dym[], Nag_Comm *comm);
#ifdef __cplusplus
}
#endif

int main(void)
{
  static double ruser[7] = {-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0};
  Integer     exit_status = 0;
  Integer     neq, mmax, nlbc, nrbc, nleft, nright;
  Integer     i, iermx, ijermx, j, licomm, lrcomm, mxmesh, ncol, ncont, nmesh;
  double      xsplit = 30.0;
  double      dx, el, el_init, en, ermx, s, s_init, xx;
  double      *mesh = 0, *rcomm = 0;
  double      *tol = 0, *y = 0;
  Integer     *ipmesh = 0, *icomm = 0, *m = 0;
  func_data   fd;
  Nag_Comm    comm;
  NagError    fail;

  INIT_FAIL(fail);

  printf("nag_ode_bvp_coll_nlin_contin (d02txc) Example Program Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Skip heading in data file*/
  scanf("%*[^\n] ");
  scanf("%"NAG_IFMT "%"NAG_IFMT "", &neq, &mmax);
  scanf("%"NAG_IFMT "%"NAG_IFMT "%*[^\n] ", &nlbc, &nrbc);
  scanf("%"NAG_IFMT "%"NAG_IFMT "%*[^\n] ", &nleft, &nright);
  /* Read method parameters*/
  scanf("%"NAG_IFMT "%"NAG_IFMT "%"NAG_IFMT "%*[^\n] ", &ncol, &nmesh, &mxmesh);
  licomm = mxmesh * (11 * neq + 6);
  lrcomm = mxmesh * (109 *  pow((double) neq, 2) + 78 * neq + 7);
  /* Allocate memory */
  if (!(tol = NAG_ALLOC(neq, double)) ||
      !(y = NAG_ALLOC(neq*mmax, double)) ||
      !(m = NAG_ALLOC(neq, Integer)) ||
      !(mesh = NAG_ALLOC(mxmesh, double)) ||
      !(rcomm = NAG_ALLOC(lrcomm, double)) ||
      !(ipmesh = NAG_ALLOC(mxmesh, Integer)) ||
      !(icomm = NAG_ALLOC(licomm, Integer))
      )
    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }
  for (i = 0; i < neq; i++) {
    scanf("%"NAG_IFMT "", &m[i]);
  } 
  scanf("%*[^\n] ");
  for (i = 0; i < neq; i++) {
    scanf("%lf", &tol[i]);
  }
  scanf("%*[^\n] ");
  /* Read problem (initial) parameters*/
  scanf("%lf%lf%lf%*[^\n] ", &en, &el_init, &s_init);
  /* Initialise data*/
  el = el_init;
  s = s_init;

  /* Set data required for the user-supplied functions */
  fd.el = el;
  fd.en = en;
  fd.s = s;
  /* Associate the data structure with comm.p */
  comm.p = (Pointer) &fd;

  dx = 1.0/(double) (nmesh - 1);
  mesh[0] = 0.0;
  for (i = 1; i < nmesh - 1; i++) {
    mesh[i] = mesh[i - 1] + dx;
  }
  mesh[nmesh - 1] = 1.0;
  ipmesh[0] = 1;
  for (i = 1; i < nmesh - 1; i++) {
    ipmesh[i] = 2;
  }
  ipmesh[nmesh-1] = 1;

  /* nag_ode_bvp_coll_nlin_setup (d02tvc).
   * Ordinary differential equations, general nonlinear boundary value problem,
   * setup for nag_ode_bvp_coll_nlin_solve (d02tlc).
   */
  nag_ode_bvp_coll_nlin_setup(neq, m, nlbc, nrbc, ncol, tol, mxmesh, nmesh,
                              mesh, ipmesh, rcomm, lrcomm, icomm, licomm,
                              &fail);
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_ode_bvp_coll_nlin_setup (d02tvc).\n%s\n",
           fail.message);
    exit_status = 1;
    goto END;
  }

  /* Initialize number of continuation steps in el and s*/
  scanf("%"NAG_IFMT "%*[^\n] ", &ncont);

  for (j = 0; j < ncont; j++) {
    printf("\n Tolerance = %8.1e", tol[0]);
    printf("  l = %8.3f  s =%7.4f\n", el, s);
    /* Solve*/

    /* nag_ode_bvp_coll_nlin_solve (d02tlc).
     * Ordinary differential equations, general nonlinear boundary value 
     * problem, collocation technique.
     */
    nag_ode_bvp_coll_nlin_solve(ffun, fjac, gafun, gbfun, gajac, gbjac, guess,
                                rcomm, icomm, &comm, &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_ode_bvp_coll_nlin_solve (d02tlc).\n%s\n",
             fail.message);
      exit_status = 2;
      goto END;
    }

    /* Extract mesh*/

    /* nag_ode_bvp_coll_nlin_diag (d02tzc).
     * Ordinary differential equations, general nonlinear boundary value
     * problem, diagnostics for nag_ode_bvp_coll_nlin_solve (d02tlc).
     */
    nag_ode_bvp_coll_nlin_diag(mxmesh, &nmesh, mesh, ipmesh, &ermx, &iermx,
                               &ijermx, rcomm, icomm, &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_ode_bvp_coll_nlin_diag (d02tzc).\n%s\n",
             fail.message);
      exit_status = 3;
      goto END;
    }

    printf("\n Used a mesh of %4"NAG_IFMT "  points\n", nmesh);
    printf(" Maximum error = %10.2e", ermx);
    printf("  in interval %4"NAG_IFMT " ", iermx);
    printf(" for component %4"NAG_IFMT " \n", ijermx);
    /* Print solution components on mesh*/
    printf("\n\n Solution on original interval:\n     x        f          g\n");
    /* Left side domain [0,xsplit], evaluate at nleft+1 uniform grid points.*/
    dx = xsplit/(double) (nleft)/el;
    xx = 0.0;
    for (i = 0; i <= nleft; i++) {

      /* nag_ode_bvp_coll_nlin_interp (d02tyc).
       * Ordinary differential equations, general nonlinear boundary value 
       * problem, interpolation for nag_ode_bvp_coll_nlin_solve (d02tlc).
       */
      nag_ode_bvp_coll_nlin_interp(xx, y, neq, mmax, rcomm, icomm, &fail);
      if (fail.code != NE_NOERROR) {
        printf("Error from nag_ode_bvp_coll_nlin_interp (d02tyc).\n%s\n",
               fail.message);
        exit_status = 5;
        goto END;
      }

      printf("%8.2f %10.4f %10.4f \n", xx * el, Y(1, 0), Y(2, 0));
      xx = xx + dx;
    }
    /* Right side domain (xsplit,L], evaluate at nright uniform grid points.*/
    dx = (el - xsplit)/(double) (nright)/el;
    xx = xsplit/el;
    for (i = 0; i < nright; i++) {
      xx = MIN(1.0, xx + dx);

      /* nag_ode_bvp_coll_nlin_interp (d02tyc).
       * Ordinary differential equations, general nonlinear boundary value 
       * problem, interpolation for nag_ode_bvp_coll_nlin_solve (d02tlc).
       */
      nag_ode_bvp_coll_nlin_interp(xx, y, neq, mmax, rcomm, icomm, &fail);
      if (fail.code != NE_NOERROR) {
        printf("Error from nag_ode_bvp_coll_nlin_interp (d02tyc).\n%s\n",
               fail.message);
        exit_status = 6;
        goto END;
      }

      printf("%8.2f %10.4f %10.4f \n", xx * el, Y(1, 0), Y(2, 0));
    }
    /* Select mesh for continuation and update continuation parameters.*/
    if (j < ncont-1) {
      el = 2.0 * el;
      s = 0.6 * s;
      fd.el = el;
      fd.s = s;
      nmesh = (nmesh + 1)/2;

      /* nag_ode_bvp_coll_nlin_contin (d02txc).
       * Ordinary differential equations, general nonlinear boundary value 
       * problem, continuation facility for
       * nag_ode_bvp_coll_nlin_solve (d02tlc).
       */
      nag_ode_bvp_coll_nlin_contin(mxmesh, nmesh, mesh, ipmesh, rcomm, icomm,
                                   &fail);
      if (fail.code != NE_NOERROR) {
        printf("Error from nag_ode_bvp_coll_nlin_contin (d02txc).\n%s\n",
               fail.message);
        exit_status = 7;
        goto END;
      }
    }
  }

 END :
  NAG_FREE(mesh);
  NAG_FREE(m);
  NAG_FREE(tol);
  NAG_FREE(rcomm);
  NAG_FREE(y);
  NAG_FREE(ipmesh);
  NAG_FREE(icomm);
  return exit_status;
}

static void NAG_CALL ffun(double x, const double y[], Integer neq, 
                          const Integer m[], double f[], Nag_Comm *comm)
{
  func_data *fd = (func_data *)comm->p;
  double    el, en, s, t1, y11, y20;
  double    half = 0.5;
  double    one = 1.0;
  double    three = 3.0;

  if (comm->user[0] == -1.0)
    {
      printf("(User-supplied callback ffun, first invocation.)\n");
      comm->user[0] = 0.0;
    }
  el = fd->el;
  en = fd->en; 
  s = fd->s;
  t1 = half * (three - en) * Y(1, 0);
  y11 = Y(1, 1);
  y20 = Y(2, 0);
  f[0] = (pow(el, 3)) * (one - pow(y20, 2)) + (pow(el, 2)) * s * y11 - 
    el * (t1 * Y(1, 2) + en *  pow(y11, 2));
  f[1] = (pow(el, 2)) * s * (y20 - one) - el * 
    (t1 * Y(2, 1) + (en - one) * y11 * y20);
}

static void NAG_CALL fjac(double x, const double y[], Integer neq, 
                          const Integer m[], double dfdy[], Nag_Comm *comm)
{
#define DFDY(I, J, K) dfdy[I-1 + (J-1)* neq + K * neq * neq]
  func_data *fd = (func_data *)comm->p;
  double    el, en, s;
  double    half = 0.5;
  double    one = 1.0;
  double    two = 2.0;
  double    three = 3.0;

  if (comm->user[1] == -1.0)
    {
      printf("(User-supplied callback fjac, first invocation.)\n");
      comm->user[1] = 0.0;
    }
  el = fd->el;
  en = fd->en; 
  s = fd->s;
  DFDY(1, 2, 0) = -two * pow(el, 3) * Y(2, 0);
  DFDY(1, 1, 0) = -el * half * (three - en) * Y(1, 2);
  DFDY(1, 1, 1) = pow(el, 2) * s - el * two * en * Y(1, 1);
  DFDY(1, 1, 2) = -el * half * (three - en) * Y(1, 0);
  DFDY(2, 2, 0) = pow(el, 2) * s - el * (en - one) * Y(1, 1);
  DFDY(2, 2, 1) = -el * half * (three - en) * Y(1, 0);
  DFDY(2, 1, 0) = -el * half * (three - en) * Y(2, 1);
  DFDY(2, 1, 1) = -el * (en - one) * Y(2, 0);
}

static void NAG_CALL gafun(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double ga[], Nag_Comm *comm)
{
#define YA(I, J) ya[J * neq + I-1]

  if (comm->user[2] == -1.0)
    {
      printf("(User-supplied callback gafun, first invocation.)\n");
      comm->user[2] = 0.0;
    }
  ga[0] = YA(1, 0);
  ga[1] = YA(1, 1);
  ga[2] = YA(2, 0);
}

static void NAG_CALL gbfun(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double gb[], Nag_Comm *comm)
{
#define YB(I, J) yb[J * neq + I-1]

  if (comm->user[3] == -1.0)
    {
      printf("(User-supplied callback gbfun, first invocation.)\n");
      comm->user[3] = 0.0;
    }
  gb[0] = YB(1, 1);
  gb[1] = YB(2, 0) - 1.0;
}

static void NAG_CALL gajac(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double dgady[], Nag_Comm *comm)
{
#define DGADY(I, J, K) dgady[I-1 + (J-1)* nlbc + K * nlbc * neq]
  double  one = 1.0;

  if (comm->user[4] == -1.0)
    {
      printf("(User-supplied callback gajac, first invocation.)\n");
      comm->user[4] = 0.0;
    }
  DGADY(1, 1, 0) = one;
  DGADY(2, 1, 1) = one;
  DGADY(3, 2, 0) = one;
}

static void NAG_CALL gbjac(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double dgbdy[], Nag_Comm *comm)
{
#define DGBDY(I, J, K) dgbdy[I-1 + (J-1)* nrbc + K * nrbc * neq]
  double  one = 1.0;

  if (comm->user[5] == -1.0)
    {
      printf("(User-supplied callback gbjac, first invocation.)\n");
      comm->user[5] = 0.0;
    }
  DGBDY(1, 1, 1) = one;
  DGBDY(2, 2, 0) = one;
}

static void NAG_CALL guess(double x, Integer neq, const Integer m[], double y[],
                           double dym[], Nag_Comm *comm)
{
  func_data *fd = (func_data *)comm->p;
  double    ex, expmx;
  double    one = 1.0;
  double    two = 2.0;

  if (comm->user[6] == -1.0)
    {
      printf("(User-supplied callback guess, first invocation.)\n");
      comm->user[6] = 0.0;
    }
  ex = x * fd->el;
  expmx = exp(-ex);
  Y(1, 0) = -pow(ex, 2) * expmx;
  Y(1, 1) = (-two * ex +  pow(ex, 2)) * expmx;
  Y(1, 2) = (-two + 4.0 * ex -  pow(ex, 2)) * expmx;
  Y(2, 0) = one - expmx;
  Y(2, 1) = expmx;
  dym[0] = (6.0 - 6.0 * ex +  pow(ex, 2)) * expmx;
  dym[1] = -expmx;
}