/*******************************************************************************
* Copyright (C) 2022 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
!  Content:
!      Intel(R) oneAPI Math Kernel Library ScaLAPACK C example source file
!
!*****************************************************************************
!===============================================================================
!==== ScaLAPACK psgetrf and psgetrs example program ============================
!===============================================================================
!
! Example solves a system of distributed linear equations
!
!       A * X = Y  or  A * X =  I
!
!  with a general random n-by-n distributed matrix A using the LU
!  factorization computed by psgetrf.
!  Y is a random n-by-nrhs distributed matrix, 
!  I denotes n-by-nrhs matrix which possesses 1 as all 
!  elements along its leading (main) diagonal; 
!  while all other elements in the matrix are 0. 
!  If n==nrhs, I is the identity matrix and X is the matrix inverse. 
!
! Example also demonstrates BLACS routines usage.
!
! It works with block-cyclic distribution on 2D process grid
!
! List of routines demonstrated in the example:
!
! psgetrf
! psgetrs 
! pslamch 
! pslange
! psgemr2d
! psgemm
! blacs_get
! blacs_gridinit
! blacs_gridinfo
! blacs_gridexit
! blacs_exit 
! descinit
! igamx2d 
! igebr2d
! igebs2d  
! sgebr2d
! sgebs2d
! numroc
! indxg2p
!
!  The program must be driven by a short data file.
!
! 0               type, problem type, if type=0 the system A*X = I is solved, otherwise A*X = Y
! 2000            n, dimension of matrix, 0 < n
! 1000            nrhs,number of right hand sides
! 16              nb, size of blocks, must be > 0
! 2               p, number of rows in the process grid, must be > 0
! 2               q, number of columns in the process grid, must be > 0, p*q = number of processes
! 20.0            threshold for residual check
!
!*******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "matrix_generator.h"

MKL_INT test_solve(MKL_INT ictxt, MKL_INT type, MKL_INT n, MKL_INT nrhs, MKL_INT nb, MKL_INT nprow,  MKL_INT npcol, MKL_INT* seed, float thresh)
{
    MKL_INT info = 0;
    MKL_INT err = 0;

    float  *A = NULL;
    float  *A_INI = NULL;
    float  *X = NULL;
    float  *Y = NULL;
    MKL_INT   *IPIV = NULL;
/*  Matrix descriptors */
    MKL_INT  desca[9];
    MKL_INT  descx[9];
    MKL_INT  descy[9];
/*  Local variables */

    MKL_INT myrow;
    MKL_INT mycol;
    MKL_INT iam, nprocs;
    float eps;
    float normA, normRhs, normSol, normRes, residual;

    float fzero = RZERO;
    float fone = RONE;
    float fnegone = RNEGONE; 
/*  Compute machine epsilon */
    eps = pslamch( &ictxt, "e" );
/*  Get information about how many processes are used for program execution
    and number of current process */
    blacs_pinfo( &iam, &nprocs );      
    blacs_gridinfo(&ictxt, &nprow, &npcol, &myrow, &mycol);
/*  Compute precise length of local pieces of distributed matrices */ 
    MKL_INT loc_ra = numroc(&n, &nb, &myrow, &izero, &nprow);
    MKL_INT loc_ca = numroc(&n, &nb, &mycol, &izero, &npcol);
    MKL_INT loc_lda = max(1, loc_ra);
    MKL_INT loc_sizeA =loc_ra*loc_ca;
   
    MKL_INT loc_rx = numroc(&n, &nb, &myrow, &izero, &nprow);
    MKL_INT loc_cx = numroc(&nrhs, &nb, &mycol, &izero, &npcol);
    MKL_INT loc_ldx = max(1, loc_rx);
    MKL_INT loc_ry = numroc(&n, &nb, &myrow, &izero, &nprow); 
    MKL_INT loc_cy = numroc(&nrhs, &nb, &mycol, &izero, &npcol);
    MKL_INT loc_ldy = max(1, loc_ry);
/*  Initialize descriptors for distributed arrays */  

    descinit(desca, &n,  &n,   &nb, &nb, &izero, &izero, &ictxt, &loc_lda, &info);
    descinit(descx, &n, &nrhs, &nb, &nb, &izero, &izero, &ictxt, &loc_ldx, &info);
    descinit(descy, &n, &nrhs, &nb, &nb, &izero, &izero, &ictxt, &loc_ldy, &info);

/*  Allocate matrix, right-hand-side, vector solution x and the IPIV array
    containing the pivoting information. The array A_INI is used to store
    original matrix */
    if( loc_sizeA ) {
        A = (float*)mkl_malloc( loc_sizeA * sizeof(float), 64 );
        A_INI=(float*)mkl_malloc( loc_sizeA * sizeof(float), 64 );
    }
    if( loc_rx * loc_cx ) X = (float*)mkl_malloc( loc_rx * loc_cx * sizeof(float), 64 );
    if( loc_ry * loc_cy ) Y = (float*)mkl_malloc( loc_ry * loc_cy * sizeof(float), 64 );
    IPIV = (MKL_INT*)mkl_malloc( (loc_ra + nb)*sizeof(MKL_INT), 64 );
    if( A == NULL || X == NULL || Y == NULL || A_INI == NULL || IPIV == NULL ) info = 1;
    igamx2d(&ictxt, "ALL", " ", &ione, &ione, &info, &ione, NULL, NULL, &inegone, &inegone, &inegone);
    if( info ) {
        err = 1;
        if ( iam == 0 ) printf( "\n Can't allocate memory for arrays\n" );
        goto early_exit;
    }
/*  Generate matrix and solution. Copy the initial matrix. */
    seed[0]+= myrow + mycol;
    psmatgen_random(1, desca, A, seed, nprow, npcol);
    scopy(&loc_sizeA, A, &ione, A_INI, &ione);
    psmatgen_random(type, descx, X, seed, nprow, npcol);
    psmatgen_random(type, descy, Y, seed, nprow, npcol);
/*  Compute the right hand side */
    if( type ) {
        psgemm("N", "N", &n, &nrhs, &n, &fone, A, &ione, &ione, desca,
                X, &ione, &ione, descx,
                &fzero, Y, &ione,  &ione,  descy);
/*  Copy the right hand side Y to X */
        psgemr2d(&n, &nrhs, Y, &ione, &ione, descy, X, &ione, &ione, descx, &ictxt);
    }

/*  Compute norm of the initial matrix and right hand side  */
    normA = pslange("M", &n, &n, A, &ione, &ione, desca, NULL);
    normRhs = pslange("M", &n, &nrhs, Y, &ione, &ione, descy, NULL);

/*  Compute the LU factorization of the initial matrix */
    psgetrf(&n, &n, A, &ione, &ione, desca, IPIV, &info);
    igamx2d(&ictxt, "ALL", " ", &ione, &ione, &info, &ione, NULL, NULL, &inegone, &inegone, &inegone);
/* If psgetrf fails, deallocate all arays and report an error */
    if( info) {
        err = 1;
        if ( iam == 0 ) printf( "\n psgetrf fails to compute LU decomposition \n" );
        goto early_exit;
    }

    psgetrs("N", &n, &nrhs, A, &ione, &ione, desca, IPIV, X, &ione, &ione, descx, &info);
    igamx2d(&ictxt, "ALL", " ", &ione, &ione, &info, &ione, NULL, NULL, &inegone, &inegone, &inegone);
/* If psgetrs fails, deallocate all arays and report an error */
    if( info ) {
        err = 1;
        if ( iam == 0 ) printf( "\n psgetrs fails to solve \n" );
        goto early_exit;
    }
/* Normsol is the norm of the solution matrix:  normsol = || X|| */
    normSol = pslange("M", &n, &nrhs, X, &ione, &ione, descx, NULL);


/* Compute the residual matrix || A*X - Y ||  */
    psgemm("N", "N", &n, &nrhs, &n, &fone, A_INI, &ione, &ione, desca,
            X, &ione, &ione, descx,
            &fnegone, Y, &ione,  &ione,  descy);
/* Compute the residual as: 
   residual = ||A*X-Y||/(n*eps*(||A||*||X||+||Y||)) */

    normRes = pslange("M", &n, &nrhs, Y, &ione, &ione, descy, NULL);
    residual = normRes/( ( (float)n ) * eps *( normA * normSol + normRhs ) );

/*  Check if residual passed or failed the threshold */
    if( iam == 0 ) {
        if ( ( thresh >= RZERO ) && !( residual <= thresh ) ){
            printf( "FAILED. Residual = %05.11f\n", residual );
            err = 1;
        } else {
            printf( "PASSED. Residual = %05.11f\n", residual );
        }
        printf( "=== END OF EXAMPLE =====================\n" ); 
    }
early_exit :
/*  Destroy arrays */
    if( A ) mkl_free(A);
    if( A_INI ) mkl_free(A_INI);
    if( X ) mkl_free(X);
    if( Y ) mkl_free(Y);
    if( IPIV ) mkl_free(IPIV); 
    return err;
}

int main(int argc, char *argv[])
{
/*  ==== Declarations =================================================== */

/*  File variables */
    FILE    *fin;

/*  Local scalars */
    MKL_INT iam, nprocs, ictxt=-1, nprow, npcol;
    MKL_INT n, nb, nrhs;
    MKL_INT err;
    MKL_INT type=1;
    int     n_int, nb_int, nprow_int, npcol_int, nrhs_int, type_int;
    float  thresh;
    MKL_INT seed[4] = { 105, 1410, 1860, 4085 };

/*  Local arrays */
    MKL_INT iw[ 6 ];

/*  ==== Executable statements ========================================== */

/*  Get information about how many processes are used for program execution
    and number of current process */
    blacs_pinfo( &iam, &nprocs );

/*  Init temporary 1D process grid */
    blacs_get( &inegone, &izero, &ictxt );
    blacs_gridinit( &ictxt, "C", &nprocs, &ione );

/*  Open input file */
    if ( iam == 0 ) {
        fin = fopen( "psgetrf_example.in", "r" );
        if ( fin == NULL ) {
            printf( "Error while open input file." );
            return 1;
        }
    }

/*  Read data and send it to all processes */
    if ( iam == 0 ) {

/*      Read parameters */
        fscanf( fin, "%d type, problem type, if type=0 the system A*X = I is solved, otherwise A*X = Y", &type_int );
        fscanf( fin, "%d n, dimension of matrix, 0 < n", &n_int );
        fscanf( fin, "%d nrhs,number of right hand sides, 0 < nrhs", &nrhs_int );
        fscanf( fin, "%d nb, size of blocks, must be > 0", &nb_int );
        fscanf( fin, "%d p, number of rows in the process grid, must be > 0", &nprow_int );
        fscanf( fin, "%d q, number of columns in the process grid, must be > 0, p*q = number of processes", &npcol_int );
        fscanf( fin, "%f threshold for residual check", &thresh );
        fclose( fin );
        type = (MKL_INT) type_int;
        n = (MKL_INT) n_int;
        nrhs = (MKL_INT) nrhs_int;
        nb = (MKL_INT) nb_int;
        nprow = (MKL_INT) nprow_int;
        npcol = (MKL_INT) npcol_int;
/*      Check if all parameters are correct */
        if( ( n<=0 )||( nrhs<=0 )||( nb<=0 )||( nprow<=0 )||
            ( npcol<=0 )||( nprow*npcol != nprocs ) ) {
            printf( "One or several input parameters has incorrect value. Limitations:\n" );
            printf( "n > 0, nrhs > 0, nb > 0, p > 0, q > 0 - integer\n" );
            printf( "p*q = number of processes\n" );
            printf( "threshold - float \n");
            return 1;
        }

/*      Pack data into array and send it to other processes */
        iw[ 0 ] = type;
        iw[ 1 ] = n;
        iw[ 2 ] = nrhs;
        iw[ 3 ] = nb;
        iw[ 4 ] = nprow;
        iw[ 5 ] = npcol;
        igebs2d( &ictxt, "All", " ", &isix, &ione, iw, &isix );
        sgebs2d( &ictxt, "All", " ", &ione, &ione, &thresh, &ione );
    } else {

/*      Recieve and unpack data */
        igebr2d( &ictxt, "All", " ", &isix, &ione, iw, &isix, &izero, &izero );
        sgebr2d( &ictxt, "All", " ", &ione, &ione, &thresh, &ione, &izero, &izero );
        type = iw[ 0 ];
        n = iw[ 1 ];
        nrhs = iw[ 2 ];
        nb = iw[ 3 ];
        nprow = iw[ 4 ];
        npcol = iw[ 5 ];
    }
/*  Destroy temporary process grid */
    blacs_gridexit( &ictxt ); 
/*  Init workind 2D process grid */ 
    blacs_get(&inegone, &izero, &ictxt);
    blacs_gridinit(&ictxt, "R", &nprow, &npcol);
    if ( iam == 0 ) {
/*      Print information of task */
        printf( "=== START OF EXAMPLE ===================\n" );
        if( type ) {
            printf( "Solve matrix equation: A*X = Y\n" );
            printf( "with a general random n-by-n distributed matrix A \n" );
            printf( "Y is a random n-by-nrhs distributed matrix\n\n" );
        } else {
            printf( "Solve matrix equation: A*X = I\n" );
            printf( "with a general random n-by-n distributed matrix A \n" );
            printf( "I denotes a n-by-nrhs matrix which possesses 1 \n" ); 
            printf( "as all elements along its leading (main) diagonal \n" );
            printf( "while all other elements in the matrix are 0. \n" ); 
            printf( "If n==nrhs, I is the identity matrix and X is the matrix inverse of A.\n\n" );
        }
        printf( "n = %d, nrhs = %d, nb = %d; %dx%d - process grid\n\n", n_int, nrhs_int, nb_int, nprow_int, npcol_int);
        printf( "Threshold for residual check = %05.11f\n", thresh );
    }
    err = test_solve(ictxt, type, n, nrhs, nb,  nprow, npcol, seed, thresh);
/*  Destroy process grid */
    blacs_gridexit(&ictxt);
    blacs_exit(&izero);
    return err;
}
