//
// program to solve 1D heat-distribution problem (1D version of
//   problem presented in textbook, section 6.3.2)
//
// command-line arguments:  
//   number of interior points (must be a multiple of number
//     of processes)
//   value for left end
//   value for right end
//   maximum iterations (optional)
//   convergence threshold (optional)
//
// output:
//   status messages to standard error
//   values of all points on convergence to standard output
// 

#include <stdio.h>
#include <stdlib.h>		// has exit(), etc.
#include <algorithm>		// has STL fill()
#include "mpi++.h"		// MPI header file

#define DEFAULT_THRESHOLD      0.01	// convergence threshold
#define DEFAULT_MAX_ITERATIONS 1000

int main(int argc, char *argv[]) {

  MPI::Init(argc, argv);
  int nprocs = MPI::COMM_WORLD.Get_size();
  int myID = MPI::COMM_WORLD.Get_rank();

  // ---- process command-line arguments ----

  if (argc < 4) {
    if (myID == 0)
      cerr << "Usage:  heat-eqn npoints left-temp right-temp "
	   << "[maxiter] [threshold]\n";
    MPI::Finalize();
    exit(EXIT_FAILURE);
  }

  int nPoints = atoi(argv[1]);
  if ((nPoints % nprocs) != 0) {
    if (myID == 0)
      cerr << "npoints must be a multiple of nprocs\n";
    MPI::Finalize();
    exit(EXIT_FAILURE);
  }

  int maxIterations = DEFAULT_MAX_ITERATIONS;
  if (argc > 4) 
    maxIterations = atoi(argv[4]);

  double threshold = DEFAULT_THRESHOLD;
  if (argc > 5)
    threshold = atof(argv[5]);

  // ---- declare and initialize other variables ----

  double start_time, end_time, end_output_time;

  if (myID == 0) {
    start_time = MPI::Wtime();
    cerr << "Computing for " << nPoints << " points using "
	 << nprocs << " processes\n";
  }

  const int nPointsLocal = nPoints/nprocs;

  // arrays for oldval, newval
  // elements 0, nPointsLocal+1 are "ghost cells"
  // elements 1 .. nPointsLocal are this process's values
  double * oldval = new double[nPointsLocal + 2];
  double * newval = new double[nPointsLocal + 2];

  // initialize oldval
  fill(oldval+1, oldval+nPointsLocal+1, 0.0);

  // initialize left, right ends
  if (myID == 0) {
    oldval[0] = atof(argv[2]);
    newval[0] = oldval[0];
  }
  if (myID == (nprocs-1)) {
    oldval[nPointsLocal+1] = atof(argv[3]);
    newval[nPointsLocal+1] = oldval[nPointsLocal+1];
  }

  // ---- main processing loop ---- 
  //      on each iteration, compute new values for all interior
  //        points based on values computed last time
  //      iterations continue until "convergence" (new values and
  //        old values differ by less than threshold value) or
  //        until maximum number of iterations is reached

  bool converged = false;
  int iter = 0;

  MPI::Status status;

  for ( ; (iter < maxIterations) && !converged; ++iter) {

    // exchange values with neighbors
    // send to left edge to left neighbor (if any)
    if (myID != 0)
      MPI::COMM_WORLD.Send(oldval+1, 1, MPI::DOUBLE, 
			   myID-1, 0);
    // send right edge to right neighbor (if any)
    if (myID != (nprocs-1))
      MPI::COMM_WORLD.Send(oldval+nPointsLocal, 1, MPI::DOUBLE, 
			   myID+1, 0);
    // receive from left neighbor (if any)
    if (myID != 0)
      MPI::COMM_WORLD.Recv(oldval, 1, MPI::DOUBLE,
			   myID-1, 0);
    // receive from right neighbor (if any)
    if (myID != (nprocs-1))
      MPI::COMM_WORLD.Recv(oldval+nPointsLocal+1, 1, MPI::DOUBLE,
			   myID+1, 0);

    // compute new values
    for (int i = 1; i <= nPointsLocal; ++i)
      newval[i] = 0.5*(oldval[i-1] + oldval[i+1]);

    // check for convergence
    bool convergedThisTime = true;
    int i = 1;
    for ( ; (i <= nPointsLocal) && convergedThisTime; ++i) {
      if (fabs(oldval[i] - newval[i]) > threshold)
	convergedThisTime = false;
    }
    // use MPI "reduction" to compute "logical and" of variable
    //   convergedThisTime in all processes
    int tinyBuffIn = int(convergedThisTime);
    int tinyBuffOut;
    int tinyBuffSize = 1;
    MPI::COMM_WORLD.Allreduce(&tinyBuffIn, &tinyBuffOut, tinyBuffSize,
			      MPI::INTEGER, MPI::LAND);
    converged = bool(tinyBuffOut);

    // "copy" old values to new values (by switching pointers!)
    double * temp = oldval;
    oldval = newval;
    newval = temp;
  }

  // ---- do final output and clean up ----
  //      each process prints values for "its" points

  if (myID == 0) {
    end_time = MPI::Wtime();
    if (converged)
      cerr << "Converged in " << iter << " iterations";
    else
      cerr << "Failed to converge in " << iter << " iterations";
    cerr << " (convergence threshold " << threshold << ")\n";
    cerr << "Computation time = " << end_time - start_time << endl;
  }

  // print results -- for each point, its index relative to the 
  //   whole array and its final value
  for (int i = 1; i <= nPointsLocal; ++i) 
    cout << "value for point " << (myID*nPointsLocal)+(i-1)
	 << " = " << oldval[i] << endl;

  if (myID == 0) {
    end_output_time = MPI::Wtime();
    cerr << "Output time = " << end_output_time - end_time << endl;
  }

  MPI::Finalize();

  return EXIT_SUCCESS;
}