/*******************************************************************************
* Copyright (C) 2021 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*      POTRF/POTRI OpenMP GPU Offload Example
*******************************************************************************/

#include <stdio.h>
#include <math.h>
#include <omp.h>
#include "mkl.h"
#include "mkl_omp_offload.h"


int main()
{
	char L = 'L';
    MKL_INT m    = 3;
    MKL_INT n    = 3;
    MKL_INT lda  = 4;
    MKL_INT info = 0;
    MKL_INT matrix_size = lda * n;

    double matrix[] = { 14,  15,  24,  -1,
                        15,  17,  27,  -1,
                        24,  27,  43,  -1 };

    double result[] = {  2,   3,  -3,  -1,
                         3,  26, -18,  -1,
                        -3, -18,  13,  -1 };

    printf("Input:\n");
    for (int i=0; i < m; i++) {
        for (int j=0; j <= i; j++) {
            printf("%6.2f ", matrix[i + j * lda]);
        }
        printf("\n");
    }

    MKL_INT *info_ptr  = &info;
    double *matrix_ptr = &matrix[0];

    #pragma omp target data map(matrix_ptr[0:matrix_size], info_ptr[0:1])
    {

        #pragma omp dispatch
        dpotrf(&L, &n, matrix_ptr, &lda, info_ptr);
        #pragma omp target update from(info_ptr)
        if (info == 0) {
            #pragma omp dispatch
            dpotri(&L, &n, matrix_ptr, &lda, info_ptr);
        }
    }

    if (info != 0) {
        printf("ERROR: Calculations failed with info = %d!\n", (int)info);
        return 1;
    }

    int num_errors = 0;
    printf("Result:\n");
    for (int i=0; i < m; i++) {
        for (int j=0; j <= i; j++) {
            printf("%6.2f ", matrix[i + j * lda]);
            num_errors += fabs(matrix[i + j * lda] - result[i + j * lda]) > 1e-7;
        }
        printf("\n");
    }
    if (num_errors != 0) {
        printf("ERROR: result mismatches!\n");
        return 1;
    }

    return 0;
}
