Commit baf456c9 authored by Nichols, Stephen's avatar Nichols, Stephen
Browse files

Changes to be committed:

	new file:   matrix_ops.f90

Forgot to add this file ..
parent ba958195
Loading
Loading
Loading
Loading

matrix_ops.f90

0 → 100644
+431 −0
Original line number Diff line number Diff line
      module matrix_operations
      use omp_lib
      use global_parameters
      implicit none

      contains
! =========================================================================
        subroutine matrix_vect_mult(neqn,aa,bb,cc)
        implicit none

#ifdef USE_OPENMP_OFFLOAD
        !$omp declare target
#endif

        integer, intent(in) :: neqn
        real*8, intent(in) :: aa(neqn,neqn),bb(neqn)
        real*8, intent(out) :: cc(neqn)

        integer :: i,j,k
        real*8 :: dummy

        do i = 1,neqn
            dummy = aa(i,1)*bb(1)
            do k = 2,neqn
                dummy = dummy + aa(i,k)*bb(k)
            enddo
            cc(i) = dummy
        enddo

        return
        end subroutine matrix_vect_mult
! =========================================================================
        subroutine matrix_matrix_mult(neqn,aa,bb,cc)
        implicit none

#ifdef USE_OPENMP_OFFLOAD
        !$omp declare target
#endif

        integer, intent(in) :: neqn
        real*8, intent(in) :: aa(neqn,neqn),bb(neqn,neqn)
        real*8, intent(out) :: cc(neqn,neqn)

        integer :: i,j,k
        real*8 :: dummy

        do i = 1,neqn
            do j = 1,neqn
                dummy = aa(i,1)*bb(1,j)
                do k = 2,neqn
                    dummy = dummy + aa(i,k)*bb(k,j)
                enddo
                cc(i,j) = dummy
            enddo
        enddo

        return
        end subroutine matrix_matrix_mult
! =========================================================================
        subroutine small_value(neqn,aa)
        implicit none

#ifdef USE_OPENMP_OFFLOAD
        !$omp declare target
#endif

        integer, intent(in) :: neqn
        real*8, intent(inout) :: aa(neqn,neqn)

        integer :: i,j

        do j = 1,neqn
            do i = 1,neqn
                if (dabs(aa(i,j)) .lt. 1.0e-15) then
                    aa(i,j) = 0.d0
                endif
            enddo
        enddo

        return
        end subroutine small_value
! =========================================================================
      subroutine ludecomp(neqn,ppivot,aa,matrix_info)
      implicit none

#ifdef USE_OPENMP_OFFLOAD
      !$omp declare target
#endif

      ! it uses gaussian elimination with row pivoting to reduce
      ! a (neqn x neqn) matrix to an upper triangular form.
      ! NOTE: also known as LU decompositon

      integer, intent(in) :: neqn
      integer, intent(inout) :: ppivot(neqn)
      real *8, intent(inout) :: aa(neqn,neqn)
      integer, intent(inout) :: matrix_info
     
      real*8 :: mult,apm
      integer :: l,m,row,col,pivot_loc,xtra
      integer :: pl, pm
      real*8 :: diag,maxdiag,absdiag

      ! initialize values
      do m = 1,neqn
          ppivot(m) = m
      end do

      do m = 1,neqn-1
          
          row = m
          maxdiag = dabs(aa(ppivot(row),m))

          do l = m+1,neqn
              if (dabs(aa(ppivot(l),m)) .gt. maxdiag) then
                  row = l
                  maxdiag = dabs(aa(ppivot(row),m))
              endif
          end do

          pivot_loc = ppivot(row)

          ! pivot the rows in-place (ie. re-map the rows instead
          !                          of physically exchanging rows)
          if (pivot_loc .ne. ppivot(m)) then
              xtra = ppivot(m)
              ppivot(m) = pivot_loc
              ppivot(row) = xtra
          endif

          pm = ppivot(m)
          absdiag = dabs(aa(pm,m))

          ! Check for singularity ...
          ! Instead of returning a error if absdiag <= 1.0e-16 like I would on the CPU,
          ! set diag so that everything will work and let the error check identify the 
          ! problem after the lu decomp is completed.
          ! NOTE: OpenMP has difficulty building GPU code with logic to return early
          if (absdiag .gt. 1.d-16) then
             diag = aa(pm,m)
          else
             diag = 1.d+16
          endif

          ! find and store multipliers for the column
          do row = m+1,neqn
              pl = ppivot(row)
              if (aa(pl,m) .ne. 0.d0) then   !! no need to mult by 0.0
                  mult = aa(pl,m)/diag  
                  aa(pl,m) = mult
                  !aa(pl,m) = aa(pl,m)/diag
              endif
          end do

          ! reduce the rows in a column-oriented fashion
          do col = m+1,neqn
              if (aa(pm,col) .ne. 0.d0) then !! no need to add 0.0
                  apm = aa(pm,col)
                  do row = m+1,neqn
                      pl = ppivot(row)
                      if (aa(pl,m) .ne. 0.d0) then !! no need to mult by 0.0
                          mult = aa(pl,m)
                          aa(pl,col) = aa(pl,col) - mult*apm
                      endif
                  end do
              endif
          end do

      end do !end of the lu decomp loop

      ! check the matrix for problems
      call check_matrix(neqn,ppivot,aa,matrix_info)

      return
      end subroutine ludecomp

! =========================================================================o

      subroutine check_matrix(neqn,ppivot,aa,matrix_info)
      implicit none

#ifdef USE_OPENMP_OFFLOAD
      !$omp declare target
#endif

      integer, intent(in) ::  neqn
      integer, intent(inout) :: ppivot(neqn)
      real *8, intent(inout) :: aa(neqn,neqn)
      integer, intent(inout) :: matrix_info

      integer :: l
      real *8 :: ratio
      real *8 :: det
      integer :: pm
      real *8 :: diag,absdiag,min_val,max_val

      det = 1.d0
      max_val = 0.d0
      min_val = 1.d+16

      do l = 1,neqn

         pm = ppivot(l)
         diag = aa(pm,l)
         absdiag = dabs(diag)
         det = det*(absdiag)
         max_val = max(absdiag,max_val)
         min_val = min(absdiag,min_val)

         ! check for singularity at every location
         if (absdiag .le. 1.d-16) then
              !write(*,*) 'c matrix is not invertible on target'
              !write(*,*) 'diagonal entry <= 1.0e-16'
              !write(*,*) 'halting program'
              !write(*,*) ' '
              matrix_info = 1
         endif

      enddo

      if (matrix_info .eq. 0) then
         ratio = min_val/max_val

         if ((max_val .le. 1.d-16) .and. (min_val .le. 1.d-16)) then
             ! all diagonal values <= 1.d-16
             ! NOTE: matrix needs better scaling
             !write(*,*) 'all diagonal values <= 1.0e-16 on target'
             !write(*,*) 'abs(det) = ',det
             !write(*,*) 'needs better scaling'
             !write(*,*) 'halting program'
             !write(*,*) ' '
             matrix_info = 3
         elseif (ratio .le. 1.d-15) then
             ! check for poor scaling between the diagonal values
             ! NOTE: leads to poor conditioning
             !write(*,*) 'c matrix is ill-conditioned on target'
             !write(*,*) 'ratio = abs((min_diag)/(max_diag)) = ',ratio
             !write(*,*) 'abs(det) = ',det
             !write(*,*) 'halting program'
             !write(*,*) ' '
             matrix_info = 2
         endif

      endif

      return
      end subroutine check_matrix

! =========================================================================
      subroutine fwd_bk_sub(neqn,nn,i,j,k)
      implicit none

#ifdef USE_OPENMP_OFFLOAD
      !$omp declare target
#endif

      integer, intent(in) :: neqn,nn,i,j,k
      real*8 :: zpl,zpm
      integer :: l,m
      integer :: pl,pm

      ! perform the same eliminations on Z that were
      ! completed on A (this is the quasi-forward substitution pass)
      do m = nn,neqn-1
          pm = pivot(m,i,j,k)
          if (z(pm,i,j,k) .ne. 0.d0) then    !! no need to mult by 0.0
              zpm = z(pm,i,j,k)
              do l = m+1,neqn
                  pl = pivot(l,i,j,k)
                  if (a(pl,m,i,j,k) .ne. 0.d0) then !! no need to add 0.0
                      z(pl,i,j,k) = z(pl,i,j,k) - a(pl,m,i,j,k)*zpm
                  endif
              end do
          endif
      end do

      ! solves the upper triangular system using column
      ! oriented back substitution
      do l = neqn,1,-1
          pl = pivot(l,i,j,k)
          if (z(pl,i,j,k) .ne. 0.d0) then     !! no need to mult by 0.0
              z(pl,i,j,k) = z(pl,i,j,k)/a(pl,l,i,j,k)
              zpl = z(pl,i,j,k)
              do m = l-1,1,-1
                  pm = pivot(m,i,j,k)
                  if (a(pm,l,i,j,k) .ne. 0.d0) then !! no need to add 0.0
                      z(pm,i,j,k) = z(pm,i,j,k) - a(pm,l,i,j,k)*zpl
                  endif
              end do
          endif
      end do

      return
      end subroutine fwd_bk_sub

! =========================================================================
      subroutine find_inverse(neqn,i,j,k)
      implicit none

#ifdef USE_OPENMP_OFFLOAD
      !$omp declare target
#endif

      integer, intent(in) :: neqn,i,j,k
      integer :: l,m
      integer :: pm

      ! using the LU decomp in matrix a, find and return the inverse of a
      ! NOTE: the a matrix is pivoted during the LU decomp. the pivoting is
      !       is used during the fwd/back substitution. once the inverse
      !       is found, the pivoting is reversed before returning to the  
      !       calling function

      ! NOTE: Finds the inverse of A directly

      do l = 1,neqn
          ! initialize z with the appropriate column of the Identity matrix
          do m = 1, neqn
              z(m,i,j,k) = 0.d0
          enddo
          z(pivot(l,i,j,k),i,j,k) = 1.d0

          ! solve for the appropriate column of the inverse of a
          ! NOTE: uses the pivoted a matrix
          ! NOTE2: can send in "l" iff we're initializing z with 
          !        a column of the Identity matrix ... otherwise
          !        we must send in ONE : "1" instead of "l"
          call fwd_bk_sub(neqn,l,i,j,k)
          
          ! store the column in a
          ! NOTE: pivoting is reversed so that the matrix returned in a
          !       is truly the inverse of the original a
          ! NOTE2: also need to pivot the column with this algorithm
          do m = 1,neqn
              pm = pivot(m,i,j,k)
              ao(m,pivot(l,i,j,k),i,j,k) = z(pm,i,j,k)
          enddo
      enddo

      do m = 1,neqn
         do l = 1,neqn
            a(l,m,i,j,k) = ao(l,m,i,j,k)
         enddo
      enddo

      return
      end subroutine find_inverse

! =========================================================================

      subroutine matrix_invert(neqn,i,j,k,matrix_info)
      implicit none

#ifdef USE_OPENMP_OFFLOAD
      !$omp declare target
#endif

      integer, intent(in) ::  i,j,k,neqn
      integer, intent(inout) :: matrix_info

      ! routines for inverting the C matrix using custom library

      call ludecomp(neqn,pivot(1,i,j,k),a(1,1,i,j,k),matrix_info)

      if (matrix_info .gt. 0) then
         !! bail if there is problem with the matrix
         return
      else
         !! find the inverse if the matrix passes the checks
         call find_inverse(neqn,i,j,k)
      endif

      return
      end subroutine matrix_invert

! ====================================================================

  subroutine solve(neqn,i,j,k,matrix_info)
  implicit none

#ifdef USE_OPENMP_OFFLOAD
  !$omp declare target
#endif

  integer, intent(in) ::  neqn,i,j,k
  integer, intent(inout) :: matrix_info

  real*8 :: z_tmp
  integer :: l,m,l_tmp

  ! routines for inverting the C matrix using custom library

  ! lu decomp of matrix c
  call ludecomp(neqn,pivot(1,i,j,k),a(1,1,i,j,k),matrix_info)

  if (matrix_info .gt. 0) then
     !! bail if there is problem with the matrix
     return
  else
     call fwd_bk_sub(neqn,1,i,j,k)

     ! reverse the pivoting without an additional array
     do l = 1,neqn-1
        if (l .ne. pivot(l,i,j,k)) then

           ! find the row to switch
           l_tmp = 0
           do m = l+1, neqn
              if (pivot(m,i,j,k) .eq. l) l_tmp = m
           end do

           ! switch rows in solution vector
           z_tmp = z(pivot(l,i,j,k),i,j,k)
           z(pivot(l,i,j,k),i,j,k) = z(pivot(l_tmp,i,j,k),i,j,k)
           z(pivot(l_tmp,i,j,k),i,j,k) = z_tmp

           ! update pivot vector
           pivot(l_tmp,i,j,k) = pivot(l,i,j,k)
           pivot(l,i,j,k) = l
        end if
     end do

  endif

  return
  end subroutine solve

! ====================================================================

      end module Matrix_Operations