/*
 * Simplified program in which pairs of neighboring processes periodically
 * exchange messages.  Each process alternates exchanging messages with its
 * neighbors and "computing" (actually just waiting).
 *
 * This version uses regular blocking sends/receives but corrects the
 * problem in exchange.c by having even-numbered processes send first
 * and odd-numbered processes receive first.  Requires that the number of
 * processes be even.
 *
 * Command-line arguments:  number of exchange-then-compute cycles, time to
 * "compute" (wait), message length
 */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <mpi.h>

/* a short function to print a message, stop all processes, and exit */
void error_exit(char* msg) {
    fprintf(stderr, msg);
    MPI_Finalize();
    exit(EXIT_FAILURE);
}

/* main program */
int main(int argc, char *argv[]) {

    int nprocs, myID, left_nbr, right_nbr;
    int *to_left, *to_right, *from_left, *from_right;
    int numsteps, compute_time, msglength;
    double start_time, end_time;
    double time_diff, max_time_diff;
    int k;
    MPI_Status status; 

    /* MPI initialization */
    MPI_Init(&argc, &argv);
    MPI_Comm_size (MPI_COMM_WORLD, &nprocs);
    MPI_Comm_rank(MPI_COMM_WORLD, &myID);
    if (nprocs %2 != 0) {
        error_exit("number of processes must be even\n");
    }

    /* process command-line arguments */
    if (argc < 3) 
        error_exit("parameters:  numsteps compute_time msglength\n");
    numsteps = atoi(argv[1]);
    if (numsteps <= 0)
        error_exit("parameters:  numsteps compute_time msglength\n");
    compute_time = atoi(argv[2]);
    if (compute_time <= 0)
        error_exit("parameters:  numsteps compute_time msglength\n");
    msglength = atoi(argv[3]);
    if (msglength <= 0)
        error_exit("parameters:  numsteps compute_time msglength\n");
    if (myID == 0)
        printf("numsteps = %d, compute_time = %d, msglength = %d, nprocs = %d\n",
                numsteps, compute_time, msglength, nprocs);

    /* initialization of other variables */
    left_nbr = (myID == 0) ? nprocs-1 : myID - 1; 
    right_nbr = (myID == (nprocs-1)) ? 0 : myID + 1;
    to_left = malloc(sizeof(int) * msglength);
    to_right = malloc(sizeof(int) * msglength);
    from_left = malloc(sizeof(int) * msglength);
    from_right = malloc(sizeof(int) * msglength);
    if (to_left == NULL || to_right == NULL || 
            from_left == NULL || from_right == NULL)
        error_exit("unable to allocate space for buffers\n");

    /* barrier before we start timing */
    MPI_Barrier(MPI_COMM_WORLD);
    start_time = MPI_Wtime();

    /* repeat numsteps times */
    for (k = 0; k < numsteps; ++k) {

        /* exchange information with neighbors */
        if (myID % 2 == 0) {
            /* even-numbered processes send first */
            MPI_Send(to_left, msglength, MPI_INT, left_nbr, 0, MPI_COMM_WORLD);
            MPI_Send(to_right, msglength, MPI_INT, right_nbr, 0, MPI_COMM_WORLD);
            MPI_Recv(from_left, msglength, MPI_INT, left_nbr, 0, MPI_COMM_WORLD, 
                    &status);
            MPI_Recv(from_right, msglength, MPI_INT, right_nbr, 0, MPI_COMM_WORLD, 
                    &status);
        }
        else {
            /* odd-numbered processes receive first */
            MPI_Recv(from_right, msglength, MPI_INT, right_nbr, 0, MPI_COMM_WORLD, 
                    &status);
            MPI_Recv(from_left, msglength, MPI_INT, left_nbr, 0, MPI_COMM_WORLD, 
                    &status);
            MPI_Send(to_right, msglength, MPI_INT, right_nbr, 0, MPI_COMM_WORLD);
            MPI_Send(to_left, msglength, MPI_INT, left_nbr, 0, MPI_COMM_WORLD);
        }

        /* fake computation */
        sleep(compute_time);
    }

    /* end timing */
    end_time = MPI_Wtime();
    time_diff = end_time - start_time;

    /* get maximum time over all processes */
    MPI_Reduce(&time_diff, &max_time_diff, 1, MPI_DOUBLE, MPI_MAX, 0,
            MPI_COMM_WORLD);

    /* print maximum time */
    if (myID == 0)
        printf("time %g seconds\n", max_time_diff);
 
    /* clean up and end */
    MPI_Finalize();
    return EXIT_SUCCESS;
}