/*******************************************************************************
* Copyright 2020-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of DPC++ API for oneMKL RNG service
*       functionality
*
*       Functions list:
*           oneapi::mkl::rng::skip_ahead
*           oneapi::mkl::rng::leapfrog
*           oneapi::mkl::rng::save_state
*           oneapi::mkl::rng::load_state
*           oneapi::mkl::rng::get_state_size
*
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>
#include <math.h>

#include <sycl/sycl.hpp>
#include "oneapi/mkl.hpp"

// local includes
#include "../include/common_for_rng_examples.hpp"

// example parameters defines
constexpr std::size_t n = 100;
constexpr std::size_t s = 10;
constexpr std::size_t ns = 10;
constexpr std::size_t n_print = 100;

template <typename RealType>
bool run_skip_ahead_example(sycl::queue& queue) {
    std::cout << "\n\tRun skip_ahead example" << std::endl;

    oneapi::mkl::rng::philox4x32x10 engine(queue);

    std::vector<oneapi::mkl::rng::philox4x32x10> engine_vec;

    for (int i = 0; i < s; i++) {
        // copy reference engine to engine_vec[i]
        engine_vec.push_back(oneapi::mkl::rng::philox4x32x10{ engine });
        // skip ahead engine
        oneapi::mkl::rng::skip_ahead(engine_vec[i], i * ns);
    }

    // prepare array for random numbers
    sycl::usm_allocator<RealType, sycl::usm::alloc::shared> allocator(queue);
    std::vector<RealType, decltype(allocator)> r_ref(n, allocator);
    std::vector<RealType, decltype(allocator)> r(n, allocator);

    oneapi::mkl::rng::uniform<RealType> distr{};
    sycl::event events[s + 1];

    try {
        // fill r_ref with n random numbers
        events[s] = oneapi::mkl::rng::generate(distr, engine, n, r_ref.data());

        // fill r with random numbers by portions of ns
        for (int i = 0; i < s; i++) {
            events[i] = oneapi::mkl::rng::generate(distr, engine_vec[i], ns, r.data() + i * ns);
        }
        for (int i = 0; i < s + 1; i++) {
            events[i].wait_and_throw();
        }
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }
    std::cout << "\t\tOutput:" << std::endl;
    print_output(r.data(), r.size(), n_print);
    std::cout << "\t\tReference output:" << std::endl;
    print_output(r_ref.data(), r_ref.size(), n_print);

    // validation
    for (int i = 0; i < n; i++) {
        if (r[i] != r_ref[i]) {
            std::cout << "Fail at " << i << " element" << std::endl;
            return false;
        }
    }
    std::cout << "Success" << std::endl;
    return true;
}

template <typename RealType>
bool run_skip_ahead_ex_example(sycl::queue& queue) {
    std::cout << "\n\tRun skip_ahead extended example" << std::endl;

    oneapi::mkl::rng::mrg32k3a engine_1(queue);

    oneapi::mkl::rng::mrg32k3a engine_2(engine_1);

    // to skip 2^76 elements in the random engine with skip_ahead function should be called 2^14 times
    //    with nskip equal to 2^62
    std::uint64_t nskip = (std::uint64_t)pow(2, 62);
    std::uint64_t skip_times = (std::uint64_t)pow(2, 14);

    for (std::uint64_t i = 0; i < skip_times; i++) {
        oneapi::mkl::rng::skip_ahead(engine_1, nskip);
    }
    // skip 2^76 elements in the engine with advanced skip_ahead function should be called
    //    with nskip represented as
    //        nskip = 2^76 = 0 + 2^12 * 2^64
    //    in general case:
    //        nskip = params[0] + params[1] * 2^64 + params[2] * 2^128 + ...
    oneapi::mkl::rng::skip_ahead(engine_2, { 0, (std::uint64_t)pow(2, 12) });

    // prepare array for random numbers
    sycl::usm_allocator<RealType, sycl::usm::alloc::shared> allocator(queue);
    std::vector<RealType, decltype(allocator)> r_ref(n, allocator);
    std::vector<RealType, decltype(allocator)> r(n, allocator);
    oneapi::mkl::rng::uniform<RealType> distr{};

    try {
        auto event_1 = oneapi::mkl::rng::generate(distr, engine_1, n, r_ref.data());
        auto event_2 = oneapi::mkl::rng::generate(distr, engine_2, n, r.data());
        event_1.wait_and_throw();
        event_2.wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }
    std::cout << "\t\tOutput:" << std::endl;
    print_output(r.data(), r.size(), n_print);
    std::cout << "\t\tReference output:" << std::endl;
    print_output(r_ref.data(), r_ref.size(), n_print);

    // validation
    for (int i = 0; i < n; i++) {
        if (r[i] != r_ref[i]) {
            std::cout << "Fail at " << i << " element" << std::endl;
            return false;
        }
    }
    std::cout << "Success" << std::endl;
    return true;
}

template <typename RealType>
bool run_leapfrog_example(sycl::queue& queue) {
    std::cout << "\n\tRun leapfrog example:" << std::endl;

    oneapi::mkl::rng::mcg31m1 engine(queue);

    std::vector<oneapi::mkl::rng::mcg31m1> engine_vec;

    for (int i = 0; i < s; i++) {
        // copy reference engine to engine_vec[i]
        engine_vec.push_back(oneapi::mkl::rng::mcg31m1{ engine });
        // skip ahead engine
        oneapi::mkl::rng::leapfrog(engine_vec[i], i, s);
    }

    // prepare array for random numbers
    sycl::usm_allocator<RealType, sycl::usm::alloc::shared> allocator(queue);
    std::vector<RealType, decltype(allocator)> r_ref(n, allocator);
    std::vector<RealType, decltype(allocator)> r(n, allocator);

    oneapi::mkl::rng::uniform<RealType> distr{};
    sycl::event events[s + 1];

    try {
        // fill r_ref with n random numbers
        events[s] = oneapi::mkl::rng::generate(distr, engine, n, r_ref.data());

        // fill r with random numbers by portions of ns
        for (int i = 0; i < s; i++) {
            events[i] = oneapi::mkl::rng::generate(distr, engine_vec[i], ns, r.data() + i * ns);
        }
        for (int i = 0; i < s + 1; i++) {
            events[i].wait_and_throw();
        }
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }
    std::cout << "\t\tOutput:" << std::endl;
    print_output(r.data(), r.size(), n_print);
    std::cout << "\t\tReference output:" << std::endl;
    print_output(r_ref.data(), r_ref.size(), n_print);

    // validation
    int j = 0;
    for (int i = 0; i < ns; i++) {
        for (int k = 0; k < ns; k++) {
            if (r[j++] != r_ref[k * ns + i]) {
                std::cout << "Fail at " << i << " element" << std::endl;
                return false;
            }
        }
    }
    std::cout << "Success" << std::endl;
    return true;
}

template <typename RealType>
bool run_save_load_state_memory_example(sycl::queue& queue) {
    std::cout << "\n\tRun save/load state to memory example" << std::endl;

    oneapi::mkl::rng::philox4x32x10 engine(queue);

    // prepare array for random numbers
    sycl::usm_allocator<RealType, sycl::usm::alloc::shared> allocator(queue);
    std::vector<RealType, decltype(allocator)> r_ref(n, allocator);
    std::vector<RealType, decltype(allocator)> r(n, allocator);

    oneapi::mkl::rng::uniform<RealType> distr{};

    try {
        auto event = oneapi::mkl::rng::generate(distr, engine, n, r_ref.data());
        event.wait_and_throw();

        // check how much bytes is needed to save engine state
        std::int64_t state_size = oneapi::mkl::rng::get_state_size(engine);

        // allocate required memory
        std::vector<uint8_t> mem(state_size);
        oneapi::mkl::rng::save_state(engine, mem.data());

        // fill r_ref with n random numbers
        auto event_fill = oneapi::mkl::rng::generate(distr, engine, n, r_ref.data());
        event_fill.wait_and_throw();

        std::cout << "\t\tReference output:" << std::endl;
        print_output(r_ref.data(), r_ref.size(), n_print);

        // load engine state from memory, new oneapi::mkl::rng::philox4x32x10 object is created
        auto loaded_engine =
            oneapi::mkl::rng::load_state<oneapi::mkl::rng::philox4x32x10>(queue, mem.data());

        // fill r with n random numbers using loaded_engine
        auto event_load = oneapi::mkl::rng::generate(distr, loaded_engine, n, r.data());
        event_load.wait_and_throw();

        std::cout << "\t\tOutput from loaded engine:" << std::endl;
        print_output(r.data(), r.size(), n_print);
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }

    // validation
    for (int i = 0; i < n; i++) {
        if (r[i] != r_ref[i]) {
            std::cout << "Fail at " << i << " element" << std::endl;
            return false;
        }
    }
    std::cout << "Success" << std::endl;

    return true;
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Demonstrate service functionality usage for random number generators:"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   oneapi::mkl::rng::skip_ahead" << std::endl;
    std::cout << "#   oneapi::mkl::rng::leapfrog" << std::endl;
    std::cout << "#   oneapi::mkl::rng::save_state" << std::endl;
    std::cout << "#   oneapi::mkl::rng::load_state" << std::endl;
    std::cout << "#   oneapi::mkl::rng::get_state_size" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main(int argc, char** argv) {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";
            sycl::queue queue(my_dev, exception_handler);

            if (!run_skip_ahead_example<float>(queue) || !run_skip_ahead_ex_example<float>(queue) ||
                !run_leapfrog_example<float>(queue) ||
                !run_save_load_state_memory_example<float>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
