Commit 3a3a09f6 authored by Alex Zinenko's avatar Alex Zinenko
Browse files

[mlir][python] Provide more convenient wrappers for std.ConstantOp

Constructing a ConstantOp using the default-generated API is verbose and
requires to specify the constant type twice: for the result type of the
operation and for the type of the attribute. It also requires to explicitly
construct the attribute. Provide custom constructors that take the type once
and accept a raw value instead of the attribute. This requires dynamic dispatch
based on type in the constructor. Also provide the corresponding accessors to
raw values.

In addition, provide a "refinement" class ConstantIndexOp similar to what
exists in C++. Unlike other "op view" Python classes, operations cannot be
automatically downcasted to this class since it does not correspond to a
specific operation name. It only exists to simplify construction of the
operation.

Depends On D110946

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110947
parent ed9e52f3
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings(
  ADD_TO_PARENT MLIRPythonSources.Dialects
  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
  TD_FILE dialects/StandardOps.td
  SOURCES dialects/std.py
  SOURCES
    dialects/std.py
    dialects/_std_ops_ext.py
  DIALECT_NAME std)

declare_mlir_dialect_python_bindings(
+71 −0
Original line number Diff line number Diff line
#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

try:
  from ..ir import *
  from .builtin import FuncOp
  from ._ods_common import get_default_loc_context as _get_default_loc_context

  from typing import Any, List, Optional, Union
except ImportError as e:
  raise RuntimeError("Error loading imports from extension module") from e


def _isa(obj: Any, cls: type):
  try:
    cls(obj)
  except ValueError:
    return False
  return True


def _is_any_of(obj: Any, classes: List[type]):
  return any(_isa(obj, cls) for cls in classes)


def _is_integer_like_type(type: Type):
  return _is_any_of(type, [IntegerType, IndexType])


def _is_float_type(type: Type):
  return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])


class ConstantOp:
  """Specialization for the constant op class."""

  def __init__(self,
               result: Type,
               value: Union[int, float, Attribute],
               *,
               loc=None,
               ip=None):
    if isinstance(value, int):
      super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
    elif isinstance(value, float):
      super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
    else:
      super().__init__(result, value, loc=loc, ip=ip)

  @classmethod
  def create_index(cls, value: int, *, loc=None, ip=None):
    """Create an index-typed constant."""
    return cls(
        IndexType.get(context=_get_default_loc_context(loc)),
        value,
        loc=loc,
        ip=ip)

  @property
  def type(self):
    return self.results[0].type

  @property
  def literal_value(self) -> Union[int, float]:
    if _is_integer_like_type(self.type):
      return IntegerAttr(self.value).value
    elif _is_float_type(self.type):
      return FloatAttr(self.value).value
    else:
      raise ValueError("only integer and float constants have literal values")
+64 −0
Original line number Diff line number Diff line
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import std


def constructAndPrintInModule(f):
  print("\nTEST:", f.__name__)
  with Context(), Location.unknown():
    module = Module.create()
    with InsertionPoint(module.body):
      f()
    print(module)
  return f

# CHECK-LABEL: TEST: testConstantOp

@constructAndPrintInModule
def testConstantOp():
  c1 = std.ConstantOp(IntegerType.get_signless(32), 42)
  c2 = std.ConstantOp(IntegerType.get_signless(64), 100)
  c3 = std.ConstantOp(F32Type.get(), 3.14)
  c4 = std.ConstantOp(F64Type.get(), 1.23)
  # CHECK: 42
  print(c1.literal_value)

  # CHECK: 100
  print(c2.literal_value)

  # CHECK: 3.140000104904175
  print(c3.literal_value)

  # CHECK: 1.23
  print(c4.literal_value)

# CHECK: = constant 42 : i32
# CHECK: = constant 100 : i64
# CHECK: = constant 3.140000e+00 : f32
# CHECK: = constant 1.230000e+00 : f64

# CHECK-LABEL: TEST: testVectorConstantOp
@constructAndPrintInModule
def testVectorConstantOp():
  int_type = IntegerType.get_signless(32)
  vec_type = VectorType.get([2, 2], int_type)
  c1 = std.ConstantOp(vec_type,
                      DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
  try:
    print(c1.literal_value)
  except ValueError as e:
    assert "only integer and float constants have literal values" in str(e)
  else:
    assert False

# CHECK: = constant dense<42> : vector<2x2xi32>

# CHECK-LABEL: TEST: testConstantIndexOp
@constructAndPrintInModule
def testConstantIndexOp():
  c1 = std.ConstantOp.create_index(10)
  # CHECK: 10
  print(c1.literal_value)

# CHECK: = constant 10 : index