/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*      dgetri_oop_batch (group API) OpenMP Offload Example
*******************************************************************************/
#include <stdio.h>
#include <omp.h>
#include "mkl.h"
#include "mkl_omp_offload.h"

#define GROUP_COUNT 4

#define MIN_GROUP_SIZE 4
#define MAX_GROUP_SIZE 10

#define MIN_N 4
#define MAX_N 32


int check_info(char* func_name, MKL_INT group_count, MKL_INT* group_sizes, MKL_INT *info_arr) {
    int num_errors = 0;
    for (int igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (int imat = 0; imat < group_sizes[igrp]; imat++, idx++) {
            if (info_arr[idx]) {
                printf("%s offload failed: Matrix %4d(matrix %4d from group %4d) returned with info=%4d\n",
                    func_name, (int)idx, (int)imat, (int)igrp, (int)info_arr[idx]);
                num_errors++;
            }
        }
    }

    if (num_errors > 0)
        printf("Total number of failures:%4d\n", (int)num_errors);

    return num_errors;
}

int main(void)
{
    // Total number of groups in the matrix batch
    MKL_INT  group_count = GROUP_COUNT;

    // Allocate memory for parameter arrays
    MKL_INT* group_sizes = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* n_arr       = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* lda_arr     = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* ldainv_arr  = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    if (!group_sizes || !n_arr || !lda_arr || !ldainv_arr) {
        printf("ERROR: Memory allocation for parameter arrays failed!\n");
        return 1;
    }

    // Generate random matrix dimensions for the matrices in each group
    // n_arr[igrp]       - number of rows and columns in matrices for group igrp (MIN_N <= n_arr[igrp] <= MAX_N)
    // lda_arr[igrp]     - leading dimension in input matrices for group igrp (lda_arr[igrp] >= n_arr[igrp])
    // ldainv_arr[igrp]  - leading dimension in output matrices for group igrp (lda_arr[igrp] >= n_arr[igrp])
    // group_sizes[igrp] - number of matrices in group igrp (MIN_GROUP_SIZE <= group_sizes[igrp] <= MAX_GROUP_SIZE)
    printf("Computing the matrix inverse for the following batch of matrices:\n");
    printf("===================================================================\n");
    MKL_INT batch_size = 0;
    double  total_a_arr_size = 0;
    double  total_ainv_arr_size = 0;
    MKL_INT total_ipiv_arr_size = 0;
    for (int igrp = 0; igrp < group_count; igrp++) {
        // Random problem dimensions for each group
        n_arr[igrp]        = MIN_N + rand() % (MAX_N - MIN_N + 1);
        lda_arr[igrp]      = n_arr[igrp];
        ldainv_arr[igrp]   = n_arr[igrp];
        group_sizes[igrp]  = MIN_GROUP_SIZE + rand() % (MAX_GROUP_SIZE - MIN_GROUP_SIZE + 1);
        batch_size        += group_sizes[igrp];

        printf("Group=%4d, n=%4d, lda=%4d, ldainv=%4d, group_size=%4d\n",
            (int)igrp, (int)n_arr[igrp], (int)lda_arr[igrp], (int)ldainv_arr[igrp], (int)group_sizes[igrp]);
    }
    printf("Total number of matrices=%4d\n", (int)batch_size);
    printf("===================================================================\n");

    // Allocate memory for the matrices in the batch
    for (int igrp = 0; igrp < group_count; igrp++) {
        total_a_arr_size    += group_sizes[igrp] * lda_arr[igrp] * n_arr[igrp];
        total_ainv_arr_size += group_sizes[igrp] * ldainv_arr[igrp] * n_arr[igrp];
        total_ipiv_arr_size += group_sizes[igrp] * n_arr[igrp];
    }
    double*  a    = (double*)  mkl_malloc(total_a_arr_size    * sizeof(double),  64);
    double*  ainv = (double*)  mkl_malloc(total_ainv_arr_size * sizeof(double),  64);
    MKL_INT* ipiv = (MKL_INT*) mkl_malloc(total_ipiv_arr_size * sizeof(MKL_INT), 64);
    if (!a || !ainv || !ipiv) {
        printf("ERROR: Memory allocation for matrices failed!\n");
        return 1;
    }

    // Allocate memory for the array of pointers to the matrices in the batch
    double**  a_arr    = (double**)  mkl_malloc(batch_size * sizeof(double*),  64);
    double**  ainv_arr = (double**)  mkl_malloc(batch_size * sizeof(double*),  64);
    MKL_INT** ipiv_arr = (MKL_INT**) mkl_malloc(batch_size * sizeof(MKL_INT*), 64);
    MKL_INT*  info_arr = (MKL_INT*)  mkl_malloc(batch_size * sizeof(MKL_INT),  64);
    if (!a_arr || !ainv_arr || !ipiv_arr || !info_arr) {
        printf("ERROR: Memory allocation for pointer arrays failed!\n");
        return 1;
    }

    // Set up array of pointers and initialize data for each matrix
    MKL_INT* a_sizes    = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* ainv_sizes = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* ipiv_sizes = (MKL_INT*) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    if (!a_sizes || !ainv_sizes || !ipiv_sizes) {
        printf("ERROR: Memory allocation for size arrays failed!\n");
        return 1;
    }
    int a_off = 0;
    int ainv_off = 0;
    int ipiv_off = 0;
    for (int igrp = 0, idx = 0; igrp < group_count; igrp++) {
        a_sizes[igrp]    = lda_arr[igrp] * n_arr[igrp];
        ainv_sizes[igrp] = ldainv_arr[igrp] * n_arr[igrp];
        ipiv_sizes[igrp] = n_arr[igrp];
        for (int imat = 0; imat < group_sizes[igrp]; imat++, idx++) {
            // Set up pointer for matrix idx
            a_arr[idx]    = &a[a_off];
            ainv_arr[idx] = &ainv[ainv_off];
            ipiv_arr[idx] = &ipiv[ipiv_off];

            // Initialize entries of matrix idx
            for (int col = 0; col < n_arr[igrp]; col++) {
                for (int row = 0; row < n_arr[igrp]; row++) {
                    a_arr[idx][row + col*lda_arr[igrp]] = rand() / (double) RAND_MAX - 0.5;
                }
            }

            // Update pointer offset for the next matrix
            a_off    += a_sizes[igrp];
            ainv_off += ainv_sizes[igrp];
            ipiv_off += ipiv_sizes[igrp];
        }
    }

    // Map each array in a_arr, ainv_arr and ipiv_arr to the device and store their corresponding pointers in a new array
    double**  a_dev_arr    = (double**)  mkl_malloc(batch_size * sizeof(double*) , 64);
    double**  ainv_dev_arr = (double**)  mkl_malloc(batch_size * sizeof(double*) , 64);
    MKL_INT** ipiv_dev_arr = (MKL_INT**) mkl_malloc(batch_size * sizeof(MKL_INT*), 64);
    if (!a_dev_arr || !ainv_dev_arr || !ipiv_dev_arr) {
        printf("ERROR: Memory allocation for device pointer arrays failed!\n");
        return 1;
    }

    for (int igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (int imat = 0; imat < group_sizes[igrp]; imat++, idx++) {
            double*  a_ptr    = a_arr[idx];
            double*  ainv_ptr = ainv_arr[idx];
            MKL_INT* ipiv_ptr = ipiv_arr[idx];
            #pragma omp target enter data map(to:a_ptr[0:a_sizes[igrp]],ainv_ptr[0:ainv_sizes[igrp]],ipiv_ptr[0:ipiv_sizes[igrp]])
            #pragma omp target data use_device_ptr(a_ptr,ainv_ptr,ipiv_ptr)
            {
                a_dev_arr[idx]    = a_ptr;
                ainv_dev_arr[idx] = ainv_ptr;
                ipiv_dev_arr[idx] = ipiv_ptr;
            }
        }
    }

    // Compute batched LU factorization on GPU via dispatch construct
    #pragma omp target data map(to:a_dev_arr[0:batch_size], ipiv_dev_arr[0:batch_size]) map(from:info_arr[0:batch_size])
    {
          #pragma omp dispatch
          dgetrf_batch(n_arr, n_arr, a_dev_arr, lda_arr, ipiv_dev_arr, &group_count, group_sizes, info_arr);
    }
    printf("Finished call to dgetrf_batch\n");
    int num_errors = check_info("dgetrf_batch", group_count, group_sizes, info_arr);

    if (num_errors == 0) {
        // Compute batched matrix inverse on GPU via dispatch construct
        #pragma omp target data map(to:a_dev_arr[0:batch_size], ainv_dev_arr[0:batch_size], ipiv_dev_arr[0:batch_size]) map(from:info_arr[0:batch_size])
        {
            #pragma omp dispatch
            dgetri_oop_batch(n_arr, (const double**)a_dev_arr, lda_arr, (const MKL_INT**)ipiv_dev_arr,
                    ainv_dev_arr, ldainv_arr, &group_count, group_sizes, info_arr);
        }
        printf("Finished call to dgetri_oop_batch\n");
        num_errors = check_info("dgetri_oop_batch", group_count, group_sizes, info_arr);
    }
    printf("Finished batched matrix inverse computaitons.\n");

    // Bring a_arr, ainv_arr and ipiv_arr data back to the host
    for (int igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (int imat = 0; imat < group_sizes[igrp]; imat++, idx++) {
            double*  a_ptr    = a_arr[idx];
            double*  ainv_ptr = ainv_arr[idx];
            MKL_INT* ipiv_ptr = ipiv_arr[idx];
            #pragma omp target exit data map(from:a_ptr[0:a_sizes[igrp]],ainv_ptr[0:ainv_sizes[igrp]],ipiv_ptr[0:ipiv_sizes[igrp]])
        }
    }

    // Cleanup
    mkl_free(group_sizes);
    mkl_free(n_arr);
    mkl_free(lda_arr);
    mkl_free(ldainv_arr);
    mkl_free(a_arr);
    mkl_free(ainv_arr);
    mkl_free(ipiv_arr);
    mkl_free(info_arr);
    mkl_free(a);
    mkl_free(ainv);
    mkl_free(ipiv);
    mkl_free(a_sizes);
    mkl_free(ainv_sizes);
    mkl_free(ipiv_sizes);
    mkl_free(a_dev_arr);
    mkl_free(ainv_dev_arr);
    mkl_free(ipiv_dev_arr);

    if (num_errors == 0) {
        printf("Example executed successfully.\n");
    } else {
        printf("Example executed with errors.\n");
        return 1;
    }
    return 0;
}
