Commit 73baec79 authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh
Browse files

Implemented Tensor ctor based on tensor contraction.

parent dd178e2f
......@@ -34,13 +34,56 @@ name_(name)
{
}
Tensor::Tensor(const std::string & name, //tensor name
const Tensor & left_tensor, //left tensor
const Tensor & right_tensor, //right tensor
int index_pattern[]):
Tensor::Tensor(const std::string & name, //tensor name
const Tensor & left_tensor, //left tensor
const Tensor & right_tensor, //right tensor
const std::vector<TensorLeg> & contraction): //tensor contraction pattern
name_(name)
{
//`Finish
//Import shape/signature of the input tensors:
auto left_rank = left_tensor.getRank();
TensorShape left_shape = left_tensor.getShape();
TensorSignature left_signa = left_tensor.getSignature();
auto right_rank = right_tensor.getRank();
TensorShape right_shape = right_tensor.getShape();
TensorSignature right_signa = right_tensor.getSignature();
//Extract the output tensor dimensions:
if(left_rank + right_rank > 0){
unsigned int out_mode = 0;
unsigned int inp_mode = 0;
unsigned int argt = 1; if(left_rank == 0) argt = 2;
unsigned int max_out_dim = 0;
unsigned int contr[left_rank+right_rank][2] = {0};
for(const auto & leg: contraction){
auto tens_id = leg.getTensorId();
if(tens_id == 0){ //uncontracted leg of either input tensor
unsigned int out_dim = leg.getDimensionId(); //output tensor mode id
if(out_dim > max_out_dim) max_out_dim = out_dim;
contr[out_dim][0] = argt; //input tensor argument: {1,2}
contr[out_dim][1] = inp_mode; //input tensor mode id
++out_mode;
}else{
assert(tens_id == 1 || tens_id == 2); //checking validity of argument <contraction>
}
++inp_mode;
if(argt == 1 && inp_mode == left_rank){inp_mode = 0; argt = 2;};
}
assert(max_out_dim < out_mode);
//Form the output tensor shape/signature:
for(unsigned int i = 0; i <= max_out_dim; ++i){
inp_mode = contr[i][1];
if(contr[i][0] == 1){
shape_.appendDimension(left_tensor.getDimExtent(inp_mode));
signature_.appendDimension(left_tensor.getDimSpaceAttr(inp_mode));
}else if(contr[i][0] == 2){
shape_.appendDimension(right_tensor.getDimExtent(inp_mode));
signature_.appendDimension(right_tensor.getDimSpaceAttr(inp_mode));
}else{
std::cout << "#ERROR(Tensor::Tensor): Invalid function argument: contraction: Missing output tensor mode!" << std::endl;
assert(false); //missing output tensor dimension
}
}
}
}
void Tensor::printIt() const
......
......@@ -34,6 +34,7 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "tensor_basic.hpp"
#include "tensor_shape.hpp"
#include "tensor_signature.hpp"
#include "tensor_leg.hpp"
#include <assert.h>
......@@ -77,11 +78,16 @@ public:
const std::vector<T> & extents); //tensor dimension extents
/** Create a rank-0 tensor (scalar). **/
Tensor(const std::string & name); //tensor name
/** Create a tensor by contracting two other tensors. **/
Tensor(const std::string & name, //tensor name
const Tensor & left_tensor, //left tensor
const Tensor & right_tensor, //right tensor
int index_pattern[]); //index contraction pattern
/** Create a tensor by contracting two other tensors.
The vectors of tensor legs specify the tensor contraction pattern:
contraction.size() = left_rank + right_rank;
Output tensor id = 0;
Left input tensor id = 1;
Right input tensor id = 2. **/
Tensor(const std::string & name, //tensor name
const Tensor & left_tensor, //left tensor
const Tensor & right_tensor, //right tensor
const std::vector<TensorLeg> & contraction); //tensor contraction pattern
Tensor(const Tensor & tensor) = default;
Tensor & operator=(const Tensor & tensor) = default;
......
/** ExaTN::Numerics: Tensor shape
REVISION: 2019/07/02
REVISION: 2019/07/08
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -48,6 +48,13 @@ const std::vector<DimExtent> & TensorShape::getDimExtents() const
return extents_;
}
void TensorShape::resetDimension(unsigned int dim_id, DimExtent extent)
{
assert(dim_id < extents_.size()); //debug
extents_[dim_id] = extent;
return;
}
void TensorShape::deleteDimension(unsigned int dim_id)
{
assert(dim_id < extents_.size());
......
/** ExaTN::Numerics: Tensor shape
REVISION: 2019/07/02
REVISION: 2019/07/08
Copyright (C) 2018-2018 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2018 Oak Ridge National Laboratory (UT-Battelle) **/
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
/** Rationale:
(a) Tensor shape is an ordered set of tensor dimension extents.
......@@ -54,6 +54,9 @@ public:
/** Get the extents of all tensor dimensions. **/
const std::vector<DimExtent> & getDimExtents() const;
/** Resets a specific dimension. **/
void resetDimension(unsigned int dim_id, DimExtent extent);
/** Deletes a specific dimension, reducing the shape rank by one. **/
void deleteDimension(unsigned int dim_id);
......
/** ExaTN::Numerics: Tensor signature
REVISION: 2019/07/02
REVISION: 2019/07/08
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -70,6 +70,13 @@ std::pair<SpaceId,SubspaceId> TensorSignature::getDimSpaceAttr(unsigned int dim_
return subspaces_[dim_id];
}
void TensorSignature::resetDimension(unsigned int dim_id, std::pair<SpaceId,SubspaceId> subspace)
{
assert(dim_id < subspaces_.size()); //debug
subspaces_[dim_id] = subspace;
return;
}
void TensorSignature::deleteDimension(unsigned int dim_id)
{
assert(dim_id < subspaces_.size());
......
/** ExaTN::Numerics: Tensor signature
REVISION: 2019/07/02
REVISION: 2019/07/08
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -59,6 +59,9 @@ public:
SubspaceId getDimSubspaceId(unsigned int dim_id) const;
std::pair<SpaceId,SubspaceId> getDimSpaceAttr(unsigned int dim_id) const;
/** Resets a specific subspace. **/
void resetDimension(unsigned int dim_id, std::pair<SpaceId,SubspaceId> subspace);
/** Deletes a specific subspace, reducing the signature rank by one. **/
void deleteDimension(unsigned int dim_id);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment