!===============================================================================
! Copyright 2021-2022 Intel Corporation.
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
! A simple example of batch double-precision real-to-complex, complex-to-real 
! in-place 1D FFT using Intel(R) oneAPI Math Kernel Library (oneMKL) DFTI
!
!*****************************************************************************

include "mkl_dfti_omp_offload.f90"

program dp_real_1d_batch

  use MKL_DFTI_OMP_OFFLOAD, forget => DFTI_DOUBLE, DFTI_DOUBLE => DFTI_DOUBLE_R
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING
  ! Size of 1D transform
  integer, parameter :: N = 16
  
  integer, parameter :: halfNplus1 = N/2 + 1
  
  ! Number of transforms 
  integer, parameter :: M = 5

  ! Arbitrary harmonic used to verify FFT
  integer, parameter :: H = 1

  ! Working precision is double precision
  integer, parameter :: WP = selected_real_kind(15,307)

  ! Execution status
  integer :: status = 0, ignored_status

  ! The data array
  real(WP), allocatable :: x (:)

  ! DFTI descriptor handle
  type(DFTI_DESCRIPTOR), POINTER :: hand

  hand => null()

  print *,"Example dp_real_1d_batch"
  print *,"Batch forward and backward double-precision real-to-complex",      &
    &      " and complex-to-real in-place 1D transform"
  print *,"Configuration parameters:"
  print *,"DFTI_PRECISION      = DFTI_DOUBLE"
  print *,"DFTI_FORWARD_DOMAIN = DFTI_REAL"
  print *,"DFTI_DIMENSION      = 1"
  print '(" DFTI_NUMBER_OF_TRANSFORMS = "I0"")', M
  print '(" DFTI_LENGTHS        = /"I0"/" )', N

  print *,"Create DFTI descriptor"
  status = DftiCreateDescriptor(hand, DFTI_DOUBLE, DFTI_REAL, 1, N)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for CCE storage "
  status = DftiSetValue(hand, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for number of in-place transforms"
  status = DftiSetValue(hand, DFTI_NUMBER_OF_TRANSFORMS, M)
  if (0 /= status) goto 999
  
  ! The memory address of the first element must coincide on I/O
  ! --> different DFTI_INPUT_DISTANCE and DFTI_OUTPUT_DISTANCE to be set
  ! (storing a complex number requires to store two real numbers)
  print *,"Set DFTI descriptor for input distance for forward FFT"
  status = DftiSetValue(hand, DFTI_INPUT_DISTANCE, 2*halfNplus1)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for output distance for forward FFT"
  status = DftiSetValue(hand, DFTI_OUTPUT_DISTANCE, halfNplus1)
  if (0 /= status) goto 999

  print *,"Commit DFTI descriptor for forward FFT"
  !$omp dispatch
  status = DftiCommitDescriptor(hand)
  if (0 /= status) goto 999

  print *,"Allocate array for input/output data"
  allocate ( x(2*halfNplus1*M), STAT = status)
  if (0 /= status) goto 999

  print *,"Initialize input for real-to-complex forward FFT "
  call init_r(x, M, N, H)

  print *,"Compute forward transform"
  !$omp target data map(tofrom:x)
  !$omp dispatch
  status = DftiComputeForward(hand, x)
  !$omp end target data
  if (0 /= status) goto 999

  print *,"Verify the complex result"
  status = verify_c(x, M, N, H)
  if (0 /= status) goto 999

  ! DFTI_INPUT_DISTANCE and DFTI_OUTPUT_DISTANCE are to be reset for the 
  ! backward transform
  print *,"Set DFTI descriptor for input distance for backward FFT"
  status = DftiSetValue(hand, DFTI_INPUT_DISTANCE, halfNplus1)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for output distance for backward FFT"
  status = DftiSetValue(hand, DFTI_OUTPUT_DISTANCE, 2*halfNplus1)
  if (0 /= status) goto 999

  print *,"Commit DFTI descriptor for backward FFT"
  !$omp dispatch
  status = DftiCommitDescriptor(hand)
  if (0 /= status) goto 999
  
  print *,"Initialize input for complex-to-real backward transform"
  call init_c(x, M, N, H)

  print *,"Compute backward transform"
  !$omp target data map(tofrom:x)
  !$omp dispatch
  status = DftiComputeBackward(hand, x)
  !$omp end target data
  if (0 /= status) goto 999

  print *,"Verify the result"
  status = verify_r(x, M, N, H)
  if (0 /= status) goto 999

100 continue

  print *,"Release the DFTI descriptor"
  ignored_status = DftiFreeDescriptor(hand)

  if (allocated(x)) then
      print *,"Deallocate input data array"
      deallocate(x)
  endif

  if (status == 0) then
    print *,"TEST PASSED"
    call exit(0)
  else
    print *,"TEST FAILED"
    call exit(1)
  endif

999 print '("  Error, status = ",I0)', status
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure real(WP) function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = real(mod(k8*l,m),WP)
  end function moda

  ! Initialize real array x to produce unit peaks at y(H) and y(N-H)
  subroutine init_r(x, M, N, H)
    integer M, N, H
    real(WP) :: x(:)

    integer j, k, halfNplus1
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: factor
    
    if (mod(2*(N - H), N) == 0) then
      factor = 1.0_WP
    else
      factor = 2.0_WP
    end if
    halfNplus1 = N/2 + 1
    
    do j = 1, M
      do k = 1, N
        x((j-1)*2*halfNplus1 + k) = factor * cos(TWOPI*moda(k-1, H, N)/N) / N
      end do
    end do
  end subroutine init_r

  ! Verify that y(k) is unit peak at k = H
  integer function verify_c(y, M, N, H)
    integer M, N, H
    real(WP) :: y(:)

    integer j, k, halfNplus1
    real(WP) err, errthr, maxerr
    complex(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr
    
    halfNplus1 = N/2 +1 
    maxerr = 0.0_WP
    do j = 1, M
      do k = 1, halfNplus1
        if (mod(k-1-H,N)==0 .OR. mod(1-k-H,N)==0) then
          res_exp = 1.0_WP
        else
          res_exp = 0.0_WP
        end if
        res_got = DCMPLX(y((j-1)*2*halfNplus1 + 2*k-1), y((j-1)*2*halfNplus1 + 2*k))
        err = abs(res_got - res_exp)
        maxerr = max(err,maxerr)
        if (.not.(err < errthr)) then
          print '(" Batch #"I0" y("I0"): "$)', j, k
          print '(" expected ("G24.17", "G24.17"),"$)', res_exp
          print '(" got ("G24.17", "G24.17"),"$)', res_got
          print '(" err "G10.3)', err
          print *," Verification FAILED"
          verify_c = 100
          return
        end if
      end do
    end do
    
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_c = 0
  end function verify_c
  
  ! Initialize complex array y to produce unit peaks at x(H)
  subroutine init_c(y, M, N, H)
    integer M, N, H
    real(WP) :: y(:)

    integer j, k, halfNplus1
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: TWOPI_phase
    
    halfNplus1 = N/2 + 1
    
    do j = 1, M
      do k = 1, halfNplus1
        TWOPI_phase = TWOPI*moda(k-1, H, N)/N
        y((j-1)*2*halfNplus1 + 2*k-1)  =  cos(TWOPI_phase)/N ! real part
        y((j-1)*2*halfNplus1 + 2*k)    = -sin(TWOPI_phase)/N ! imaginary part
      end do
    end do
  end subroutine init_c
  
  ! Verify that x(k) is unit peak at k = H
  integer function verify_r(x, M, N, H)
    integer M, N, H
    real(WP) :: x(:)

    integer j, k, halfNplus1
    real(WP) err, errthr, maxerr
    real(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    halfNplus1 = N/2 + 1
    maxerr = 0.0_WP
    do j = 1, M
      do k = 1, N
        if (mod(k-1-H,N)==0) then
          res_exp = 1.0_WP
        else
          res_exp = 0.0_WP
        end if
        res_got = x((j-1)*2*halfNplus1 + k)
        err = abs(res_got - res_exp)
        maxerr = max(err,maxerr)
        if (.not.(err < errthr)) then
          print '(" Batch #"I0" x("I0"): "$)', j, k
          print '(" expected "G24.17","$)', res_exp
          print '(" got "G24.17","$)', res_got
          print '(" err "G10.3)', err
          print *," Verification FAILED"
          verify_r = 100
          return
        end if
      end do
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_r = 0
  end function verify_r

end program dp_real_1d_batch
