Commit 20966fcb authored by Ingo Müller's avatar Ingo Müller
Browse files

[mlir][linalg][transform][python] Add mix-in for BufferizeToAllocOp.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157704
parent 691a2fab
Loading
Loading
Loading
Loading
+34 −0
Original line number Diff line number Diff line
@@ -84,6 +84,40 @@ def _get_int_int_array_attr(
    return ArrayAttr.get(values)


class BufferizeToAllocationOp:
    """Specialization for BufferizeToAllocationOp class."""

    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        memory_space: Optional[int | str | Attribute] = None,
        memcpy_op: Optional[str] = None,
        alloc_op: Optional[str] = None,
        bufferize_destination_only: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        # No other types are allowed, so hard-code those here.
        allocated_buffer_type = transform.AnyValueType.get()
        new_ops_type = transform.AnyOpType.get()

        if isinstance(memory_space, int):
            memory_space = str(memory_space)
        if isinstance(memory_space, str):
            memory_space = Attribute.parse(memory_space)

        super().__init__(
            allocated_buffer_type,
            new_ops_type,
            target,
            memory_space=memory_space,
            memcpy_op=memcpy_op,
            alloc_op=alloc_op,
            bufferize_destination_only=bufferize_destination_only,
        )


class DecomposeOp:
    """Specialization for DecomposeOp class."""

+36 −0
Original line number Diff line number Diff line
@@ -18,6 +18,42 @@ def run(f):
    return f


@run
def testBufferizeToAllocationOpCompact():
    sequence = transform.SequenceOp(
        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
    )
    with InsertionPoint(sequence.body):
        structured.BufferizeToAllocationOp(sequence.bodyTarget)
        transform.YieldOp()
    # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
    # CHECK: transform.sequence
    # CHECK: transform.structured.bufferize_to_allocation


@run
def testBufferizeToAllocationOpArgs():
    sequence = transform.SequenceOp(
        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
    )
    with InsertionPoint(sequence.body):
        structured.BufferizeToAllocationOp(
            sequence.bodyTarget,
            memory_space=3,
            memcpy_op="memref.copy",
            alloc_op="memref.alloca",
            bufferize_destination_only=True,
        )
        transform.YieldOp()
    # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.bufferize_to_allocation
    # CHECK-SAME: alloc_op = "memref.alloca"
    # CHECK-SAME: bufferize_destination_only
    # CHECK-SAME: memcpy_op = "memref.copy"
    # CHECK-SAME: memory_space = 3


@run
def testDecompose():
    sequence = transform.SequenceOp(