/**
 * Parallel program to test/time block-based matrix-multiplication routine. 
 *
 * Command-line arguments:  size of matrices to multiply, number of blocks
 *   (in each dimension -- e.g., 2 means split matrices into 4 blocks each).
 */
#include <stdio.h>
#include <stdlib.h>
#include <omp.h>

#define blockstart(M,i,j,rows_per_blk,cols_per_blk,stride) \
    (M + ((i)*(rows_per_blk))*(stride) + (j)*(cols_per_blk))

#include "matrix-mult-utility.h"
#include "matrix-mult-initfortest.h"
#include "matrix-mult-print.h"

int main(int argc, char *argv[]) {

    double start_time, end_time;
    int N = 0;                      /* input size */
    int NB = 0;                     /* number of blocks */
    int dimN, dimP, dimM;           /* matrix dimensions */
    double *A, *B, *C;              /* matrices */
    int dimNb, dimPb, dimMb;        /* block dimensions */
    int ib, jb, kb;
    int nthreads;

    /* command-line arguments */
    if (argc < 3) {
        fprintf(stderr, "usage:  %s size numblocks\n", argv[0]);
        return EXIT_FAILURE;
    }
    N = atoi(argv[1]);
    if (N <= 0) {
        fprintf(stderr, "usage:  %s size numblocks\n", argv[0]);
        return EXIT_FAILURE;
    }
    NB = atoi(argv[2]);
    if (NB <= 0) {
        fprintf(stderr, "usage:  %s size numblocks\n", argv[0]);
        return EXIT_FAILURE;
    }
    if ((N % NB) != 0) {
        fprintf(stderr, "numblocks must evenly divide size\n");
        return EXIT_FAILURE;
    }
    dimN = dimP = dimM = N;
    dimNb = dimN / NB;
    dimPb = dimP / NB;
    dimMb = dimM / NB;

    A = malloc(dimN*dimP*sizeof(double));
    B = malloc(dimP*dimM*sizeof(double));
    C = malloc(dimN*dimM*sizeof(double));

    if ((A == NULL) || (B == NULL) || (C == NULL)) {
        fprintf(stderr, "unable to allocate space for matrices of size %d\n",
                dimN);
        return EXIT_FAILURE;
    }

    #pragma omp parallel
    {
        #pragma omp single
        {
            nthreads = omp_get_num_threads();
        }
    }

    /* Initialize matrices */

    initialize(A, B, dimN, dimP, dimM);
	
    /* Do the multiply */

    start_time = omp_get_wtime();

    #pragma omp parallel for private(jb,kb) schedule(static)
    for (ib=0; ib < NB; ++ib) {
        for (jb=0; jb < NB; ++jb) {
            /* find block[ib][jb] of C */
            double * blockPtr = blockstart(C, ib, jb, dimNb, dimMb, dimM);
            /* clear block[ib][jb] of C (set all elements to zero) */
            matclear(blockPtr, dimNb, dimMb, dimM);
            for (kb=0; kb < NB; ++kb) {
                /* compute product of block[ib][kb] of A and block[kb][jb] of B 
                   and add to block[ib][jb] of C */
                matmul_add(
                        blockstart(A, ib, kb, dimNb, dimPb, dimP),
                        blockstart(B, kb, jb, dimPb, dimMb, dimM),
                        blockPtr, dimNb, dimPb, dimMb, dimP, dimM, dimM
                        );
            }
        }
    }

    end_time = omp_get_wtime();

    /* Print results */

    printMatrix(stdout, "A", A, dimN, dimP, dimP);
    printMatrix(stdout, "B", B, dimP, dimM, dimM);
    printMatrix(stdout, "A*B", C, dimN, dimM, dimM);

    fprintf(stderr, "Block-based program, parallel with %d threads\n", 
            nthreads);
    fprintf(stderr, "Size = %d, numblocks = %d, time for multiplication = %g\n",
            N, NB, end_time - start_time);

    return EXIT_SUCCESS;
}