/* Copyright (C) 2023 Intel Corporation
 * SPDX-License-Identifier: BSD-3-Clause
 */
#include <CL/sycl.hpp>
#include <ishmem.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <shmem.h>
#include <ishmemx.h>
#include "esp.h"

#define NUM_ITERATIONS 500

using namespace std;
char hostname[128];

#define NULL_CHECK(p) \
  do { \
        if (p == nullptr) { \
            fprintf(stderr, "FATAL: Could not allocate " #p "\n"); \
            return; \
        } \
  } while (0)

//
// A simple pattern which uses a PE's GPU to send data to another PE's GPU via ishmem calls.
// We then validate it, and create some SHMEM traffic to share between hosts
//  as an example of combining SHMEM with iSHMEM.
//
void simple_host_pe_exchange(const int num_array_elems)
{
    assert (num_array_elems > 0 && "Error: num_array_elems must be greater than 0.");

    int myPE = ishmem_my_pe();
    int npes = ishmem_n_pes();

    // OpenSHMEM buffers for intra-node
    int *pSymmetric    = (int *)shmem_malloc(num_array_elems * sizeof(int));
    int *pSymmetricSrc = (int *)shmem_malloc(num_array_elems * sizeof(int));
    int *pLocal     = (int *)malloc(num_array_elems * sizeof(int));
    NULL_CHECK(pSymmetric); NULL_CHECK(pLocal); NULL_CHECK(pSymmetricSrc);
    memset(pSymmetric, 0x666,  num_array_elems * sizeof(int));
    memset(pLocal, (0xDEAD << 16) + myPE,  num_array_elems * sizeof(int));

    // iSHMEM buffers for GPU offload
    int *pISHM_Src      = (int *) ishmem_malloc(num_array_elems * sizeof(int));
    int *pISHM_Sink     = (int *) ishmem_malloc(num_array_elems * sizeof(int));
    NULL_CHECK(pISHM_Src); NULL_CHECK(pISHM_Sink);

    sycl::queue q;
    int *pSYCL_Bounce   = (int *) sycl::malloc_host<int>(num_array_elems, q);
    int *pSYCL_Errors   = (int *) sycl::malloc_host<int>(1, q);
    NULL_CHECK(pSYCL_Bounce); NULL_CHECK(pSYCL_Errors);

    // Init the iSHMEM sink buffers.
    auto e_init = q.submit([&](sycl::handler &h)
    {
        h.parallel_for(sycl::nd_range<1>{num_array_elems, num_array_elems}, [=](sycl::nd_item<1> idx)
        {
            int i = idx.get_global_id()[0];
            pISHM_Src[i]  = (int) ((0xDEAD << 16) + myPE);
            pISHM_Sink[i] = (int) ((0xBEEF << 16) + myPE);
        });
    });
    e_init.wait_and_throw();
    ishmem_barrier_all();

    // Now have the GPUs send data.
    auto e_run = q.submit([&](sycl::handler &h)
    {
        h.single_task([=]()
        {
            int my_dev_pe = ishmem_my_pe();
            int my_dev_npes = ishmem_n_pes();
            ishmem_int_put(pISHM_Sink, pISHM_Src, num_array_elems, (my_dev_pe + 1) % my_dev_npes );  
        });
    });
    e_run.wait_and_throw();
    ishmem_barrier_all();

    // Have CPU init pSYCL_Errors
    *pSYCL_Errors = 0;

    // Now, have the GPU verify the results
    auto e_verify = q.submit([&](sycl::handler &h) 
    {
        h.single_task([=]()
        {
            for (int x = 0; x < num_array_elems; x++)
            {
                int destPE = (myPE == 0) ? npes - 1 : ((myPE - 1) % npes);
                if (pISHM_Sink[x] != (int) ((0xDEAD << 16) + destPE))
                {
                    *pSYCL_Errors = *pSYCL_Errors + 1;
                }
            }
        });
    });
    e_verify.wait_and_throw();
    ishmem_barrier_all();

    // Now have the CPU access the sycl buffer to check for errs
    if (*pSYCL_Errors > 0) {
        cerr << "[" << hostname << "][" << myPE << "]:ERROR: GPU reported " << *pSYCL_Errors << " errors." << endl << flush;
    }
    
    // Have the GPU copy to the host-accessible buffer
    memset(pSYCL_Bounce, 0x123, num_array_elems * sizeof(int));
    auto e_bounce = q.submit([&](sycl::handler &h)
    {
        h.memcpy(pSYCL_Bounce, pISHM_Sink, num_array_elems * sizeof(int));
    });
    e_bounce.wait_and_throw();
    ishmem_barrier_all();

    for (int x = 0; x < num_array_elems; x++)
    {
        if (pSYCL_Bounce[x] == 0x123)
        {
            cerr << endl << "[" << hostname << "][" << myPE << "]ERROR: Bounce Buffer [" << x << "] Read 0x" << hex << pSYCL_Bounce[x] << "." << endl << flush;
            ishmem_free(pISHM_Src);
            ishmem_free(pISHM_Sink);
            sycl::free(pSYCL_Errors, q);
            sycl::free(pSYCL_Bounce, q);
            shmem_free(pSymmetric);
            free(pLocal);
            return;
        }
    }

    // Now call host-based shmem_int_put() after copying to a SHMEM-safe symmetric buffer
    memcpy(pSymmetricSrc, pSYCL_Bounce, num_array_elems * sizeof(int));

    esp_enter("shmemPut");
    int mySHPE = shmem_my_pe();
    int numSHPEs = shmem_n_pes();
    int destPE = (myPE == 0) ? npes - 1 : ((myPE - 1) % npes);
    shmem_int_put(pSymmetric, pSymmetricSrc, num_array_elems, destPE);
    shmem_barrier_all();
    esp_exit("shmemPut");  

    // Finally, check the buffer
    for (int x = 0; x < num_array_elems; x++)
    {
        if (pSymmetric[x] == 0x666)
        {
            cerr << "[" << hostname << "][" << myPE << "]ERROR: Host read error." << endl << flush;
        } 
    }

    // free the world
    ishmem_free(pISHM_Src);
    ishmem_free(pISHM_Sink);
    sycl::free(pSYCL_Errors, q);
    sycl::free(pSYCL_Bounce, q);
    shmem_free(pSymmetric);
    free(pLocal);
}

int main(int argc, char **argv)
{
        gethostname(hostname, 127);

        ishmem_init();
        int myPE = shmem_my_pe();
	printf("\n [Hostnamme: %s][PE: %d]: sycl_sanity: Start\n", hostname, myPE); fflush(stdout);
    
        int npes = ishmem_n_pes();
        if (npes < 2) {
            cerr << "ERROR: This app supports > two PEs!" << endl;
            ishmem_finalize();
            return 1;
        }

        char myName[128];
        ishmem_info_get_name(myName);
        printf("myName = %s \n", myName); fflush(stdout);
        
        int maj, min;
        ishmem_info_get_version(&maj, &min);
        printf("iSHMEM Version: %d.%d \n", maj, min); fflush(stdout);
         
        // Enumerate the platform
	auto platforms = sycl::platform::get_platforms();
	for (auto &platform : platforms) 
	{
		cout << "Platform: " << platform.get_info<sycl::info::platform::name>() << endl << flush;
	        auto devices = platform.get_devices();
	        for (auto &device : devices) 
		{
			cout << "  Device: " << device.get_info<sycl::info::device::name>() << endl << flush;
		}
	}	    

        // Now do something with the GPUs and the host
        for (int x = 0; x < NUM_ITERATIONS; x++)
        {
            simple_host_pe_exchange(ishmem_n_pes());
        }

	ishmem_finalize();
	printf("sycl_sanity: End\n");
}

