/*
 * numerical integration example, as discussed in textbook:  
 *
 * compute pi by approximating the area under the curve f(x) = 4 / (1 + x*x)
 * between 0 and 1.
 *
 * parallel version using Java.
 *
 * command-line argument specifies number of threads.
 */
public class NumIntPar {

    /* variables/constants to be used by all threads */

    private static final int NUM_STEPS = 10000000;
    private static final double step = 1.0/(double) NUM_STEPS; 
    private static double sum = 0.0;
    private static Object lockObj = new Object();  
    private static int nThreads;

    /* main method */

    public static void main(String[] args) {

        /* process command-line argument */
        if (args.length < 1) {
            System.err.println("usage:  NumIntPar nThreads");
            System.exit(1);
        }
        try {
            nThreads = Integer.parseInt(args[0]);
        }
        catch (NumberFormatException e) {
            System.err.println("usage:  NumIntPar nThreads");
            System.exit(1);
        }

        /* start timing */
        long startTime = System.currentTimeMillis();

        /* create threads */
        Thread[] threads = new Thread[nThreads];
        for (int i = 0; i < threads.length; ++i) {
            threads[i] = new Thread(new CodeForThread(i));
        }

        /* start them up */
        for (int i = 0; i < threads.length; ++i) {
            threads[i].start();
        }

        /* wait for them to finish */
        for (int i = 0; i < threads.length; ++i) {
            try {
                threads[i].join();
            }
            catch (InterruptedException e) {
                System.err.println("this should not happen");
            }
        }

        /* finish computation */
        double pi = sum * step;

        /* end timing and print result */
        long endTime = System.currentTimeMillis();
        System.out.println("parallel program results with " + nThreads +
                " threads:");
        System.out.println("pi = " + pi);
        System.out.println("time to compute = " + 
                (double) (endTime - startTime) / 1000);
    }

    /* static inner class to contain code to run in each thread */

    private static class CodeForThread implements Runnable {

        private int myID;

        public CodeForThread(int myID) {
            this.myID = myID;
        }

        public void run() {
            double partsum = 0.0;
            for (int i=myID; i < NUM_STEPS; i += nThreads) {
                double x = (i+0.5)*step;
                partsum = partsum + 4.0/(1.0+x*x);
            }
            /* only one thread at a time can do this */
            synchronized(lockObj) {
                sum += partsum;
            }
        }
    }
}