/* nag_pde_parab_1d_fd_ode_remesh (d03ppc) Example Program.
 *
 * NAGPRODCODE Version.
 *
 * Copyright 2016 Numerical Algorithms Group.
 *
 * Mark 26, 2016.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd03.h>

#ifdef __cplusplus
extern "C"
{
#endif
  static void NAG_CALL pdedef(Integer, double, double, const double[],
                              const double[], Integer, const double[],
                              const double[], double[], double[], double[],
                              Integer *, Nag_Comm *);

  static void NAG_CALL bndary(Integer, double, const double[], const double[],
                              Integer, const double[], const double[],
                              Integer, double[], double[], Integer *,
                              Nag_Comm *);

  static void NAG_CALL uvinit(Integer, Integer, Integer, const double[],
                              const double[], double[], Integer, double[],
                              Nag_Comm *);

  static void NAG_CALL monitf(double, Integer, Integer, const double[],
                              const double[], const double[], double[],
                              Nag_Comm *);
#ifdef __cplusplus
}
#endif

static void exact(double, double *, Integer, double *, Nag_Comm *);

#define P(I, J)       p[npde*((J) -1)+(I) -1]
#define R(I, J)       r[npde*((J) -1)+(I) -1]
#define U(I, J)       u[npde*((J) -1)+(I) -1]
#define UOUT(I, J, K) uout[npde*(intpts*((K) -1)+(J) -1)+(I) -1]

int main(void)
{
  const Integer npde = 1, npts = 61, ncode = 0, m = 0, nxi = 0, nxfix = 0;
  const Integer itype = 1, neqn = npde * npts + ncode, intpts = 5;
  const Integer lisave = 25 + nxfix;
  const Integer nwkres = npde * (npts + 3 * npde + 21) + 7 * npts + nxfix + 3;
  const Integer lenode = 11 * neqn + 50, lrsave =
         neqn * neqn + neqn + nwkres + lenode;
  static double ruser[4] = { -1.0, -1.0, -1.0, -1.0 };
  double con, dxmesh, e, tout, trmesh, ts, xratio;
  Integer exit_status, i, ind, ipminf, it, itask, itol, itrace, nrmesh;
  Nag_Boolean remesh, theta;
  double *algopt = 0, *atol = 0, *rsave = 0, *rtol = 0, *u = 0, *ue = 0;
  double *uout = 0, *x = 0, *xfix = 0, *xi = 0, *xout = 0;
  Integer *isave = 0;
  NagError fail;
  Nag_Comm comm;
  Nag_D03_Save saved;

  INIT_FAIL(fail);

  exit_status = 0;

  /* Allocate memory */

  if (!(algopt = NAG_ALLOC(30, double)) ||
      !(atol = NAG_ALLOC(1, double)) ||
      !(rsave = NAG_ALLOC(lrsave, double)) ||
      !(rtol = NAG_ALLOC(1, double)) ||
      !(u = NAG_ALLOC(neqn, double)) ||
      !(ue = NAG_ALLOC(intpts, double)) ||
      !(uout = NAG_ALLOC(npde * intpts * itype, double)) ||
      !(x = NAG_ALLOC(npts, double)) ||
      !(xfix = NAG_ALLOC(1, double)) ||
      !(xi = NAG_ALLOC(1, double)) ||
      !(xout = NAG_ALLOC(intpts, double)) ||
      !(isave = NAG_ALLOC(lisave, Integer)))
  {
    printf("Allocation failure\n");
    exit_status = 1;
    goto END;
  }

  printf("nag_pde_parab_1d_fd_ode_remesh (d03ppc) Example Program"
         " Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  e = 0.005;
  comm.p = (Pointer) &e;
  itrace = 0;
  itol = 1;
  atol[0] = 5e-5;
  rtol[0] = atol[0];

  printf("  Accuracy requirement =%12.3e", atol[0]);
  printf(" Number of points = %3" NAG_IFMT "\n\n", npts);

  /* Initialize mesh */

  for (i = 0; i < npts; ++i)
    x[i] = i / (npts - 1.0);

  /* Set remesh parameters */

  remesh = Nag_TRUE;
  nrmesh = 3;
  dxmesh = 0.5;
  trmesh = 0.0;
  con = 2.0 / (npts - 1.0);
  xratio = 1.5;
  ipminf = 0;

  printf("  Remeshing every %3" NAG_IFMT " time steps\n\n", nrmesh);
  printf("  e =%8.3f\n\n\n", e);

  xi[0] = 0.0;
  ind = 0;
  itask = 1;

  /* Set theta to TRUE if the Theta integrator is required */

  theta = Nag_FALSE;
  for (i = 0; i < 30; ++i)
    algopt[i] = 0.0;
  if (theta) {
    algopt[0] = 2.0;
  }
  else {
    algopt[0] = 0.0;
  }

  /* Loop over output value of t */

  ts = 0.0;
  for (it = 0; it < 5; ++it) {
    tout = 0.2 * (it + 1);

    /* nag_pde_parab_1d_fd_ode_remesh (d03ppc).
     * General system of parabolic PDEs, coupled DAEs, method of
     * lines, finite differences, remeshing, one space variable
     */
    nag_pde_parab_1d_fd_ode_remesh(npde, m, &ts, tout, pdedef, bndary,
                                   uvinit, u, npts, x, ncode, NULLFN, nxi,
                                   xi, neqn, rtol, atol, itol, Nag_TwoNorm,
                                   Nag_LinAlgFull, algopt, remesh, nxfix,
                                   xfix, nrmesh, dxmesh, trmesh, ipminf,
                                   xratio, con, monitf, rsave, lrsave, isave,
                                   lisave, itask, itrace, 0, &ind, &comm,
                                   &saved, &fail);

    if (fail.code != NE_NOERROR) {
      printf("Error from nag_pde_parab_1d_fd_ode_remesh (d03ppc).\n%s\n",
             fail.message);
      exit_status = 1;
      goto END;
    }

    /* Set output points */

    switch (it) {
    case 0:
      for (i = 0; i < 5; ++i)
        xout[i] = 0.3 + 0.1 * i;
      break;
    case 1:
      for (i = 0; i < 5; ++i)
        xout[i] = 0.4 + 0.1 * i;
      break;
    case 2:
      for (i = 0; i < 5; ++i)
        xout[i] = 0.6 + 0.05 * i;
      break;
    case 3:
      for (i = 0; i < 5; ++i)
        xout[i] = 0.7 + 0.05 * i;
      break;
    case 4:
      for (i = 0; i < 5; ++i)
        xout[i] = 0.8 + 0.05 * i;
      break;
    }

    printf(" t = %6.3f\n", ts);
    printf(" x           ");

    for (i = 0; i < 5; ++i) {
      printf("%9.4f", xout[i]);
      printf((i + 1) % 5 == 0 || i == 4 ? "\n" : " ");
    }

    /* Interpolate at output points */

    /* nag_pde_interp_1d_fd (d03pzc). PDEs, spatial interpolation with
     * nag_pde_parab_1d_fd_ode_remesh (d03ppc),
     */
    nag_pde_interp_1d_fd(npde, m, u, npts, x, xout, intpts, itype, uout,
                         &fail);

    if (fail.code != NE_NOERROR) {
      printf("Error from nag_pde_interp_1d_fd (d03pzc).\n%s\n", fail.message);
      exit_status = 1;
      goto END;
    }

    /* Check against exact solution */

    exact(ts, xout, intpts, ue, &comm);

    printf(" Approx sol. ");

    for (i = 1; i <= intpts; ++i) {
      printf("%9.4f", UOUT(1, i, 1));
      printf(i % 5 == 0 || i == 5 ? "\n" : " ");
    }

    printf(" Exact  sol. ");

    for (i = 1; i <= intpts; ++i) {
      printf("%9.4f", ue[i - 1]);
      printf(i % 5 == 0 || i == 5 ? "\n" : " ");
    }
    printf("\n");
  }

  printf(" Number of integration steps in time = %6" NAG_IFMT "\n", isave[0]);
  printf(" Number of function evaluations = %6" NAG_IFMT "\n", isave[1]);
  printf(" Number of Jacobian evaluations = %6" NAG_IFMT "\n", isave[2]);
  printf(" Number of iterations = %6" NAG_IFMT "\n\n", isave[4]);

END:
  NAG_FREE(algopt);
  NAG_FREE(atol);
  NAG_FREE(rsave);
  NAG_FREE(rtol);
  NAG_FREE(u);
  NAG_FREE(ue);
  NAG_FREE(uout);
  NAG_FREE(x);
  NAG_FREE(xfix);
  NAG_FREE(xi);
  NAG_FREE(xout);
  NAG_FREE(isave);

  return exit_status;
}

static void NAG_CALL uvinit(Integer npde, Integer npts, Integer nxi,
                            const double x[], const double xi[], double u[],
                            Integer ncode, double v[], Nag_Comm *comm)
{
  double *e = (double *) comm->p;
  double a, b, c, t;
  Integer i;

  if (comm->user[0] == -1.0) {
    printf("(User-supplied callback uvinit, first invocation.)\n");
    comm->user[0] = 0.0;
  }
  t = 0.0;
  for (i = 1; i <= npts; ++i) {
    a = (x[i - 1] - 0.25 - 0.75 * t) / (*e * 4.0);
    b = (0.9 * x[i - 1] - 0.325 - 0.495 * t) / (*e * 2.0);
    if (a > 0.0 && a > b) {
      a = exp(-a);
      c = (0.8 * x[i - 1] - 0.4 - 0.24 * t) / (*e * 4.0);
      c = exp(c);
      U(1, i) = (0.1 * c + 0.5 + a) / (c + 1.0 + a);
    }
    else if (b > 0.0 && b >= a) {
      b = exp(-b);
      c = (-0.8 * x[i - 1] + 0.4 + 0.24 * t) / (*e * 4.0);
      c = exp(c);
      U(1, i) = (0.5 * c + 0.1 + b) / (c + 1.0 + b);
    }
    else {
      a = exp(a);
      b = exp(b);
      U(1, i) = (0.5 * a + 1.0 + 0.1 * b) / (a + 1.0 + b);
    }
  }
  return;
}

static void NAG_CALL pdedef(Integer npde, double t, double x,
                            const double u[], const double ux[],
                            Integer ncode, const double v[],
                            const double vdot[], double p[], double q[],
                            double r[], Integer *ires, Nag_Comm *comm)
{
  double *e = (double *) comm->p;

  if (comm->user[1] == -1.0) {
    printf("(User-supplied callback pdedef, first invocation.)\n");
    comm->user[1] = 0.0;
  }
  P(1, 1) = 1.0;
  r[0] = *e * ux[0];
  q[0] = u[0] * ux[0];

  return;
}

static void NAG_CALL bndary(Integer npde, double t, const double u[],
                            const double ux[], Integer ncode,
                            const double v[], const double vdot[],
                            Integer ibnd, double beta[], double gamma[],
                            Integer *ires, Nag_Comm *comm)
{
  double a, b, c, ue, x;
  double *e = (double *) comm->p;

  if (comm->user[2] == -1.0) {
    printf("(User-supplied callback bndary, first invocation.)\n");
    comm->user[2] = 0.0;
  }
  beta[0] = 0.0;
  if (ibnd == 0) {
    x = 0.0;
    a = (x - 0.25 - 0.75 * t) / (*e * 4.0);
    b = (0.9 * x - 0.325 - 0.495 * t) / (*e * 2.0);
    if (a > 0. && a > b) {
      a = exp(-a);
      c = (0.8 * x - 0.4 - 0.24 * t) / (*e * 4.0);
      c = exp(c);
      ue = (0.1 * c + 0.5 + a) / (c + 1.0 + a);
    }
    else if (b > 0.0 && b >= a) {
      b = exp(-b);
      c = (-0.8 * x + 0.4 + 0.24 * t) / (*e * 4.0);
      c = exp(c);
      ue = (0.5 * c + 0.1 + b) / (c + 1.0 + b);
    }
    else {
      a = exp(a);
      b = exp(b);
      ue = (0.5 * a + 1.0 + 0.1 * b) / (a + 1.0 + b);
    }
  }
  else {
    x = 1.0;
    a = (x - 0.25 - 0.75 * t) / (*e * 4.0);
    b = (0.9 * x - 0.325 - 0.495 * t) / (*e * 2.0);
    if (a > 0.0 && a > b) {
      a = exp(-a);
      c = (0.8 * x - 0.4 - 0.24 * t) / (*e * 4.0);
      c = exp(c);
      ue = (0.1 * c + 0.5 + a) / (c + 1.0 + a);
    }
    else if (b > 0.0 && b >= a) {
      b = exp(-b);
      c = (-0.8 * x + 0.4 + 0.24 * t) / (*e * 4.0);
      c = exp(c);
      ue = (0.5 * c + 0.1 + b) / (c + 1.0 + b);
    }
    else {
      a = exp(a);
      b = exp(b);
      ue = (0.5 * a + 1.0 + 0.1 * b) / (a + 1.0 + b);
    }
  }
  gamma[0] = u[0] - ue;

  return;
}

static void exact(double t, double *x, Integer npts, double *u,
                  Nag_Comm *comm)
{
  /* Exact solution (for comparison purposes) */

  double a, b, c;
  double *e = (double *) comm->p;
  Integer i;

  for (i = 0; i < npts; ++i) {
    a = (x[i] - 0.25 - 0.75 * t) / (*e * 4.0);
    b = (0.9 * x[i] - 0.325 - 0.495 * t) / (*e * 2.0);
    if (a > 0. && a > b) {
      a = exp(-a);
      c = (0.8 * x[i] - 0.4 - 0.24 * t) / (*e * 4.0);
      c = exp(c);
      u[i] = (0.1 * c + 0.5 + a) / (c + 1.0 + a);
    }
    else if (b > 0. && b >= a) {
      b = exp(-b);
      c = (-0.8 * x[i] + 0.4 + 0.24 * t) / (*e * 4.0);
      c = exp(c);
      u[i] = (0.5 * c + 0.1 + b) / (c + 1.0 + b);
    }
    else {
      a = exp(a);
      b = exp(b);
      u[i] = (0.5 * a + 1.0 + 0.1 * b) / (a + 1.0 + b);
    }
  }
  return;
}

static void NAG_CALL monitf(double t, Integer npts, Integer npde,
                            const double x[], const double u[],
                            const double r[], double fmon[], Nag_Comm *comm)
{
  double drdx, h;
  Integer i, k, l;

  if (comm->user[3] == -1.0) {
    printf("(User-supplied callback monitf, first invocation.)\n");
    comm->user[3] = 0.0;
  }
  for (i = 1; i <= npts - 1; ++i) {
    k = i - 1;
    if (i == 1)
      k = 1;
    l = i + 1;
    h = 0.5 * (x[l - 1] - x[k - 1]);

    /* Second derivative */

    drdx = (R(1, i + 1) - R(1, i)) / h;
    fmon[i - 1] = drdx;
    if (fmon[i - 1] < 0)
      fmon[i - 1] = -drdx;
  }
  fmon[npts - 1] = fmon[npts - 2];

  return;
}