/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#define ESIMD_UNROLL _Pragma("unroll")

//
// ESIMD kernel for SpTRSV using ESB4 format with blockptr_st, blockptr_en, colind, values
//
// note that in ESB format, blockptr_st points to vector of block where lower or upper starts for the
// block and blockptr_en points to where lower or upper ends on the block, but there may be elements
// within this set of vectors that are on the otherside, so we must load and then check with mask while
// accumulating
//



auto trsv_esb4_esimd_kernel = [=](const local_int_t block,
                                  const local_int_t vec_st,
                                  const local_int_t vec_en,
                                  const local_int_t *colind,
                                  const double * values,
                                  double *x,
                                  double *y,
                                  const double *diag) SYCL_ESIMD_KERNEL {

    const local_int_t row_st = block * block_size;
    esimd::simd<double, block_size> t_vec(0.0), zero_vec(0.0);
    esimd::simd<local_int_t, block_size> offset(block * block_size, 1);

    auto x_vec       = esimd_lsc_block_load<double, local_int_t, block_size, st, uc >(x, row_st);
    auto diagvals    = esimd_lsc_block_load<double, local_int_t, block_size, st, uc>(diag, row_st);
    auto invDiagvals = 1.0 / diagvals;

#ifndef USE_TRSV_UNROLL_KERNELS
    ESIMD_UNROLL
    for ( local_int_t j = vec_st; j < vec_en; ++j) {
        esimd::simd<local_int_t, block_size> cols  = esimd_lsc_block_load<local_int_t, local_int_t, block_size, st, uc>(colind, j * block_size);
#ifdef USE_LOWER
        esimd::simd_mask<block_size> uplomask = (cols < offset) && (cols >= 0); // lower + not fill-in
#else
        esimd::simd_mask<block_size> uplomask = (cols > offset) && (cols < nrows) && (cols >= 0); // upper + local + not fill-in
#endif
        esimd::simd<double, block_size> y_vec = esimd_lsc_gather<double, local_int_t, block_size, ca, ca>(y, cols, uplomask, zero_vec);
        esimd::simd<double, block_size> vals  = esimd_lsc_block_load<double, local_int_t, block_size, st, uc>(values, j * block_size);
        vals.merge(zero_vec, !uplomask);

        // accumulate t_vec += L * y or += U*y with uplo mask
        t_vec += vals * y_vec;
    }
#else
#ifdef USE_LOWER
    trsv_unroll_dispatch<block_size, ESBTRSVL>(vec_st, vec_en, t_vec, values, colind, y, nrows, offset);
#else
    trsv_unroll_dispatch<block_size, ESBTRSVU>(vec_st, vec_en, t_vec, values, colind, y, nrows, offset);
#endif
#endif

    if constexpr (isFused) {
#ifdef USE_LOWER
        auto w_vec = esimd_lsc_block_load<double, local_int_t, block_size, st, uc >(y, row_st);
        w_vec = w_vec + x_vec - t_vec;
        esimd_lsc_block_store<double, local_int_t, block_size, uc, uc>(x, row_st, w_vec);
#else
        x_vec = diagvals * x_vec;
        esimd_lsc_block_store<double, local_int_t, block_size, uc, uc>(x, row_st, x_vec);        
#endif
    }

    // y = (x - L*y) * invD  or  y = (x - U*y) * invD
    esimd::simd<double, block_size> y_vec = (x_vec - t_vec) * invDiagvals;
    esimd_lsc_block_store<double, local_int_t, block_size, uc, uc>(y, row_st, y_vec);
};



//
// Similar kernel as before but for when a block extends outside the color, a mask is provided
// to keep all loads and accumulations within the color
//

auto trsv_esb4_masked_esimd_kernel = [=](const local_int_t block,
                                         const esimd::simd_mask<block_size> &locmask,
                                         const local_int_t vec_st,
                                         const local_int_t vec_en,
                                         const local_int_t *colind,
                                         const double * values,
                                         double *x,
                                         double *y,
                                         const double *diag) SYCL_ESIMD_KERNEL {

    local_int_t vec_loc = vec_st * block_size;
    const local_int_t row_st = block * block_size;
    esimd::simd<double, block_size> t_vec(0.0), zero_vec(0.0);
    esimd::simd<local_int_t, block_size> offset(block * block_size, 1);


    esimd::simd<local_int_t, block_size> iota(0,1); // 0, 1, 2, ... block_size-1

    auto x_vec       = esimd_lsc_gather<double, local_int_t, block_size, st, uc>(x, row_st+iota, locmask, zero_vec);
    auto diagvals    = esimd_lsc_gather<double, local_int_t, block_size, st, uc>(diag, row_st+iota, locmask, esimd::simd<double, block_size>(1.0));
    auto invDiagvals = 1.0 / diagvals;

    ESIMD_UNROLL
    for ( local_int_t j = vec_st; j < vec_en; ++j) {
        esimd::simd<local_int_t, block_size> cols = esimd_lsc_gather<local_int_t, local_int_t, block_size, st, uc>(colind, j * block_size + iota, locmask);
#ifdef USE_LOWER
        esimd::simd_mask<block_size> uplomask = (cols < offset) && (cols >= 0) && locmask;
#else
        esimd::simd_mask<block_size> uplomask = (cols > offset) && (cols < nrows) && (cols >= 0) && locmask;
#endif
        esimd::simd<double, block_size> y_vec = esimd_lsc_gather<double, local_int_t, block_size, ca, ca>(y, cols, uplomask, zero_vec);
        esimd::simd<double, block_size> vals  = esimd_lsc_gather<double, local_int_t, block_size, st, uc>(values, j * block_size + iota, locmask, zero_vec);

        // accumulate t_vec += L * y or += U*y with uplo mask
        t_vec += vals * y_vec;
    }

    if constexpr (isFused) {
#ifdef USE_LOWER
        auto w_vec = esimd_lsc_gather<double, local_int_t, block_size, st, uc>(y, row_st+iota, locmask, zero_vec);
        w_vec = w_vec + x_vec - t_vec;
        esimd_lsc_scatter<double, local_int_t, block_size, uc, uc>(x, row_st+iota, w_vec, locmask);
#else
        x_vec = diagvals * x_vec;
        esimd_lsc_scatter<double, local_int_t, block_size, uc, uc>(x, row_st+iota, x_vec, locmask);        
#endif
    }

    // y = (x - L*y) * invD  or  y = (x - U*y) * invD
    esimd::simd<double, block_size> y_vec = (x_vec - t_vec) * invDiagvals;

    esimd_lsc_scatter<double, local_int_t, block_size, uc, uc>(y, row_st+iota, y_vec, locmask);

};
