Commit a954d1ac authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixed MPS tensor creation with MPI



Need to use self process group to match the new ExaTN tensor resident rule.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent f33d0720
Loading
Loading
Loading
Loading
+87 −54
Original line number Diff line number Diff line
@@ -2040,7 +2040,7 @@ void ExatnMpsVisitor::applyTwoQubitGate(xacc::Instruction& in_gateInstruction)
        m_tensorNetwork = std::make_shared<exatn::TensorNetwork>(m_tensorNetwork->getName(), mpsString, buildTensorMap());
        {
            // Truncate SVD tensors:
            truncateSvdTensors(q1TensorName, q2TensorName, m_svdCutoff);
            truncateSvdTensors(q1TensorName, q2TensorName, m_svdCutoff, m_selfProcessGroup.get());
        }

        // Rebuild the tensor network since the qubit tensors have been changed after SVD truncation
@@ -2093,6 +2093,7 @@ void ExatnMpsVisitor::applyTwoQubitGate(xacc::Instruction& in_gateInstruction)
        rebuildTensorNetwork();
        // Now, the *remote* tensor has been initialized, process gate as normal:
        // Apply gate
        xacc::info("Process [" + std::to_string(m_rank) + "]: Process gate: " + in_gateInstruction.toString());
        processTwoQubitGate();
        // Done: Send tensor to the neighbor process
        // Send tensor forward
@@ -2387,7 +2388,7 @@ std::vector<uint8_t> ExatnMpsVisitor::getMeasureSample(const std::vector<size_t>
    return resultBitString;
}

void ExatnMpsVisitor::truncateSvdTensors(const std::string& in_leftTensorName, const std::string& in_rightTensorName, double in_eps)
void ExatnMpsVisitor::truncateSvdTensors(const std::string& in_leftTensorName, const std::string& in_rightTensorName, double in_eps, exatn::ProcessGroup *in_processGroup)
{
    int lhsTensorId = -1;
    int rhsTensorId = -1;
@@ -2470,24 +2471,43 @@ void ExatnMpsVisitor::truncateSvdTensors(const std::string& in_leftTensorName, c
    assert(newBondDim > 0);
    if (newBondDim < bondDim)
    {
      // xacc::info("Truncate SVD bond.");
      auto leftShape = lhsTensor->getDimExtents();
      auto rightShape = rhsTensor->getDimExtents();
      leftShape[lhsBondId] = newBondDim;
      rightShape[rhsBondId] = newBondDim;

      // Create two new tensors:
        const std::string newLhsTensorName = in_leftTensorName + "_" + std::to_string(lhsTensor->getTensorHash());
        const bool newLhsCreated = exatn::createTensorSync(newLhsTensorName, exatn::TensorElementType::COMPLEX64, leftShape);
      const std::string newLhsTensorName =
          in_leftTensorName + "_" + std::to_string(lhsTensor->getTensorHash());
      const bool newLhsCreated =
          in_processGroup
              ? exatn::createTensorSync(*in_processGroup, newLhsTensorName,
                                        exatn::TensorElementType::COMPLEX64,
                                        leftShape)
              : exatn::createTensorSync(newLhsTensorName,
                                        exatn::TensorElementType::COMPLEX64,
                                        leftShape);
      assert(newLhsCreated);

        const std::string newRhsTensorName = in_rightTensorName + "_" + std::to_string(rhsTensor->getTensorHash());
        const bool newRhsCreated = exatn::createTensorSync(newRhsTensorName, exatn::TensorElementType::COMPLEX64, rightShape);
      const std::string newRhsTensorName =
          in_rightTensorName + "_" + std::to_string(rhsTensor->getTensorHash());
      const bool newRhsCreated =
          in_processGroup
              ? exatn::createTensorSync(*in_processGroup, newRhsTensorName,
                                        exatn::TensorElementType::COMPLEX64,
                                        rightShape)
              : exatn::createTensorSync(newRhsTensorName,
                                        exatn::TensorElementType::COMPLEX64,
                                        rightShape);
      assert(newRhsCreated);

      // Take the slices:
        const bool lhsSliceOk = exatn::extractTensorSliceSync(in_leftTensorName, newLhsTensorName);
      const bool lhsSliceOk =
          exatn::extractTensorSliceSync(in_leftTensorName, newLhsTensorName);
      assert(lhsSliceOk);
        const bool rhsSliceOk = exatn::extractTensorSliceSync(in_rightTensorName, newRhsTensorName);
      const bool rhsSliceOk =
          exatn::extractTensorSliceSync(in_rightTensorName, newRhsTensorName);
      assert(rhsSliceOk);

      // Destroy the two original tensors:
@@ -2497,19 +2517,30 @@ void ExatnMpsVisitor::truncateSvdTensors(const std::string& in_leftTensorName, c
      assert(rhsDestroyed);

      // Rename new tensors to the old name
        const auto renameNumericTensor = [](const std::string& oldTensorName, const std::string& newTensorName){
      const auto renameNumericTensor = [&in_processGroup](
                                           const std::string &oldTensorName,
                                           const std::string &newTensorName) {
        auto tensor = exatn::getTensor(oldTensorName);
        assert(tensor);
        auto talsh_tensor = exatn::getLocalTensor(oldTensorName);
        assert(talsh_tensor);
        const std::complex<double> *body_ptr;
            const bool access_granted = talsh_tensor->getDataAccessHostConst(&body_ptr);
        const bool access_granted =
            talsh_tensor->getDataAccessHostConst(&body_ptr);
        assert(access_granted);
        std::vector<std::complex<double>> newData;
        newData.assign(body_ptr, body_ptr + talsh_tensor->getVolume());
            const bool newTensorCreated = exatn::createTensorSync(newTensorName, exatn::TensorElementType::COMPLEX64, tensor->getShape());
        const bool newTensorCreated =
            in_processGroup
                ? exatn::createTensorSync(*in_processGroup, newTensorName,
                                          exatn::TensorElementType::COMPLEX64,
                                          tensor->getShape())
                : exatn::createTensorSync(newTensorName,
                                          exatn::TensorElementType::COMPLEX64,
                                          tensor->getShape());
        assert(newTensorCreated);
            const bool newTensorInitialized = exatn::initTensorDataSync(newTensorName, newData);
        const bool newTensorInitialized =
            exatn::initTensorDataSync(newTensorName, newData);
        assert(newTensorInitialized);
        // Destroy the two original tensor:
        const bool tensorDestroyed = exatn::destroyTensorSync(oldTensorName);
@@ -2520,9 +2551,11 @@ void ExatnMpsVisitor::truncateSvdTensors(const std::string& in_leftTensorName, c
      renameNumericTensor(newRhsTensorName, in_rightTensorName);

      // Debug:
        // std::cout << "[DEBUG] Bond dim (" << in_leftTensorName << ", " << in_rightTensorName << "): " << bondDim << " -> " << newBondDim << "\n";
      // std::cout << "[DEBUG] Bond dim (" << in_leftTensorName << ", " <<
      // in_rightTensorName << "): " << bondDim << " -> " << newBondDim << "\n";
      std::stringstream logSs;
        logSs << "[SVD] Bond dim (" << in_leftTensorName << ", " << in_rightTensorName << "): " << bondDim << " -> " << newBondDim;
      logSs << "[SVD] Bond dim (" << in_leftTensorName << ", "
            << in_rightTensorName << "): " << bondDim << " -> " << newBondDim;
      xacc::info(logSs.str());
    }
}
+4 −1
Original line number Diff line number Diff line
@@ -120,7 +120,10 @@ private:
    std::vector<uint8_t> getMeasureSample(const std::vector<size_t>& in_qubitIdx);
    void printStateVec();
    // Truncate the bond dimension between two tensors that are decomposed by SVD
    void truncateSvdTensors(const std::string& in_leftTensorName, const std::string& in_rightTensorName, double in_eps = std::numeric_limits<double>::min());
    void truncateSvdTensors(const std::string &in_leftTensorName,
                            const std::string &in_rightTensorName,
                            double in_eps = std::numeric_limits<double>::min(),
                            exatn::ProcessGroup *in_processGroup = nullptr);
    std::vector<std::complex<double>> computeWaveFuncSlice(const exatn::numerics::TensorNetwork& in_tensorNetwork, const std::vector<int>& bitString, const exatn::ProcessGroup& in_processGroup) const;
    double computeStateVectorNorm(const exatn::numerics::TensorNetwork& in_tensorNetwork, const exatn::ProcessGroup& in_processGroup) const;