//
// Program to solve 1D heat-distribution problem (1D version of
//   problem presented in textbook, section 6.3.2)
//
// Command-line arguments:  
//   number of threads
//   number of interior points (must be a multiple of number
//     of processes)
//   value for left end
//   value for right end
//   maximum iterations (optional)
//   convergence threshold (optional)
//
// Output:
//   status messages to standard error
//   values of all points on convergence to standard output
//
#include <iostream.h>
#include <iomanip>		// has setprecision()
#include <stdlib.h>		// has exit(), etc.
#include <pthread.h>		// has pthread_ routines
#include <algorithm>		// has fill()
#include <numeric>		// has accumulate()
#include <functional>		// has logical_and()
#include "threads-timer.h"	// has timer()
#include "threads-threadWrapper.h"	// has threadWrapper class
#include "threads-barrier.h"	// has barrier class

#define DEFAULT_THRESHOLD      0.01     // convergence threshold
#define DEFAULT_MAX_ITERATIONS 1000

// ---- Class for thread ---------------------------------------------

// Each thread "owns" one segment of arrays oldVals and newVals,
//   and is responsible for updating their values.

class computeThreadObj {
public:
  typedef computeThreadObj * ptr;

  // member variables are parameters for thread
  int myID;			// thread ID
  int nThreads;			// number of threads
  int firstIndex;		// index of first element we "own"
  int lastIndex;		// index of last element we "own"
  double threshold;		// convergence threshold
  int maxIter;			// maximum number of iterations
  double * oldVals;		// whole "old values" array
  double * newVals;		// whole "new values" array
  bool * convergedLocal;	// convergedLocal[i] means convergence
				//   has occurred in thread i during
				//   this trip through the main loop
  barrier * barPtr;		// barrier to use for waiting for
				//   other threads
  int * iterPtr;		// count of iterations -- one thread
				//   will update *iterPtr, each time
				//   through loop; all threads have
				//   access, as does main thread at
				//   the end
  bool * convergedPtr;		// *convergedPtr == true means 
				//   convergedLocal[i] is true for all 
				//   i -- one thread updates this
  computeThreadObj(const int myID, const int nThreads,
		   const int firstIndex, const int lastIndex,
		   const double threshold, const int maxIter,
		   double * const oldVals, double * const newVals,
		   bool * const convergedLocal,
		   barrier * const barPtr,
		   int * const iterPtr,
		   bool * const convergedPtr)
  {
    this->myID = myID;
    this->nThreads = nThreads;
    this->firstIndex = firstIndex;
    this->lastIndex = lastIndex;
    this->threshold = threshold;
    this->maxIter = maxIter;
    this->oldVals = oldVals;
    this->newVals = newVals;
    this->convergedLocal = convergedLocal;
    this->barPtr = barPtr;
    this->iterPtr = iterPtr;
    this->convergedPtr = convergedPtr;
  }

  // this member function is the code the thread should execute
  void run(void) {

    // Performs main loop:
    //
    //   Each time through, we:
    //     Compute new values for points we "own", using values
    //       computed last time.
    //     Check for convergence.
    //     Do barrier synchronization.
    //     In first thread, check for global convergence and update
    //       count of iterations.
    //     "Copy" old values to new values (by switching pointers).
    //     Do barrier synchronization.
    //
    //   The loop continues until convergence or the maximum number of
    //     iterations is reached.

    while ((*iterPtr < maxIter) && !(*convergedPtr)) {

      // Compute new values for points (from firstIndex to lastIndex).
      for (int i = firstIndex; i <= lastIndex; ++i)
	newVals[i] = 0.5*(oldVals[i-1] + oldVals[i+1]);

      // Check for convergence in this section.
      convergedLocal[myID] = true;
      for (int i = firstIndex; i <= lastIndex && convergedLocal[myID]; ++i)
	if (fabs(oldVals[i] - newVals[i]) > threshold)
	  convergedLocal[myID] = false;

      // Do barrier synchronization (wait for other threads to get this
      //   far).
      barPtr->wait();

      // In first thread, check for global convergence and update
      //   count of iterations.
      if (myID == 0) {
	*convergedPtr = 
	  accumulate(convergedLocal, convergedLocal+nThreads,
		     true, logical_and<bool>());
	++(*iterPtr);
      }
      // "Copy" old values to new values (by switching pointers).
      double * temp = oldVals;
      oldVals = newVals;
      newVals = temp;

      // Do barrier synchronization (wait for other threads to get this
      //   far).
      barPtr->wait();
    } // end of main loop
  } // end of run() 
};

// ---- Main program -------------------------------------------------

int main(int argc, char* argv[]) {

  // Process command-line arguments.

  if (argc < 5) {
    cerr << "Usage:  " << argv[0] 
	 << " nThreads nPoints leftTemp rightTemp"
	 << " [maxIter] [threshold]\n";
    exit(EXIT_FAILURE);
  }

  int nThreads = atoi(argv[1]);
  int nPoints = atoi(argv[2]);
  if ((nPoints % nThreads) != 0) {
    cerr << "nPoints must be a multiple of nThreads\n";
    exit(EXIT_FAILURE);
  }

  unsigned int maxIterations = DEFAULT_MAX_ITERATIONS;
  if (argc > 5) 
    maxIterations = atoi(argv[5]);

  double threshold = DEFAULT_THRESHOLD;
  if (argc > 6)
    threshold = atof(argv[6]);

  cerr << "Computing for " << nPoints << " points using "
       << nThreads << " threads\n";

  // Set up to time things.

  double startTime = timer();

  // Declare and initialize other variables.

  double * oldval = new double[nPoints + 2];
				// old values
  double * newval = new double[nPoints + 2];
				// new values

  // interior points initially = 0
  fill(oldval+1, oldval+nPoints+1, 0.0);

  // boundary points' values given by command-line arguments
  oldval[0] = atof(argv[3]);
  oldval[nPoints+1] = atof(argv[4]);

  // boundary points do not change, so set new values here
  newval[0] = oldval[0];
  newval[nPoints+1] = oldval[nPoints+1];

  bool * convergedLocal = new bool[nThreads];
				// convergedLocal[i] is true iff
				//   abs(oldval - newval) <= threshold
				//   for all points owned by thread i

				// these shared variables don't need
				//   locks -- access is controlled by
				//   use of barriers within threads
  int iter = 0;			// count of trips through main loop
				//   (in threads) -- updated each time
				//   through by first thread
  bool converged = false;	// convergence reached in all threads 
				//   -- updated each time through the
				//   loop by first thread

  barrier * barPtr = new barrier(nThreads);
				// barrier to use for synchronization

  int nPointsPerThread = nPoints / nThreads;

  // Start numThreads new threads, each running a wrapper function, 
  //   with a computeThreadObj object as parameter.  (The wrapper 
  //   function just invokes the object's "run()" method.)
  computeThreadObj::ptr * threadObj = 
    new computeThreadObj::ptr[nThreads]; 
				// parameters for threads
  pthread_t * threads = new pthread_t[nThreads];
				// "handles" for threads
  for (int i = 0; i < nThreads; ++i) {
    threadObj[i] = new computeThreadObj(i, nThreads,
					i*nPointsPerThread + 1,
					(i+1)*nPointsPerThread,
					threshold, maxIterations,
					oldval, newval,	convergedLocal,
					barPtr,
					&iter, &converged);
    if (pthread_create(&threads[i], NULL, // 
                       threadWrapper<computeThreadObj::ptr>::run, 
                       (void *) threadObj[i]) != 0)
      cerr << "Unable to create thread " << i << endl; // 
  }

  // Wait for all threads to complete.

  for (int i = 0; i < nThreads; ++i) {
    if (pthread_join(threads[i], NULL) != 0)
      cerr << "Unable to perform join on thread " << i << endl;
  }

  // Print results.

  double endTime = timer();
  cerr << "Elapsed time (seconds) = " << setprecision(4) 
       << endTime - startTime << endl;
  
  if (converged)
    cerr << "Converged in " << iter << " iterations";
  else
    cerr << "Failed to converge in " << iter << " iterations";
  cerr << " (convergence threshold " << threshold << ")\n";

  for (int i = 1; i <= nPoints; ++i) 
    cout << "value for point " << i << " = " << oldval[i] << endl;

  // Free everything allocated with "new" in this function.

  delete [] oldval;
  delete [] newval;
  delete [] convergedLocal;
  for (int i = 0; i < nThreads; ++i)
    delete threadObj[i];
  delete [] threadObj;
  delete [] threads;

  return EXIT_SUCCESS;
}