Commit 691a2fab authored by Ingo Müller's avatar Ingo Müller
Browse files

[mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157706
parent 030e315e
Loading
Loading
Loading
Loading
+60 −0
Original line number Diff line number Diff line
@@ -187,6 +187,66 @@ class InterchangeOp:
        )


class MapCopyToThreadsOp:
    """Specialization for MapCopyToThreadsOp class."""

    @overload
    def __init__(
        self,
        forall_op_type: Type,
        tiled_op_type: Type,
        target: Union[Operation, OpView, Value],
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        forall_op_type_or_target: Union[Operation, OpView, Type, Value],
        tiled_op_type_or_none: Optional[Type] = None,
        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        if isinstance(forall_op_type_or_target, Type):
            forall_op_type = forall_op_type_or_target
            tiled_op_type = tiled_op_type_or_none
            target = target_or_none
        else:
            forall_op_type = transform.AnyOpType.get()
            tiled_op_type = transform.AnyOpType.get()
            target = forall_op_type_or_target

        super().__init__(
            forall_op_type,
            tiled_op_type,
            target,
            total_num_threads=total_num_threads,
            desired_bit_alignment=desired_bit_alignment,
            loc=loc,
            ip=ip,
        )


class MatchOp:
    """Specialization for MatchOp class."""

+38 −0
Original line number Diff line number Diff line
@@ -97,6 +97,44 @@ def testInterchange():
    # CHECK: iterator_interchange = [1, 0]


@run
def testMapCopyToThreadsOpCompact():
    sequence = transform.SequenceOp(
        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
    )
    with InsertionPoint(sequence.body):
        structured.MapCopyToThreadsOp(
            sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
        )
        transform.YieldOp()
    # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
    # CHECK: = transform.structured.gpu.map_copy_to_threads
    # CHECK-SAME: total_num_threads = 32
    # CHECK-SAME: desired_bit_alignment = 128
    # CHECK-SAME:  (!transform.any_op) -> (!transform.any_op, !transform.any_op)


@run
def testMapCopyToThreadsOpTypes():
    sequence = transform.SequenceOp(
        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
    )
    with InsertionPoint(sequence.body):
        structured.MapCopyToThreadsOp(
            transform.OperationType.get("test.opA"),
            transform.OperationType.get("test.opB"),
            sequence.bodyTarget,
            total_num_threads=32,
            desired_bit_alignment=128,
        )
        transform.YieldOp()
    # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
    # CHECK: = transform.structured.gpu.map_copy_to_threads
    # CHECK-SAME: total_num_threads = 32
    # CHECK-SAME: desired_bit_alignment = 128
    # CHECK-SAME:  (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)


@run
def testMatchOpNamesString():
    sequence = transform.SequenceOp(