Commit 5097d2f3 authored by Nguyen, Thien's avatar Nguyen, Thien

Exposed a couple of ExaTN APIs to Python

- Merge tensors API that returns pattern string.

- Create tensor (in num server) from an existing tensor.

Also, added a simple example.

(This example was adapted from discussion with Andrew Chen)
Signed-off-by: Nguyen, Thien's avatarThien Nguyen <nguyentm@ornl.gov>
parent 5340cd73
import sys
from pathlib import Path
sys.path.insert(1, str(Path.home()) + '/.exatn')
import exatn
import numpy as np
# Demonstrate simple tensor network manipulation
exatn.createTensor('X', [2, 2], 0)
exatn.createTensor('Y', [2, 2], 0)
exatn.createTensor('Z', [2, 2], 0)
exatn.initTensorRnd('X')
exatn.initTensorRnd('Y')
exatn.initTensorRnd('Z')
tNet = exatn.TensorNetwork('test')
tNet.appendTensor(1, 'X')
tNet.appendTensor(2, 'Y')
tNet.appendTensor(3, 'Z')
# print tensor network
tNet.printIt()
tNetOriginal = exatn.TensorNetwork(tNet)
# Merge X and Y
pattern = tNet.mergeTensors(1, 2, 4)
print("After merge:")
tNet.printIt()
# Print the generic merge pattern
print(pattern)
# Create the merged tensor
pattern = pattern.replace("D", tNet.getTensor(4).getName())
pattern = pattern.replace("L", "X")
pattern = pattern.replace("R", "Y")
print(pattern)
# Perform calculation
exatn.createTensor(tNet.getTensor(4))
exatn.contractTensors(pattern)
# Evaluate the tensor network (after merging two tensors)
exatn.evaluate(tNet)
# Print root tensor
root = exatn.getLocalTensor(tNet.getTensor(0).getName())
print(root)
# Evaluate the *Original* network to make sure it is the same.
tNetOriginal.printIt()
exatn.evaluate(tNetOriginal)
rootOriginal = exatn.getLocalTensor(tNetOriginal.getTensor(0).getName())
print(rootOriginal)
assert(np.allclose(root, rootOriginal))
......@@ -155,8 +155,18 @@ py::class_<exatn::numerics::TensorExpansion,
.def("reorderOutputModes",
&exatn::numerics::TensorNetwork::reorderOutputModes, "")
.def("deleteTensor", &exatn::numerics::TensorNetwork::deleteTensor, "")
.def("mergeTensors", &exatn::numerics::TensorNetwork::mergeTensors, "");
.def("mergeTensors", &exatn::numerics::TensorNetwork::mergeTensors, "")
// Returns the merge pattern if valid. Otherwise, returns an empty string.
.def(
"mergeTensors",
[](TensorNetwork &network, unsigned int left_id, unsigned int right_id, unsigned int result_id) {
std::string pattern;
if (network.mergeTensors(left_id, right_id, result_id, &pattern)) {
return pattern;
}
return std::string();
},
"");
py::enum_<exatn::TensorElementType>(m, "DataType", py::arithmetic(), "")
.value("float32", exatn::TensorElementType::REAL32, "")
.value("float64", exatn::TensorElementType::REAL64, "")
......@@ -263,6 +273,11 @@ py::class_<exatn::numerics::TensorExpansion,
},
"");
m.def("createTensor", &createTensorWithDataNoNumServer, "");
// Create an existing declared tensor
m.def("createTensor", [](std::shared_ptr<Tensor> tensor) {
auto success = exatn::createTensor(tensor, tensor->getElementType());
return success;
});
m.def(
"registerTensorIsometry",
[](const std::string &name, const std::vector<unsigned int> &iso_dims) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment