Unverified Commit 23483d49 authored by Eugene Epshteyn's avatar Eugene Epshteyn Committed by GitHub
Browse files

[flang][Semantics] Break recursion on illegal recursive type in I/O check (#194284)

When an illegal recursive derived type (a non-POINTER/non-ALLOCATABLE
component whose type is the enclosing type itself, prohibited by F2023
C749) is used in an I/O list, the component-walking helpers
FindUnsafeIoDirectComponent() and FindInaccessibleComponent() recursed
through it forever and blew the stack.

The fix involves tracking the derived types currently on the recursion
path in a VisitedSymbolSet to detect loops.

Fixes #192387

Assisted-by: AI
parent eed9e970
Loading
Loading
Loading
Loading
+36 −11
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/tools.h"
#include <unordered_map>
#include <unordered_set>

namespace Fortran::semantics {

@@ -1116,14 +1117,23 @@ void IoChecker::CheckForUselessIomsg() const {
  }
}

// Set of derived-type symbols already visited on the current recursion
// path of the component walks below.
using VisitedSymbolSet = std::unordered_set<const Symbol *>;

// Seeks out an allocatable or pointer ultimate component that is not
// nested in a nonallocatable/nonpointer component with a specific
// defined I/O procedure.
// nested in a nonallocatable/nonpointer component with a specific defined I/O
// procedure. The 'visited' set tracks derived types to break cycles caused by
// an illegal recursive type definition (F2023 C749).
static const Symbol *FindUnsafeIoDirectComponent(common::DefinedIo which,
    const DerivedTypeSpec &derived, const Scope &scope) {
    const DerivedTypeSpec &derived, const Scope &scope,
    VisitedSymbolSet &visited) {
  if (HasDefinedIo(which, derived, &scope)) {
    return nullptr;
  }
  if (!visited.insert(&derived.typeSymbol()).second) {
    return nullptr;
  }
  if (const Scope * dtScope{derived.scope()}) {
    for (const auto &pair : *dtScope) {
      const Symbol &symbol{*pair.second};
@@ -1134,9 +1144,8 @@ static const Symbol *FindUnsafeIoDirectComponent(common::DefinedIo which,
        if (const DeclTypeSpec * type{details->type()}) {
          if (type->category() == DeclTypeSpec::Category::TypeDerived) {
            const DerivedTypeSpec &componentDerived{type->derivedTypeSpec()};
            if (const Symbol *
                bad{FindUnsafeIoDirectComponent(
                    which, componentDerived, scope)}) {
            if (const Symbol *bad{FindUnsafeIoDirectComponent(
                    which, componentDerived, scope, visited)}) {
              return bad;
            }
          }
@@ -1147,11 +1156,22 @@ static const Symbol *FindUnsafeIoDirectComponent(common::DefinedIo which,
  return nullptr;
}

static const Symbol *FindUnsafeIoDirectComponent(common::DefinedIo which,
    const DerivedTypeSpec &derived, const Scope &scope) {
  VisitedSymbolSet visited;
  return FindUnsafeIoDirectComponent(which, derived, scope, visited);
}

// For a type that does not have a defined I/O subroutine, finds a direct
// component that is a witness to an accessibility violation outside the module
// in which the type was defined.
// in which the type was defined.  The 'visited' set tracks derived types to
// break cycles caused by an illegal recursive type definition (F2023 C749).
static const Symbol *FindInaccessibleComponent(common::DefinedIo which,
    const DerivedTypeSpec &derived, const Scope &scope) {
    const DerivedTypeSpec &derived, const Scope &scope,
    VisitedSymbolSet &visited) {
  if (!visited.insert(&derived.typeSymbol()).second) {
    return nullptr;
  }
  if (const Scope * dtScope{derived.scope()}) {
    if (const Scope * module{FindModuleContaining(*dtScope)}) {
      for (const auto &pair : *dtScope) {
@@ -1177,9 +1197,8 @@ static const Symbol *FindInaccessibleComponent(common::DefinedIo which,
            }
          }
          if (componentDerived) {
            if (const Symbol *
                bad{FindInaccessibleComponent(
                    which, *componentDerived, scope)}) {
            if (const Symbol *bad{FindInaccessibleComponent(
                    which, *componentDerived, scope, visited)}) {
              return bad;
            }
          }
@@ -1190,6 +1209,12 @@ static const Symbol *FindInaccessibleComponent(common::DefinedIo which,
  return nullptr;
}

static const Symbol *FindInaccessibleComponent(common::DefinedIo which,
    const DerivedTypeSpec &derived, const Scope &scope) {
  VisitedSymbolSet visited;
  return FindInaccessibleComponent(which, derived, scope, visited);
}

// Fortran 2018, 12.6.3 paragraphs 5 & 7
parser::Message *IoChecker::CheckForBadIoType(const evaluate::DynamicType &type,
    common::DefinedIo which, parser::CharBlock where) const {
+78 −0
Original line number Diff line number Diff line
@@ -74,3 +74,81 @@ module m4
  end subroutine
end module

! Regression test: an illegal recursive derived-type component used to cause
! infinite recursion in FindUnsafeIoDirectComponent when the object appeared
! in an I/O list (issue #192387).
subroutine test_recursive_io
  type t1
    !ERROR: Recursive use of the derived type requires POINTER or ALLOCATABLE
    type(t1) :: b
  end type t1
  type(t1) :: obj
  print *, obj
end subroutine

! Same regression covering the FindInaccessibleComponent walk: the type
! must be defined in a module and used in I/O outside that module so the
! recursive component traversal in FindInaccessibleComponent is reached.
module m_recursive
  type t2
    !ERROR: Recursive use of the derived type requires POINTER or ALLOCATABLE
    type(t2) :: b
  end type t2
end module
subroutine test_recursive_io_module
  use m_recursive
  type(t2) :: obj
  print *, obj
end subroutine

! Positive cases: a recursive type is legal when the recursive component
! is POINTER or ALLOCATABLE.  With defined I/O, an I/O list item of such
! a type is accepted without diagnostics.
module m_recursive_pointer
  type :: rp
    integer :: x
    type(rp), pointer :: next => null()
   contains
    procedure :: wuf_rp
    generic :: write(unformatted) => wuf_rp
  end type
 contains
  subroutine wuf_rp(dtv, unit, iostat, iomsg)
    class(rp), intent(in) :: dtv
    integer, intent(in) :: unit
    integer, intent(out) :: iostat
    character(*), intent(in out) :: iomsg
    write(unit) dtv%x
  end subroutine
end module
subroutine test_recursive_pointer_io(u)
  use m_recursive_pointer
  integer, intent(in) :: u
  type(rp) :: obj
  write(u) obj ! ok: defined I/O
end subroutine

module m_recursive_allocatable
  type :: ra
    integer :: x
    type(ra), allocatable :: next
   contains
    procedure :: wuf_ra
    generic :: write(unformatted) => wuf_ra
  end type
 contains
  subroutine wuf_ra(dtv, unit, iostat, iomsg)
    class(ra), intent(in) :: dtv
    integer, intent(in) :: unit
    integer, intent(out) :: iostat
    character(*), intent(in out) :: iomsg
    write(unit) dtv%x
  end subroutine
end module
subroutine test_recursive_allocatable_io(u)
  use m_recursive_allocatable
  integer, intent(in) :: u
  type(ra) :: obj
  write(u) obj ! ok: defined I/O
end subroutine