//
// Program to solve 1D heat-distribution problem (1D version of
//   problem presented in textbook, section 6.3.2)
//
// Command-line arguments:  
//   number of threads
//   number of interior points (must be a multiple of number
//     of processes)
//   value for left end
//   value for right end
//   maximum iterations (optional)
//   convergence threshold (optional)
//
// Output:
//   status messages to standard error
//   values of all points on convergence to standard output
//
// Version 1:  For each iteration of the main loop, starts
//   up P threads and waits for them to complete.
//
#include <iostream>
#include <iomanip>		// has setprecision()
#include <cstdlib>		// has exit(), etc.
#include <algorithm>		// has fill()
#include <cmath>		// has fabs()
#include "timer.h"		// has timer()
#include "pthreads-threadmgr.h"	// has threadManager class

#define DEFAULT_THRESHOLD      0.01     // convergence threshold
#define DEFAULT_MAX_ITERATIONS 1000

// ---- Class for doing calculations ---------------------------------

// This will be an "enclosing class" to hold variables to be shared
//   among threads.  We'll build one object of this class, and pass
//   a pointer to it to all the threads.  This seems tidier than
//   either of the previous approaches (global variables, or static
//   class variables).  This class will also provide methods that
//   do the actual desired work, including starting up threads and
//   waiting for them to complete.

class computeObj {
public:
  // Constructor.
  computeObj(const int nThreads, const int nPoints, 
	     const int maxIter, const double threshold,
	     const double leftEnd, const double rightEnd);
  // Destructor.
  ~computeObj(void);

  // Do calculations.
  void calculate(void);

  // Do output.
  void output(void);

private:
  // These variables hold data to be shared among all threads.
  int nThreads;
  int nPoints;
  int maxIter;			// maximum iterations
  double threshold;		// convergence threshold
  double leftEnd;		// left boundary point
  double rightEnd;		// right boundary point
  int iter;			// number of iterations
  bool converged;		// global convergence?
  double * oldval;		// array of old values
  double * newval;		// array of new values
  bool * convergedLocal;	// per-thread convergences

  // ---- Class for thread -------------------------------------------

  class computeThreadObj {
  public:
    typedef computeThreadObj * ptr;
    computeThreadObj(const int myID,
		     const int firstIndex, const int lastIndex,
		     computeObj * context);
    void run(void);
  private:
    int myID;		
    int firstIndex;		// index of first element we "own"
    int lastIndex;		// index of last element we "own"
    computeObj * context;	// pointer to "enclosing object"
  };
};

// ---- Main program -------------------------------------------------

int main(int argc, char* argv[]) {

  // Process command-line arguments.

  if (argc < 5) {
    cerr << "Usage:  " << argv[0] 
	 << " nThreads nPoints leftTemp rightTemp"
	 << " [maxIter] [threshold]\n";
    exit(EXIT_FAILURE);
  }

  int nThreads = atoi(argv[1]);
  int nPoints = atoi(argv[2]);
  if ((nPoints % nThreads) != 0) {
    cerr << "nPoints must be a multiple of nThreads\n";
    exit(EXIT_FAILURE);
  }

  int maxIter = DEFAULT_MAX_ITERATIONS;
  if (argc > 5) 
    maxIter = atoi(argv[5]);

  double threshold = DEFAULT_THRESHOLD;
  if (argc > 6)
    threshold = atof(argv[6]);

  cerr << "Computing for " << nPoints << " points using "
       << nThreads << " threads\n";

  // Build object, using values for boundary points from command line.

  double startInitTime = timer();
  computeObj obj(nThreads, nPoints, maxIter, threshold, 
		 atof(argv[3]), atof(argv[4]));

  // Do calculations.
  double startTime = timer();
  obj.calculate();
  double endTime = timer();

  // Print results.
  obj.output();
  double endOutputTime = timer();
  cerr << "Time for computation (seconds) = " << setprecision(4) 
       << endTime - startTime << endl;
  cerr << "Total time (seconds) = " << setprecision(4)
       << endOutputTime - startInitTime << endl;

  // Exit.
  return EXIT_SUCCESS;
}

// ---- Functions for computeObj class -------------------------------

// Constructor.
computeObj::computeObj(const int nThreads, const int nPoints, 
		       const int maxIter, const double threshold,
		       const double leftEnd, const double rightEnd) {
  this->nThreads = nThreads;
  this->nPoints = nPoints;
  this->maxIter = maxIter;
  this->threshold = threshold;
  this->leftEnd = leftEnd;
  this->rightEnd = rightEnd;
  oldval = new double[nPoints + 2];
  newval = new double[nPoints + 2];
  convergedLocal = new bool[nThreads];
}

// Destructor.
computeObj::~computeObj(void) {
  delete [] oldval;
  delete [] newval;
  delete [] convergedLocal;
}

// Do calculations.
void computeObj::calculate(void) {

  // Initialize arrays.
  //   interior points initially = 0
  fill(oldval+1, oldval+nPoints+1, 0.0);
  //   boundary points from constructor arguments
  oldval[0] = leftEnd;
  oldval[nPoints+1] = rightEnd;
  //   boundary points do not change, so set new values here
  newval[0] = oldval[0];
  newval[nPoints+1] = oldval[nPoints+1];

  converged = false;		// will become true if, at all points,
				//   abs(oldval - newval) <= threshold.
				// convergedLocal[i] is true iff
				//   abs(oldval - newval) <= threshold
				//   for all points owned by thread i

  iter = 0;			// number of iterations so far

  computeThreadObj::ptr * threadObjs = 
    new computeThreadObj::ptr[nThreads]; // parameters for threads

  // Create computeThreadObj objects, one for each thread, to hold
  //   parameters to pass to threads.  
  int nPointsPerThread = nPoints / nThreads;
  for (int i = 0; i < nThreads; ++i)
    threadObjs[i] = new computeThreadObj(i, 
					 i*nPointsPerThread + 1,
					 (i+1)*nPointsPerThread,
					 this);

  // Main processing loop:
  //   Each time through, we:
  //     Compute new values for all interior points, using values
  //       computed last time.
  //     Check for convergence.
  //     "Copy" new values to old values (by switching pointers).
  //   The loop continues until convergence or the maximum number of
  //     iterations is reached.
  //   Concurrency is obtained by dividing up the work of computing
  //     new values and checking for convergence among nThreads threads.

  for ( ; (iter < maxIter) && !converged; ++iter) {

    // Create and start nThreads new threads, and wait for them to 
    //   complete.
    threadManager<computeThreadObj> tMgr(nThreads, threadObjs);
    tMgr.join();

    // "Copy" old values to new values by switching pointers.
    double * temp = oldval;
    oldval = newval;
    newval = temp;

    // Check for overall convergence.
    converged = true;
    for (int i = 0; i < nThreads && converged; ++i) 
      converged = converged && convergedLocal[i];

  } // end of main loop

  // Clean up.
  for (int i = 0; i < nThreads; ++i) 
    delete threadObjs[i];
  delete [] threadObjs;
}

// Do output.
void computeObj::output(void) {
  if (converged)
    cerr << "Converged in " << iter << " iterations";
  else
    cerr << "Failed to converge in " << iter << " iterations";
  cerr << " (convergence threshold " << threshold << ")\n";
  for (int i = 1; i <= nPoints; ++i) 
    cout << "value for point " << i << " = " << oldval[i] << endl;
}

// ---- Functions for computeObj::computeThreadObj class -------------

// Each thread "owns" one segment of arrays oldval and newval,
//   and is responsible for updating their values.

// Constructor.
computeObj::computeThreadObj::computeThreadObj(const int myID,
					       const int firstIndex, 
					       const int lastIndex,
					       computeObj * context) {
  this->myID = myID;
  this->firstIndex = firstIndex;
  this->lastIndex = lastIndex;
  this->context = context;
}

// Member function containing code each thread should execute.
void computeObj::computeThreadObj::run(void) {

  // Computes new values for "its" points and determines whether
  //   convergence has occurred.  

  // Make local copy of some shared variables.
  double * oldVals = context->oldval;
  double * newVals = context->newval;

  // Compute new values for points (from firstIndex to lastIndex).
  for (int i = firstIndex; i <= lastIndex; ++i)
    newVals[i] = 0.5*(oldVals[i-1] + oldVals[i+1]);

  // Check for convergence in this section.
  bool convergedTemp = true;
  for (int i = firstIndex; i <= lastIndex && convergedTemp; ++i)
    if (fabs(oldVals[i] - newVals[i]) > (context->threshold))
      convergedTemp = false;
  context->convergedLocal[myID] = convergedTemp;
}