/* nag_mv_gaussian_mixture (g03gac) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 *
 * Mark 24, 2013.
 */
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagg03.h>
#include <nagx04.h>

#define S(I,J,K) s[I-1 + (J-1)*(sopt==Nag_GroupVar ?ng:nvar) + (K-1)*nvar*nvar]
#define X(I,J) x[(I-1)*tdx + J-1]
#define PROB(I,J) prob[(I-1)*tdprob + J-1]
#define G(I,J) g[(I-1)*ng + J-1]
#define F(I,J) f[(I-1)*ng + J-1]

int main(void)
{
    /* Integer scalar and array declarations */
    Integer     exit_status = 0, i, j, lens, m, n, ng, niter, nvar, riter,
        tdprob, tdx;
    Integer     *isx = 0;

    /* Double scalar and array declarations */
    double      loglik, tol;
    double      *f = 0, *g = 0, *prob = 0, *s = 0, *w = 0, *x = 0;

    /* NAG structures */
    Nag_Boolean  popt;
    Nag_VarCovar sopt;
    NagError     fail;

    /* Character scalar and array declarations */
    char         nag_enum_popt[30+1], nag_enum_sopt[30+1];

    printf("nag_mv_gaussian_mixture (g03gac) Example Program Results\n\n");
    fflush(stdout);

    /* Skip heading in data file */
    scanf("%*[^\n] ");
    
    /* Problem size */
    scanf("%ld", &n);
    scanf("%ld", &m);
    scanf("%ld", &nvar);
    scanf("%*[^\n] ");
    
    /* Number of groups */
    scanf("%ld", &ng);
    scanf("%*[^\n] ");
    
    /* Scaling option */
    scanf("%30s", nag_enum_sopt);
    scanf("%*[^\n] ");
    
    /* Initial probabilities option */
    scanf("%30s", nag_enum_popt);
    scanf("%*[^\n] ");
    
    /* Maximum number of iterations */
    scanf("%ld", &niter);
    scanf("%*[^\n] ");
    
    /* Principal dimensions */
    tdx = nvar;
    tdprob = ng;

    /* nag_enum_name_to_value (x04nac).
     * Converts NAG enum member name to value
     */
    popt = (Nag_Boolean)nag_enum_name_to_value(nag_enum_popt);
    sopt = (Nag_VarCovar)nag_enum_name_to_value(nag_enum_sopt);

    /* Variance/covariance array */ 
    switch (sopt)
    {
    case Nag_GroupCovar:
        lens = nvar*nvar*ng;
        break;
    case Nag_PooledCovar:
        lens = nvar*nvar;
        break;
    case Nag_GroupVar:
        lens = nvar*ng;
        break;
    case Nag_PooledVar:
        lens = nvar;
        break;
    case Nag_OverallVar:
        lens = 1;
        break;
    }

    if (!(x = NAG_ALLOC(n*tdx, double)) ||
        !(prob = NAG_ALLOC(n*tdprob, double)) ||
        !(g = NAG_ALLOC(ng*nvar, double)) ||
        !(w = NAG_ALLOC(ng, double)) ||
        !(isx = NAG_ALLOC(m, Integer)) ||
        !(f = NAG_ALLOC(ng*n, double)) ||
        !(s = NAG_ALLOC(lens, double)))
    {
        printf("Allocation failure\n");
        exit_status = -1;
        goto END;
    }

    /* Data matrix X */
    for (i=1; i<=n; i++)
        for (j=1;j<=m; j++)
            scanf("%lf", &X(i,j));
    scanf("%*[^\n] ");

    /* Included variables */
    if (nvar != m)
    {
        for (j=1; j<=m; j++)
            scanf("%ld", &isx[j-1]);
        scanf("%*[^\n] ");
    }

    /* Optionally read initial probabilities of group membership */
    if (popt==Nag_FALSE)
    {
        for (i=1; i<=n; i++)
            for (j=1; j<=ng; j++)
                scanf("%lf", &PROB(i,j));
        scanf("%*[^\n] ");
    }

    /* Optimisation parameters */
    tol = 0.0;
    riter = 5;

    /* Fit the model */
    /* nag_mv_gaussian_mixture (g03gac).
     * Computes a Gaussian mixture model
     */
    INIT_FAIL(fail);
    nag_mv_gaussian_mixture(n, m, x, tdx, isx, nvar, ng, popt, prob, tdprob,
                            &niter, riter, w, g, sopt, s, f, tol, &loglik,
                            &fail);

    if (fail.code != NE_NOERROR)
      {
        printf("nag_mv_gaussian_mixture (g03gac) failed.\n%s\n",fail.message);
        exit_status = 1;
        goto END;
      }
    
    /* Results */
    /* nag_gen_real_mat_print (x04cac).
     * Print real general matrix (easy-to-use)
     */
    nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag,
                           1, ng, w, ng, "Mixing proportions", NULL, &fail);

    nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag,
                           nvar, ng, g, ng, "\n Group means", NULL, &fail);

    /* Variance/Covariance */
    switch (sopt) {
    case Nag_GroupCovar:
      for (i=1; i<=ng; i++)
        {
          nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix,
                                 Nag_NonUnitDiag, nvar, nvar, &S(1,1,i), nvar,
                                 "\n Variance-covariance matrix", NULL, &fail);
        }
      break;
    case Nag_PooledCovar:
      nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag,
                             nvar, nvar, s, nvar,
                             "\n Pooled Variance-covariance matrix", NULL,
                             &fail);
      break;
    case Nag_GroupVar:
      nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag,
                             nvar, ng, s, ng, "\n Groupwise Variance", NULL,
                             &fail);
      break;
    case Nag_PooledVar:
      nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag,
                             nvar, 1, s, 1, "\n Pooled Variance", NULL, &fail);
      break;
    case Nag_OverallVar:
      printf("\n Overall Variance = %g\n", S(1,1,1));
      break;
    }

    nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag, n,
                           ng, f, ng, "\n Densities", NULL, &fail);

    nag_gen_real_mat_print(Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag, n,
                           ng, prob, ng, "\n Membership probabilities", NULL,
                           &fail);
    
    printf("\nNo. iterations: %ld\n", niter);
    printf("Log-likelihood: %g\n\n", loglik);
    
  END:
    NAG_FREE(f);
    NAG_FREE(g);
    NAG_FREE(prob);
    NAG_FREE(s);
    NAG_FREE(w);
    NAG_FREE(x);
    NAG_FREE(isx);
  
    return exit_status;
}