Commit 255a6909 authored by Alex Zinenko's avatar Alex Zinenko
Browse files

[mlir][python] Provide more convenient constructors for std.CallOp

The new constructor relies on type-based dynamic dispatch and allows one to
construct call operations given an object representing a FuncOp or its name as
a string, as opposed to requiring an explicitly constructed attribute.

Depends On D110947

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110948
parent 3a3a09f6
Loading
Loading
Loading
Loading
+10 −6
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

try:
  from typing import Optional, Sequence
  from typing import Optional, Sequence, Union

  import inspect

@@ -82,8 +82,8 @@ class FuncOp:
    return self.attributes["sym_visibility"]

  @property
  def name(self):
    return self.attributes["sym_name"]
  def name(self) -> StringAttr:
    return StringAttr(self.attributes["sym_name"])

  @property
  def entry_block(self):
@@ -104,11 +104,15 @@ class FuncOp:

  @property
  def arg_attrs(self):
    return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])

  @arg_attrs.setter
  def arg_attrs(self, attribute: ArrayAttr):
  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
    if isinstance(attribute, ArrayAttr):
      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
    else:
      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
          attribute, context=self.context)

  @property
  def arguments(self):
+70 −0
Original line number Diff line number Diff line
@@ -69,3 +69,73 @@ class ConstantOp:
      return FloatAttr(self.value).value
    else:
      raise ValueError("only integer and float constants have literal values")


class CallOp:
  """Specialization for the call op class."""

  def __init__(self,
               calleeOrResults: Union[FuncOp, List[Type]],
               argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
               arguments: Optional[List] = None,
               *,
               loc=None,
               ip=None):
    """Creates an call operation.

    The constructor accepts three different forms:

      1. A function op to be called followed by a list of arguments.
      2. A list of result types, followed by the name of the function to be
         called as string, following by a list of arguments.
      3. A list of result types, followed by the name of the function to be
         called as symbol reference attribute, followed by a list of arguments.

    For example

        f = builtin.FuncOp("foo", ...)
        std.CallOp(f, [args])
        std.CallOp([result_types], "foo", [args])

    In all cases, the location and insertion point may be specified as keyword
    arguments if not provided by the surrounding context managers.
    """

    # TODO: consider supporting constructor "overloads", e.g., through a custom
    # or pybind-provided metaclass.
    if isinstance(calleeOrResults, FuncOp):
      if not isinstance(argumentsOrCallee, list):
        raise ValueError(
            "when constructing a call to a function, expected " +
            "the second argument to be a list of call arguments, " +
            f"got {type(argumentsOrCallee)}")
      if arguments is not None:
        raise ValueError("unexpected third argument when constructing a call" +
                         "to a function")

      super().__init__(
          calleeOrResults.type.results,
          FlatSymbolRefAttr.get(
              calleeOrResults.name.value,
              context=_get_default_loc_context(loc)),
          argumentsOrCallee,
          loc=loc,
          ip=ip)
      return

    if isinstance(argumentsOrCallee, list):
      raise ValueError("when constructing a call to a function by name, " +
                       "expected the second argument to be a string or a " +
                       f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")

    if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
      super().__init__(
          calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
    elif isinstance(argumentsOrCallee, str):
      super().__init__(
          calleeOrResults,
          FlatSymbolRefAttr.get(
              argumentsOrCallee, context=_get_default_loc_context(loc)),
          arguments,
          loc=loc,
          ip=ip)
+15 −3
Original line number Diff line number Diff line
@@ -171,7 +171,7 @@ def testFuncArgumentAccess():
    f32 = F32Type.get()
    f64 = F64Type.get()
    with InsertionPoint(module.body):
      func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
      func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
      with InsertionPoint(func.add_entry_block()):
        std.ReturnOp(func.arguments)
      func.arg_attrs = ArrayAttr.get([
@@ -186,6 +186,14 @@ def testFuncArgumentAccess():
          DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
      ])

      other = builtin.FuncOp("other_func", ([f32, f32], []))
      with InsertionPoint(other.add_entry_block()):
        std.ReturnOp([])
      other.arg_attrs = [
          DictAttr.get({"foo": StringAttr.get("qux")}),
          DictAttr.get()
      ]

  # CHECK: [{baz, foo = "bar"}, {qux = []}]
  print(func.arg_attrs)

@@ -195,7 +203,11 @@ def testFuncArgumentAccess():
  # CHECK: func @some_func(
  # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
  # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
  # CHECK: f64 {res1 = 4.200000e+01 : f32},
  # CHECK: f64 {res2 = 2.560000e+02 : f64})
  # CHECK: f32 {res1 = 4.200000e+01 : f32},
  # CHECK: f32 {res2 = 2.560000e+02 : f64})
  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
  #
  # CHECK: func @other_func(
  # CHECK: %{{.*}}: f32 {foo = "qux"},
  # CHECK: %{{.*}}: f32)
  print(module)
+25 −0
Original line number Diff line number Diff line
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import builtin
from mlir.dialects import std


@@ -62,3 +63,27 @@ def testConstantIndexOp():
  print(c1.literal_value)

# CHECK: = constant 10 : index

# CHECK-LABEL: TEST: testFunctionCalls
@constructAndPrintInModule
def testFunctionCalls():
  foo = builtin.FuncOp("foo", ([], []))
  bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
  qux = builtin.FuncOp("qux", ([], [F32Type.get()]))

  with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
    std.CallOp(foo, [])
    std.CallOp([IndexType.get()], "bar", [])
    std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
    std.ReturnOp([])

# CHECK: func @foo()
# CHECK: func @bar() -> index
# CHECK: func @qux() -> f32
# CHECK: func @caller() {
# CHECK:   call @foo() : () -> ()
# CHECK:   %0 = call @bar() : () -> index
# CHECK:   %1 = call @qux() : () -> f32
# CHECK:   return
# CHECK: }