Unverified Commit f08fe1f1 authored by Abhinav271828's avatar Abhinav271828 Committed by GitHub
Browse files

[MLIR][Presburger] Implement matrix inverse (#67382)

Shift the `determinant()` function from LinearTransform to Matrix.
Implement a FracMatrix class, inheriting from Matrix<Fraction>, for inverses.
Implement inverse for FracMatrix and intInverse for IntMatrix.
Make Matrix internals protected instead of private so that Int/FracMatrix can access them.
parent 080fb3e5
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -50,10 +50,6 @@ public:
    return matrix.postMultiplyWithColumn(colVec);
  }

  // Compute the determinant of the transform by converting it to row echelon
  // form and then taking the product of the diagonal.
  MPInt determinant();

private:
  IntMatrix matrix;
};
+34 −8
Original line number Diff line number Diff line
@@ -189,7 +189,7 @@ public:
  /// invariants satisfied.
  bool hasConsistentState() const;

private:
protected:
  /// The current number of rows, columns, and reserved columns. The underlying
  /// data vector is viewed as an nRows x nReservedColumns matrix, of which the
  /// first nColumns columns are currently in use, and the remaining are
@@ -210,13 +210,7 @@ public:
            unsigned reservedColumns = 0)
      : Matrix<MPInt>(rows, columns, reservedRows, reservedColumns){};

  IntMatrix(Matrix<MPInt> m)
      : Matrix<MPInt>(m.getNumRows(), m.getNumColumns(), m.getNumReservedRows(),
                      m.getNumReservedColumns()) {
    for (unsigned i = 0; i < m.getNumRows(); i++)
      for (unsigned j = 0; j < m.getNumColumns(); j++)
        at(i, j) = m(i, j);
  };
  IntMatrix(Matrix<MPInt> m) : Matrix<MPInt>(std::move(m)){};

  /// Return the identity matrix of the specified dimension.
  static IntMatrix identity(unsigned dimension);
@@ -239,6 +233,38 @@ public:
  /// Divide the columns of the specified row by their GCD.
  /// Returns the GCD of the columns of the specified row.
  MPInt normalizeRow(unsigned row);

  // Compute the determinant of the matrix (cubic time).
  // Stores the integer inverse of the matrix in the pointer
  // passed (if any). The pointer is unchanged if the inverse
  // does not exist, which happens iff det = 0.
  // For a matrix M, the integer inverse is the matrix M' such that
  // M x M' = M'  M = det(M) x I.
  // Assert-fails if the matrix is not square.
  MPInt determinant(IntMatrix *inverse = nullptr) const;
};

// An inherited class for rational matrices, with no new data attributes.
// This class is for functionality that only applies to matrices of fractions.
class FracMatrix : public Matrix<Fraction> {
public:
  FracMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0,
             unsigned reservedColumns = 0)
      : Matrix<Fraction>(rows, columns, reservedRows, reservedColumns){};

  FracMatrix(Matrix<Fraction> m) : Matrix<Fraction>(std::move(m)){};

  explicit FracMatrix(IntMatrix m);

  /// Return the identity matrix of the specified dimension.
  static FracMatrix identity(unsigned dimension);

  // Compute the determinant of the matrix (cubic time).
  // Stores the inverse of the matrix in the pointer
  // passed (if any). The pointer is unchanged if the inverse
  // does not exist, which happens iff det = 0.
  // Assert-fails if the matrix is not square.
  Fraction determinant(FracMatrix *inverse = nullptr) const;
};

} // namespace presburger
+116 −0
Original line number Diff line number Diff line
@@ -433,3 +433,119 @@ MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
MPInt IntMatrix::normalizeRow(unsigned row) {
  return normalizeRow(row, getNumColumns());
}

MPInt IntMatrix::determinant(IntMatrix *inverse) const {
  assert(nRows == nColumns &&
         "determinant can only be calculated for square matrices!");

  FracMatrix m(*this);

  FracMatrix fracInverse(nRows, nColumns);
  MPInt detM = m.determinant(&fracInverse).getAsInteger();

  if (detM == 0)
    return MPInt(0);

  *inverse = IntMatrix(nRows, nColumns);
  for (unsigned i = 0; i < nRows; i++)
    for (unsigned j = 0; j < nColumns; j++)
      inverse->at(i, j) = (fracInverse.at(i, j) * detM).getAsInteger();

  return detM;
}

FracMatrix FracMatrix::identity(unsigned dimension) {
  return Matrix::identity(dimension);
}

FracMatrix::FracMatrix(IntMatrix m)
    : FracMatrix(m.getNumRows(), m.getNumColumns()) {
  for (unsigned i = 0; i < m.getNumRows(); i++)
    for (unsigned j = 0; j < m.getNumColumns(); j++)
      this->at(i, j) = m.at(i, j);
}

Fraction FracMatrix::determinant(FracMatrix *inverse) const {
  assert(nRows == nColumns &&
         "determinant can only be calculated for square matrices!");

  FracMatrix m(*this);
  FracMatrix tempInv(nRows, nColumns);
  if (inverse)
    tempInv = FracMatrix::identity(nRows);

  Fraction a, b;
  // Make the matrix into upper triangular form using
  // gaussian elimination with row operations.
  // If inverse is required, we apply more operations
  // to turn the matrix into diagonal form. We apply
  // the same operations to the inverse matrix,
  // which is initially identity.
  // Either way, the product of the diagonal elements
  // is then the determinant.
  for (unsigned i = 0; i < nRows; i++) {
    if (m(i, i) == 0)
      // First ensure that the diagonal
      // element is nonzero, by swapping
      // it with a nonzero row.
      for (unsigned j = i + 1; j < nRows; j++) {
        if (m(j, i) != 0) {
          m.swapRows(j, i);
          if (inverse)
            tempInv.swapRows(j, i);
          break;
        }
      }

    b = m.at(i, i);
    if (b == 0)
      return 0;

    // Set all elements above the
    // diagonal to zero.
    if (inverse) {
      for (unsigned j = 0; j < i; j++) {
        if (m.at(j, i) == 0)
          continue;
        a = m.at(j, i);
        // Set element (j, i) to zero
        // by subtracting the ith row,
        // appropriately scaled.
        m.addToRow(i, j, -a / b);
        tempInv.addToRow(i, j, -a / b);
      }
    }

    // Set all elements below the
    // diagonal to zero.
    for (unsigned j = i + 1; j < nRows; j++) {
      if (m.at(j, i) == 0)
        continue;
      a = m.at(j, i);
      // Set element (j, i) to zero
      // by subtracting the ith row,
      // appropriately scaled.
      m.addToRow(i, j, -a / b);
      if (inverse)
        tempInv.addToRow(i, j, -a / b);
    }
  }

  // Now only diagonal elements of m are nonzero, but they are
  // not necessarily 1. To get the true inverse, we should
  // normalize them and apply the same scale to the inverse matrix.
  // For efficiency we skip scaling m and just scale tempInv appropriately.
  if (inverse) {
    for (unsigned i = 0; i < nRows; i++)
      for (unsigned j = 0; j < nRows; j++)
        tempInv.at(i, j) = tempInv.at(i, j) / m(i, i);

    *inverse = std::move(tempInv);
  }

  Fraction determinant = 1;
  for (unsigned i = 0; i < nRows; i++)
    determinant *= m.at(i, i);

  return determinant;
}
 No newline at end of file
+68 −6
Original line number Diff line number Diff line
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Presburger/Matrix.h"
#include "mlir/Analysis/Presburger/Fraction.h"
#include "./Utils.h"
#include "mlir/Analysis/Presburger/Fraction.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>

@@ -210,7 +210,8 @@ TEST(MatrixTest, computeHermiteNormalForm) {
  {
    // Hermite form of a unimodular matrix is the identity matrix.
    IntMatrix mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
    IntMatrix hermiteForm = makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
    IntMatrix hermiteForm =
        makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
    checkHermiteNormalForm(mat, hermiteForm);
  }

@@ -241,10 +242,71 @@ TEST(MatrixTest, computeHermiteNormalForm) {
  }

  {
    IntMatrix mat =
        makeIntMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
    IntMatrix hermiteForm =
        makeIntMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
    IntMatrix mat = makeIntMatrix(
        3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
    IntMatrix hermiteForm = makeIntMatrix(
        3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
    checkHermiteNormalForm(mat, hermiteForm);
  }
}

TEST(MatrixTest, inverse) {
  FracMatrix mat = makeFracMatrix(
      2, 2, {{Fraction(2), Fraction(1)}, {Fraction(7), Fraction(0)}});
  FracMatrix inverse = makeFracMatrix(
      2, 2, {{Fraction(0), Fraction(1, 7)}, {Fraction(1), Fraction(-2, 7)}});

  FracMatrix inv(2, 2);
  mat.determinant(&inv);

  EXPECT_EQ_FRAC_MATRIX(inv, inverse);

  mat = makeFracMatrix(
      2, 2, {{Fraction(0), Fraction(1)}, {Fraction(0), Fraction(2)}});
  Fraction det = mat.determinant(nullptr);

  EXPECT_EQ(det, Fraction(0));

  mat = makeFracMatrix(3, 3,
                       {{Fraction(1), Fraction(2), Fraction(3)},
                        {Fraction(4), Fraction(8), Fraction(6)},
                        {Fraction(7), Fraction(8), Fraction(6)}});
  inverse = makeFracMatrix(3, 3,
                           {{Fraction(0), Fraction(-1, 3), Fraction(1, 3)},
                            {Fraction(-1, 2), Fraction(5, 12), Fraction(-1, 6)},
                            {Fraction(2, 3), Fraction(-1, 6), Fraction(0)}});

  mat.determinant(&inv);
  EXPECT_EQ_FRAC_MATRIX(inv, inverse);

  mat = makeFracMatrix(0, 0, {});
  mat.determinant(&inv);
}

TEST(MatrixTest, intInverse) {
  IntMatrix mat = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
  IntMatrix inverse = makeIntMatrix(2, 2, {{0, -1}, {-7, 2}});

  IntMatrix inv(2, 2);
  mat.determinant(&inv);

  EXPECT_EQ_INT_MATRIX(inv, inverse);

  mat = makeIntMatrix(
      4, 4, {{4, 14, 11, 3}, {13, 5, 14, 12}, {13, 9, 7, 14}, {2, 3, 12, 7}});
  inverse = makeIntMatrix(4, 4,
                          {{155, 1636, -579, -1713},
                           {725, -743, 537, -111},
                           {210, 735, -855, 360},
                           {-715, -1409, 1401, 1482}});

  mat.determinant(&inv);

  EXPECT_EQ_INT_MATRIX(inv, inverse);

  mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}});

  MPInt det = mat.determinant(&inv);

  EXPECT_EQ(det, 0);
}
+23 −5
Original line number Diff line number Diff line
@@ -14,10 +14,10 @@
#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H

#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Matrix.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"

@@ -40,9 +40,9 @@ inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns,
  return results;
}

inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
inline FracMatrix makeFracMatrix(unsigned numRow, unsigned numColumns,
                                 ArrayRef<SmallVector<Fraction, 8>> matrix) {
  Matrix<Fraction> results(numRow, numColumns);
  FracMatrix results(numRow, numColumns);
  assert(matrix.size() == numRow);
  for (unsigned i = 0; i < numRow; ++i) {
    assert(matrix[i].size() == numColumns &&
@@ -53,6 +53,24 @@ inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
  return results;
}

inline void EXPECT_EQ_INT_MATRIX(IntMatrix a, IntMatrix b) {
  EXPECT_EQ(a.getNumRows(), b.getNumRows());
  EXPECT_EQ(a.getNumColumns(), b.getNumColumns());

  for (unsigned row = 0; row < a.getNumRows(); row++)
    for (unsigned col = 0; col < a.getNumColumns(); col++)
      EXPECT_EQ(a(row, col), b(row, col));
}

inline void EXPECT_EQ_FRAC_MATRIX(FracMatrix a, FracMatrix b) {
  EXPECT_EQ(a.getNumRows(), b.getNumRows());
  EXPECT_EQ(a.getNumColumns(), b.getNumColumns());

  for (unsigned row = 0; row < a.getNumRows(); row++)
    for (unsigned col = 0; col < a.getNumColumns(); col++)
      EXPECT_EQ(a(row, col), b(row, col));
}

/// lhs and rhs represent non-negative integers or positive infinity. The
/// infinity case corresponds to when the Optional is empty.
inline bool infinityOrUInt64LE(std::optional<MPInt> lhs,