!===============================================================================
! Copyright 2024 Intel Corporation.
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

!  Content:
!      Intel(R) oneAPI Math Kernel Library (oneMKL)
!      FORTRAN OpenMP offload example for SAXPY_BATCH
!*******************************************************************************

include "mkl_omp_offload.f90"
include "common_blas.f90"

program saxpy_batch_example
#if defined(MKL_ILP64)
use onemkl_blas_omp_offload_ilp64
#else
use onemkl_blas_omp_offload_lp64
#endif
use common_blas
use, intrinsic :: ISO_C_BINDING

real :: alpha(2)
integer :: passed
integer :: n(2)
real,allocatable,target :: x(:,:), y(:,:)
real,allocatable :: y_ref(:,:)
integer :: incx(2), incy(2)
integer(KIND=C_SIZE_T),allocatable :: x_array(:), y_array(:), y_ref_array(:)
integer(KIND=C_SIZE_T),allocatable :: x_array_dev(:), y_array_dev(:)
real,pointer :: tmp_x(:), tmp_y(:)
integer :: group_size(2), group_count = 2, total_batch_size = 0, i
integer :: max_size_x = 0, max_size_y = 0

do i = 1, group_count
  n(i) = i + 10
  alpha(i) = 1.4
  incx(i) = 1
  incy(i) = 2

  group_size(i) = 4 + i
  total_batch_size = total_batch_size + group_size(i)

  if (max_size_x.lt.(n(i)*incx(i))) max_size_x = n(i)*incx(i)
  if (max_size_y.lt.(n(i)*incy(i))) max_size_y = n(i)*incy(i)
end do

allocate(x(max_size_x,total_batch_size))
allocate(y(max_size_y,total_batch_size))
allocate(y_ref(max_size_y,total_batch_size))
allocate(x_array(total_batch_size))
allocate(y_array(total_batch_size))
allocate(y_ref_array(total_batch_size))
allocate(x_array_dev(total_batch_size))
allocate(y_array_dev(total_batch_size))

if ((.not.allocated(x)) .or. (.not.allocated(y)) .or. (.not.allocated(y_ref))) then
  print *, "Cannot allocate vectors"
  goto 998
end if

if ((.not.allocated(x_array)) .or. (.not.allocated(y_array)) .or. (.not.allocated(y_ref_array))) then
  print *, "Cannot allocate array of pointers"
  goto 998
end if

if ((.not.allocated(x_array_dev)) .or. (.not.allocated(y_array_dev))) then
  print *, "Cannot allocate array of device pointers"
  goto 998
end if

call sinit_batch_vector(max_size_x, n, incx, x, group_size, group_count)
call sinit_batch_vector(max_size_y, n, incy, y, group_size, group_count)
call scopy_batch_vector(max_size_y, n, incy, y, y_ref, group_size, group_count)

do i = 1, total_batch_size
  x_array(i) = LOC(x(1,i))
  y_array(i) = LOC(y(1,i))
  y_ref_array(i) = LOC(y_ref(1,i))
end do

call saxpy_batch(n, alpha, x_array, incx, y_ref_array, incy, group_count, group_size)

! map each matrix to the device and store the device pointers into arrays 
do i = 1, total_batch_size
!$omp target enter data map(to:x(:,i),y(:,i))
  tmp_x => x(:,i)
  tmp_y => y(:,i)
!$omp target data use_device_addr(tmp_x,tmp_y)
  x_array_dev(i) = LOC(tmp_x)
  y_array_dev(i) = LOC(tmp_y)
!$omp end target data
end do

!$omp target data map(to:x_array_dev) map(tofrom:y_array_dev)
!$omp dispatch
call saxpy_batch(n, alpha, x_array_dev, incx, y_array_dev, incy, group_count, group_size)
!$omp end target data

do i = 1, total_batch_size
!$omp target exit data map(from:x(:,i),y(:,i))
end do

passed = scheck_batch_vector(max_size_y, n, incy, y, y_ref, group_size, group_count)

deallocate(x);
deallocate(y);
deallocate(y_ref);
deallocate(x_array);
deallocate(x_array_dev);
deallocate(y_array);
deallocate(y_array_dev);
deallocate(y_ref_array);

if (passed.ne.0) then
  goto 999
else
  print *, "PASSED"
end if

stop
998 print *, 'Error: cannot allocate memory'
999 stop 1
end program
