/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       API oneapi::mkl::dft to perform 3-D Single Precision Real to Complex
*       Fast-Fourier Transform on a SYCL device (CPU, GPU).
*
*       The supported floating point data types for data are:
*           float
*           std::complex<float>
*
*******************************************************************************/

#define _USE_MATH_DEFINES
#include <cmath>
#include <vector>
#include <iostream>
#include <sycl/sycl.hpp>
#include "oneapi/mkl/dft.hpp"

#include <stdexcept>
#include <cfloat>
#include <cstddef>
#include <limits>
#include <type_traits>
#include "mkl.h"

// local includes
#define NO_MATRIX_HELPERS
#include "common_for_examples.hpp"

typedef oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL> descriptor_t;

constexpr int SUCCESS = 0;
constexpr int FAILURE = 1;
constexpr float TWOPI = 6.2831853071795864769f;

enum data_layout {
    ROW_MAJOR,
    COL_MAJOR
};

// Compute (K*L)%M accurately
static float moda(int K, int L, int M)
{
    return (float)(((long long)K * L) % M);
}

static void init_r(float *data, int M, int FWD_DIST,
                    int N1, int N2, int N3,
                    const std::vector<std::int64_t>& fwd_strides,
                    int H1, int H2, int H3)
{
    int S1 = fwd_strides[3], S2 = fwd_strides[2], S3 = fwd_strides[1];

    float factor =
        ((2 * (N1 - H1) % N1 == 0) &&
         (2 * (N2 - H2) % N2 == 0) &&
         (2 * (N3 - H3) % N3 == 0)) ? 1.0f : 2.0f;

    for (int j = 0; j < M; j++){
        for (int n3 = 0; n3 < N3; n3++) {
            for (int n2 = 0; n2 < N2; n2++) {
                for (int n1 = 0; n1 < N1; n1++) {
                    float phase = moda(n1, H1, N1) / N1 +
                                         moda(n2, H2, N2) / N2+
                                         moda(n3, H3, N3) / N3;
                    int index = j*FWD_DIST + n3*S3 + n2*S2 + n1*S1;
                    data[index] = factor * cosf(TWOPI * phase) / (N3*N2*N1);
                }
            }
        }
    }
}


static int verify_c(const float* data, int M, int BWD_DIST,
                    int N1, int N2, int N3,
                    const std::vector<std::int64_t>& bwd_strides,
                    int H1, int H2, int H3)
{
    // Note: this simple error bound doesn't take into account error of
    //       input data
    float errthr = 2.5f * log((float) N3*N2*N1) / log(2.0f) * FLT_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Generalized strides for row-major addressing of data
    int S1 = bwd_strides[3], S2 = bwd_strides[2], S3 = bwd_strides[1];

    float maxerr = 0.0f;
    for (int j = 0; j < M; j++){
        for (int n3 = 0; n3 < N3; n3++){
            for (int n2 = 0; n2 < N2; n2++) {
                for (int n1 = 0; n1 < N1/2+1; n1++) {
                    float re_exp = (
                            ((n1 - H1) % N1 == 0) &&
                            ((n2 - H2) % N2 == 0) &&
                            ((n3 - H3) % N3 == 0)
                        ) || (
                            ((-n1 - H1) % N1 == 0) &&
                            ((-n2 - H2) % N2 == 0) &&
                            ((-n3 - H3) % N3 == 0)
                        ) ? 1.0f : 0.0f;
                    float im_exp = 0.0f;

                    int index = j*BWD_DIST + n3*S3 + n2*S2 + n1*S1;
                    float re_got = data[index*2+0];
                    float im_got = data[index*2+1];
                    float err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
                    if (err > maxerr) maxerr = err;
                    if (!(err < errthr)) {
                        std::cout << "\t\tdata in batch #" << j << " incorrect at "
                                  << "[" << n3 << "][" << n2 << "][" << n1 << "]: "
                                  << " expected (" << re_exp << "," << im_exp << ")"
                                  << " got (" << re_got << "," << im_got << ")"
                                  << " err " << err << std::endl;
                        std::cout << "\t\tVerification FAILED" << std::endl;
                        return FAILURE;
                    }
                }
            }
        }
    }
    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}


static void init_c(float *data, int M, int BWD_DIST,
                    int N1, int N2, int N3,
                    const std::vector<std::int64_t>& bwd_strides,
                    int H1, int H2, int H3)
{
    int S1 = bwd_strides[3], S2 = bwd_strides[2], S3 = bwd_strides[1];

    for (int j = 0; j < M; j++){
        for (int n3 = 0; n3 < N3; n3++) {
            for (int n2 = 0; n2 < N2; n2++) {
                for (int n1 = 0; n1 < N1/2 + 1; n1++) {
                    float phase = moda(n1, H1, N1) / N1 +
                                         moda(n2, H2, N2) / N2 +
                                         moda(n3, H3, N3) / N3;
                    int index = j*BWD_DIST + n3*S3 + n2*S2 + n1*S1;
                    data[index*2 + 0] =  cosf(TWOPI * phase) / (N3*N2*N1);
                    data[index*2 + 1] = -sinf(TWOPI * phase) / (N3*N2*N1);
                }
            }
        }
    }
}

static int verify_r(const float *data, int M, int FWD_DIST,
                    int N1, int N2, int N3,
                    const std::vector<std::int64_t>& fwd_strides,
                    int H1, int H2, int H3)
{
    int S1 = fwd_strides[3], S2 = fwd_strides[2], S3 = fwd_strides[1];

    // Note: this simple error bound doesn't take into account error of
    //       input data
    float errthr = 2.5f * log((float) N3*N2*N1) / log(2.0f) * FLT_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    float maxerr = 0.0f;
    for (int j = 0; j < M; j++){
        for (int n3 = 0; n3 < N3; n3++) {
            for (int n2 = 0; n2 < N2; n2++) {
                for (int n1 = 0; n1 < N1; n1++) {
                    float re_exp = (
                        ((n1 - H1) % N1 == 0) &&
                        ((n2 - H2) % N2 == 0) &&
                        ((n3 - H3) % N3 == 0)) ? 1.0f : 0.0f;

                    int index = j*FWD_DIST + n3*S3 + n2*S2 + n1*S1;
                    float re_got = data[index];
                    float err  = fabsf(re_got - re_exp);
                    if (err > maxerr) maxerr = err;
                    if (!(err < errthr)) {
                        std::cout << "\t\tdata in batch #" << j << " incorrect at "
                                  << "[" << n3 << "][" << n2 << "][" << n1 << "]: "
                                  << " expected (" << re_exp << ")"
                                  << " got (" << re_got << ")"
                                  << " err " << err << std::endl;
                        std::cout << "\t\tVerification FAILED" << std::endl;
                        return FAILURE;
                    }
                }
            }
        }
    }
    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}


int run_dft_forward_example(sycl::device &dev, 
                            data_layout LAYOUT) {
    //
    // Initialize data for DFT
    //
    int N1 = 7, N2 = 13, N3 = 11;
    int M = 5;
    
    
    // Arbitrary harmonic used to verify FFT
    int H1 = 1, H2 = 2, H3 = 3;
    // Strides describing data layout in real and conjugate-even domain
    std::vector<std::int64_t> fwd_strides(4);
    std::vector<std::int64_t> bwd_strides(4);
    int FWD_DIST, BWD_DIST; 
    switch (LAYOUT)
    {
        case ROW_MAJOR:
            // row-major without padding 
            fwd_strides[3] = 1;         bwd_strides[3] = 1;
            fwd_strides[2] = N1;        bwd_strides[2] = (N1/2 + 1);
            fwd_strides[1] = N2*N1;     bwd_strides[1] = N2*(N1/2 + 1);
            fwd_strides[0] = 0;         bwd_strides[0] = 0;
            FWD_DIST = N1*N2*N3;
            BWD_DIST = (N1/2 + 1)*N2*N3;
            break;
        case COL_MAJOR:
            // column-major without padding 
            fwd_strides[3] = M*N2*N3;   bwd_strides[3] = M*N2*N3;
            fwd_strides[2] = M*N3;      bwd_strides[2] = M*N3;
            fwd_strides[1] = M;         bwd_strides[1] = M;
            fwd_strides[0] = 0;         bwd_strides[0] = 0;
            FWD_DIST = 1;
            BWD_DIST = 1;
            break;
        default:
            throw std::invalid_argument("Invalid data layout. The only available options are ROW_MAJOR and COL_MAJOR");
    }
    
    int result = FAILURE;
    int result_buffer = FAILURE;
    int result_usm = FAILURE;

    float* x = (float*) mkl_malloc(M*N1*N2*N3*sizeof(float), 64);
    float* y = (float*) mkl_malloc(M*2*(N1/2 + 1)*N2*N3*sizeof(float), 64);
    init_r(x, M, FWD_DIST, N1, N2, N3, fwd_strides, H1, H2, H3);
    
    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };
        
        // create execution queue with asynchronous error handling
        sycl::queue queue(dev, exception_handler);

        // Setting up SYCL buffer and initialization
        sycl::buffer<float, 1> xBuffer(x, sycl::range<1>(M*N1*N2*N3));
        sycl::buffer<float, 1> yBuffer(y, sycl::range<1>(M*2*(N1/2 + 1)*N2*N3));
        xBuffer.set_write_back(false);
        yBuffer.set_write_back(false);
        // Setting up USM and initialization
        float *x_usm = (float*) malloc_shared(M*N1*N2*N3*sizeof(float), queue.get_device(), queue.get_context());
        float *y_usm = (float*) malloc_shared(M*2*(N1/2 + 1)*N2*N3*sizeof(float), queue.get_device(), queue.get_context());
        init_r(x_usm, M, FWD_DIST, N1, N2, N3, fwd_strides, H1, H2, H3);

        descriptor_t desc({N3, N2, N1});
        desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                       oneapi::mkl::dft::config_value::NOT_INPLACE);
        desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, M);
        desc.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, FWD_DIST);
        desc.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, BWD_DIST);
        desc.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides);
        desc.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides);
        desc.commit(queue);

        // Using SYCL buffers
        std::cout << "\tUsing SYCL buffers" << std::endl;
        oneapi::mkl::dft::compute_forward(desc, xBuffer, yBuffer);

        auto yAcc = yBuffer.get_host_access(sycl::read_only);
        result_buffer = verify_c(yAcc.get_pointer(), M, BWD_DIST, N1, N2, N3, bwd_strides, H1, H2, H3);
        
        // Using USM
        std::cout << "\tUsing USM" << std::endl;
        auto fwd = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);
        fwd.wait();
        result_usm = verify_c(y_usm, M, BWD_DIST, N1, N2, N3, bwd_strides, H1, H2, H3);
        
        if ((result_buffer == SUCCESS) && (result_usm == SUCCESS))
            result = SUCCESS;
        
        free(x_usm, queue.get_context());
        free(y_usm, queue.get_context());
    }
    catch(sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << e.code().value() << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(x);
    mkl_free(y);

    return result;
}

int run_dft_backward_example(sycl::device &dev, 
                             data_layout LAYOUT) {
    //
    // Initialize data for DFT
    //
    int N1 = 7, N2 = 13, N3 = 11;
    int M = 5;
    
    // Arbitrary harmonic used to verify FFT
    int H1 = 1, H2 = 2, H3 = 3;
    // Strides describing data layout in real and conjugate-even domain
    std::vector<std::int64_t> fwd_strides(4);
    std::vector<std::int64_t> bwd_strides(4);
    int FWD_DIST, BWD_DIST;
    switch (LAYOUT)
    {
        case ROW_MAJOR:
            // row-major without padding 
            fwd_strides[3] = 1;         bwd_strides[3] = 1;
            fwd_strides[2] = N1;        bwd_strides[2] = (N1/2 + 1);
            fwd_strides[1] = N2*N1;     bwd_strides[1] = N2*(N1/2 + 1);
            fwd_strides[0] = 0;         bwd_strides[0] = 0;
            FWD_DIST = N1*N2*N3;
            BWD_DIST = (N1/2 + 1)*N2*N3;
            break;
        case COL_MAJOR:
            // column-major without padding
            fwd_strides[3] = M*N2*N3;   bwd_strides[3] = M*N2*N3;
            fwd_strides[2] = M*N3;      bwd_strides[2] = M*N3;
            fwd_strides[1] = M;         bwd_strides[1] = M;
            fwd_strides[0] = 0;         bwd_strides[0] = 0;
            FWD_DIST = 1;
            BWD_DIST = 1;
            break;
        default:
            throw std::invalid_argument("Invalid data layout. The only available options are ROW_MAJOR and COL_MAJOR");
    }
    
    int result = FAILURE;
    int result_buffer = FAILURE;
    int result_usm = FAILURE;

    float* x = (float*) mkl_malloc(M*N1*N2*N3*sizeof(float), 64);
    float* y = (float*) mkl_malloc(M*2*(N1/2 + 1)*N2*N3*sizeof(float), 64);
    init_c(y, M, BWD_DIST, N1, N2, N3, bwd_strides, H1, H2, H3);
    
    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };
        
        // create execution queue with asynchronous error handling
        sycl::queue queue(dev, exception_handler);

        // Setting up SYCL buffer and initialization
        sycl::buffer<float, 1> xBuffer(x, sycl::range<1>(M*N1*N2*N3));
        sycl::buffer<float, 1> yBuffer(y, sycl::range<1>(M*2*(N1/2 + 1)*N2*N3));
        xBuffer.set_write_back(false);
        yBuffer.set_write_back(false);
        // Setting up USM and initialization
        float *x_usm = (float*) malloc_shared(M*N1*N2*N3*sizeof(float), queue.get_device(), queue.get_context());
        float *y_usm = (float*) malloc_shared(M*2*(N1/2 + 1)*N2*N3*sizeof(float), queue.get_device(), queue.get_context());
        init_c(y_usm, M, BWD_DIST, N1, N2, N3, bwd_strides, H1, H2, H3);
        

        descriptor_t desc({N3, N2, N1});
        desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                       oneapi::mkl::dft::config_value::NOT_INPLACE);
        desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, M);
        desc.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, FWD_DIST);
        desc.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, BWD_DIST);
        desc.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides);
        desc.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides);
        desc.commit(queue);

        // Using SYCL buffers
        std::cout<<"\tUsing SYCL buffers"<<std::endl;
        oneapi::mkl::dft::compute_backward(desc, yBuffer, xBuffer);

        auto xAcc = xBuffer.get_host_access(sycl::read_only);
        result_buffer = verify_r(xAcc.get_pointer(), M, FWD_DIST, N1, N2, N3, fwd_strides, H1, H2, H3);
        
        // Using USM
        std::cout<<"\tUsing USM"<<std::endl;
        auto bwd = oneapi::mkl::dft::compute_backward(desc, y_usm, x_usm);
        bwd.wait();
        result_usm = verify_r(x_usm, M, FWD_DIST, N1, N2, N3, fwd_strides, H1, H2, H3);
        
        if ((result_buffer == SUCCESS) && (result_usm == SUCCESS))
            result = SUCCESS;

        free(x_usm, queue.get_context());
        free(y_usm, queue.get_context());
        
    }
    catch(sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << e.code().value() << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(x);
    mkl_free(y);

    return result;
}

//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# 3D FFT Real-Complex Single-Precision Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   dft" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//
int main() {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);
    

    int returnCode = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            int status;
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";
            
            for (auto layout : {ROW_MAJOR, COL_MAJOR}){
                std::cout << "\tRunning with single precision real-to-complex 3-D FFT:" << std::endl;
                std::cout << "\tData layout: " << (layout == ROW_MAJOR ? "row-major" : "column-major") << std::endl;
                status = run_dft_forward_example(my_dev, layout);
                if (status != SUCCESS) {
                    std::cout << "\tTest Forward Failed" << std::endl << std::endl;
                    if (!returnCode) returnCode = status;
                } else {
                    std::cout << "\tTest Forward Passed" << std::endl << std::endl;
                }
                std::cout << "\tRunning with single precision complex-to-real 3-D FFT:" << std::endl;
                std::cout << "\tData layout: " << (layout == ROW_MAJOR ? "row-major" : "column-major") << std::endl;
                status = run_dft_backward_example(my_dev, layout);
                if (status != SUCCESS) {
                    std::cout << "\tTest Backward Failed" << std::endl << std::endl;
                    if (!returnCode) returnCode = status;
                } else {
                    std::cout << "\tTest Backward Passed" << std::endl << std::endl;
                }
            }
        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it] << " devices found; Fail on missing devices is enabled." << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping " << sycl_device_names[*it] << " tests." << std::endl << std::endl;
#endif
        }
    }

    mkl_free_buffers();
    return returnCode;
}
