/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::gesvda_batch
*       to perform batched calculation of truncated SVD.
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*******************************************************************************/

#include <oneapi/mkl.hpp>
#include "common_for_examples.hpp"

template <typename data_t, typename real_t = decltype(std::real((data_t)0)), bool is_real = std::is_same_v<data_t,real_t>>
int run_gesvda_batch_example(sycl::device &dev)
{
    auto v = [] (real_t arg) { if constexpr (is_real) return arg; else return data_t{0, arg}; };
    const int64_t m = 5, n = 5, lda = m, stride_a = n*lda, ldu = m, stride_u=m*ldu,  batch_size = 2;
    const int64_t stride_s = m, ldvt = m, stride_vt=n*ldvt;
    int64_t iparm[16], irank[batch_size];
    real_t Residual[batch_size], S[batch_size*stride_s];
    data_t U[batch_size*stride_u], Vt[batch_size*stride_vt];
    const real_t tolerance = std::is_same_v<real_t, float> ? 1e-6 : 1e-8;
 

    data_t A[] = {
        v( 1.0), v( 0.0), v( 0.0), v( 0.0), v( 0.0),
        v( 1.0), v( 0.2), v(-0.4), v(-0.4), v(-0.8),
        v( 1.0), v( 0.6), v(-0.2), v( 0.4), v(-1.2),
        v( 1.0), v( 1.0), v(-1.0), v( 0.6), v(-0.8),
        v( 1.0), v( 1.8), v(-0.6), v( 0.2), v(-0.6)
                                                   ,
        v( 0.2), v(-0.4), v(-0.4), v(-0.8), v( 0.0),
        v( 0.4), v( 0.2), v( 0.8), v(-0.4), v( 0.0),
        v( 0.4), v(-0.8), v( 0.2), v( 0.4), v( 0.0),
        v( 0.8), v( 0.4), v(-0.4), v( 0.2), v( 0.0),
        v( 0.0), v( 0.0), v( 0.0), v( 0.0), v( 1.0)
    };

    for (int i=0; i<batch_size; i++) {
        irank[i] =n;
    }
    for (int i=0; i<16; i++) {
        iparm[i] =0;
    }
    iparm[3] = 1;
    iparm[0] = 1;
    { 
        sycl::queue que { dev };

        sycl::buffer<data_t> A_buffer{A, stride_a*batch_size};
        sycl::buffer<int64_t> iparm_buffer{iparm, 16};
        sycl::buffer<int64_t> irank_buffer{irank, batch_size};
        sycl::buffer<real_t> S_buffer{S, batch_size*stride_s};
        sycl::buffer<real_t> Residual_buffer{Residual,batch_size};
        sycl::buffer<data_t> U_buffer{U, batch_size*stride_u};
        sycl::buffer<data_t> Vt_buffer{Vt, batch_size*stride_vt};

        int64_t scratchpad_size = oneapi::mkl::lapack::gesvda_batch_scratchpad_size<data_t>(que, m, n, lda, stride_a,
                                     stride_s, ldu, stride_u, ldvt, stride_vt,
                                     batch_size);
        sycl::buffer<data_t> scratchpad{scratchpad_size};

        oneapi::mkl::lapack::gesvda_batch(que, iparm_buffer, irank_buffer, m, n, A_buffer, lda, stride_a, S_buffer, stride_s,
              U_buffer, ldu, stride_u, Vt_buffer, ldvt, stride_vt, tolerance, Residual_buffer,
              batch_size, scratchpad, scratchpad_size);
    }
    const real_t threshold = std::is_same_v<real_t, float> ? 1e-5 : 1e-10;
    for (int i=0; i< stride_s*batch_size; i++) {
            std::cout << " Singular values # " << i << " Value " << S[i] << std::endl;
    }
 
    bool passed = true;

    for (int i=0; i<batch_size; i++) {
            real_t result = Residual[i] ;
            std::cout << " Residual entry # " << i << " Value " << result << std::endl;
            passed = passed and (result == result) and ( result < threshold);
    }
    if (passed) {
        std::cout << " Calculations successfully finished " << std::endl;
    } else {
        std::cout << " Computed residual exceeds the tolerance threshold  " << std::endl;
        return 1;
    }

    return 0;
}

//
// Description of example setup, APIs used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Batched strided truncated SVD example:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Computes truncated SVD of a batch of matrices." << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
} 


//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, gesvda_batch example
//  runs with all supported data types
//
int main(int argc, char **argv) {

    print_example_banner();

    // Find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    bool failed = false;

    for (auto &deviceType: listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if (myDevIsFound) {
          std::cout << std::endl << "Running gesvda_batch examples on " << sycl_device_names[deviceType] << "." << std::endl;

          std::cout << "Running with single precision real data type:" << std::endl;
          failed |= run_gesvda_batch_example<float>(myDev);

          std::cout << "Running with single precision complex data type:" << std::endl;
          failed |= run_gesvda_batch_example<std::complex<float>>(myDev);

          if (isDoubleSupported(myDev)) {
              std::cout << "Running with double precision real data type:" << std::endl;
              failed |= run_gesvda_batch_example<double>(myDev);

              std::cout << "Running with double precision complex data type:" << std::endl;
              failed |= run_gesvda_batch_example<std::complex<double>>(myDev);
          } else {
              std::cout << "Double precision not supported on this device " << std::endl;
              std::cout << std::endl;
          }

        } else {
#ifdef FAIL_ON_MISSING_DEVICES
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
          return 1;
#else
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
        }
    }
    return failed;
}
