tensor_network.cpp 116 KB
Newer Older
1
/** ExaTN::Numerics: Tensor network
2
REVISION: 2021/12/22
3

4
5
Copyright (C) 2018-2021 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle) **/
6
7

#include "tensor_network.hpp"
8
#include "tensor_symbol.hpp"
9
#include "contraction_seq_optimizer_factory.hpp"
10
#include "functor_init_val.hpp"
11

12
13
#include "metis_graph.hpp"

14
#include <iostream>
15
16
#include <string>
#include <vector>
17
#include <list>
18
#include <map>
19
#include <memory>
20
#include <algorithm>
21

22
23
24
25
namespace exatn{

namespace numerics{

26
27
28
29
//Tensor contraction sequence optmizers:
std::map<std::string,std::shared_ptr<ContractionSeqOptimizer>> optimizers;


30
//Helpers:
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
inline bool isIntermediateTensorName(const std::string & tensor_name)
{
 if(tensor_name.length() >= 2){
  if((tensor_name[0] == '_' && tensor_name[1] == 'x') ||
     (tensor_name[0] == '_' && tensor_name[1] == 'y') ||
     (tensor_name[0] == '_' && tensor_name[1] == 'z')) return true;
 }
 return false;
}


inline bool isPureIntermediateTensorName(const std::string & tensor_name)
{
 if(tensor_name.length() >= 2){
  if((tensor_name[0] == '_' && tensor_name[1] == 'x') ||
     (tensor_name[0] == '_' && tensor_name[1] == 'y')) return true;
 }
 return false;
}


bool tensorNameIsIntermediate(const Tensor & tensor,
                              bool * network_output)
54
55
56
57
{
 bool res = false, out = false;
 const auto & tens_name = tensor.getName();
 if(tens_name.length() >= 2){
58
59
60
61
  out = (tens_name[0] == '_' && tens_name[1] == 'z');    //_z: output tensor of the tensor network
  res = (out ||                                          //output tensor is also considered intermediate
         (tens_name[0] == '_' && tens_name[1] == 'y') || //_y: intermediate tensor of the tensor network
         (tens_name[0] == '_' && tens_name[1] == 'x'));  //_x: intermediate tensor of the tensor network
62
63
64
65
66
67
 }
 if(network_output != nullptr) *network_output = out;
 return res;
}


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
double getTensorContractionCost(const TensorConn & left_tensor, const TensorConn & right_tensor,
                                double * total_volume, double * diff_volume,
                                double * arithm_intensity, bool adjust_cost)
{
 double flops = 0.0, left_vol = 1.0, right_vol = 1.0, contr_vol = 1.0;
 const auto left_id = left_tensor.getTensorId();
 const auto left_rank = left_tensor.getNumLegs();
 const auto right_id = right_tensor.getTensorId();
 const auto right_rank = right_tensor.getNumLegs();
 const auto & right_legs = right_tensor.getTensorLegs();
 for(unsigned int i = 0; i < left_rank; ++i){
  left_vol *= static_cast<double>(left_tensor.getDimExtent(i));
 }
 for(unsigned int i = 0; i < right_rank; ++i){
  double dim_ext = static_cast<double>(right_tensor.getDimExtent(i));
  if(right_legs[i].getTensorId() == left_id) contr_vol *= dim_ext; //contracted dimension
  right_vol *= dim_ext;
 }
 flops = left_vol * right_vol / contr_vol; //FMA flops (no FMA prefactor)
 double tot_vol = left_vol + right_vol + (flops / contr_vol); //total volume of tensors
 if(total_volume != nullptr) *total_volume = tot_vol;
 if(diff_volume != nullptr) *diff_volume = (flops / contr_vol) - (left_vol + right_vol);
 if(arithm_intensity != nullptr) *arithm_intensity = flops / tot_vol;
 if(adjust_cost){ //increase the "effective" flop count if arithmetic intensity is low
  //`Finish: flops *= f(arithm_intensity): [max --> 1]
 }
 return flops;
}


void printContractionSequence(const std::list<numerics::ContrTriple> & contr_seq)
{
 unsigned int i = 0;
 for(const auto & contr: contr_seq){
  std::cout << "{" << contr.result_id << ":" << contr.left_id << "," << contr.right_id << "}";
  if(++i == 10){std::cout << std::endl; i = 0;}
 }
 if(i != 0) std::cout << std::endl;
 return;
}


void printContractionSequence(std::ofstream & output_file, const std::list<numerics::ContrTriple> & contr_seq)
{
 unsigned int i = 0;
 for(const auto & contr: contr_seq){
  output_file << "{" << contr.result_id << ":" << contr.left_id << "," << contr.right_id << "}";
  if(++i == 10){output_file << std::endl; i = 0;}
 }
 if(i != 0) output_file << std::endl;
 return;
}


122
//Main:
123
TensorNetwork::TensorNetwork():
124
 explicit_output_(0), finalized_(1), max_tensor_id_(0),
125
126
 contraction_seq_flops_(0.0), max_intermediate_presence_volume_(0.0),
 max_intermediate_volume_(0.0), max_intermediate_rank_(0), universal_indexing_(false)
127
{
128
 auto res = emplaceTensorConnDirect(false,
129
                                    0U, //output tensor (id = 0)
130
                                    std::make_shared<Tensor>("_smoky"),0U,std::vector<TensorLeg>{});
131
 if(!res){
132
133
134
  std::cout << "#ERROR(exatn::numerics::TensorNetwork::TensorNetwork): Tensor id already in use!" << std::endl;
  assert(false);
 }
135
136
137
}


138
TensorNetwork::TensorNetwork(const std::string & name):
139
 explicit_output_(0), finalized_(1), name_(name), max_tensor_id_(0),
140
141
 contraction_seq_flops_(0.0), max_intermediate_presence_volume_(0.0),
 max_intermediate_volume_(0.0), max_intermediate_rank_(0), universal_indexing_(false)
142
{
143
 auto res = emplaceTensorConnDirect(false,
144
145
146
                                    0U, //output tensor (id = 0)
                                    std::make_shared<Tensor>(name),0U,std::vector<TensorLeg>{});
 if(!res){
147
148
149
  std::cout << "#ERROR(exatn::numerics::TensorNetwork::TensorNetwork): Tensor id already in use!" << std::endl;
  assert(false);
 }
150
151
}

152
153
154
155

TensorNetwork::TensorNetwork(const std::string & name,
                             std::shared_ptr<Tensor> output_tensor,
                             const std::vector<TensorLeg> & output_legs):
156
 explicit_output_(1), finalized_(0), name_(name), max_tensor_id_(0),
157
158
 contraction_seq_flops_(0.0), max_intermediate_presence_volume_(0.0),
 max_intermediate_volume_(0.0), max_intermediate_rank_(0), universal_indexing_(false)
159
{
160
 auto res = emplaceTensorConnDirect(false,
161
162
163
                                    0U, //output tensor (id = 0)
                                    output_tensor,0U,output_legs);
 if(!res){
164
165
166
  std::cout << "#ERROR(exatn::numerics::TensorNetwork::TensorNetwork): Tensor id already in use!" << std::endl;
  assert(false);
 }
167
168
}

169

170
171
TensorNetwork::TensorNetwork(const std::string & name,
                             const std::string & tensor_network,
172
                             const std::map<std::string,std::shared_ptr<Tensor>> & tensors):
173
 explicit_output_(1), finalized_(0), name_(name), max_tensor_id_(0),
174
175
 contraction_seq_flops_(0.0), max_intermediate_presence_volume_(0.0),
 max_intermediate_volume_(0.0), max_intermediate_rank_(0), universal_indexing_(false)
176
{
177
178
179
 //Convert tensor hypernetwork into regular tensor network, if needed:
 //`Finish
 //Build a regular tensor network according the the provided symbolic specification:
180
 std::map<std::string,std::vector<TensorLeg>> index_map; //index label --> list of tensor legs associated with this index label
181
 std::vector<std::string> stensors; //individual tensors of the tensor network (symbolic)
182
 if(parse_tensor_network(tensor_network,stensors)){
183
  //Construct index correspondence map:
184
185
186
187
188
  std::string tensor_name;
  std::vector<IndexLabel> indices;
  for(unsigned int i = 0; i < stensors.size(); ++i){
   bool conjugated;
   if(parse_tensor(stensors[i],tensor_name,indices,conjugated)){
189
190
191
192
193
194
195
    for(unsigned int j = 0; j < indices.size(); ++j){
     auto pos = index_map.find(indices[j].label);
     if(pos == index_map.end()){
      auto res = index_map.emplace(std::make_pair(indices[j].label,std::vector<TensorLeg>{}));
      assert(res.second);
      pos = res.first;
     }
196
     pos->second.emplace_back(TensorLeg(i,j,indices[j].direction)); //directed index #j of tensor #i
197
    }
198
199
200
201
202
203
204
205
   }else{
    std::cout << "#ERROR(TensorNetwork::TensorNetwork): Invalid tensor in symbolic tensor network specification: " <<
     stensors[i] << std::endl;
    assert(false);
   }
   indices.clear();
   tensor_name.clear();
  }
206
  //Build the tensor network object:
207
208
209
210
  for(unsigned int i = 0; i < stensors.size(); ++i){
   bool conjugated;
   if(parse_tensor(stensors[i],tensor_name,indices,conjugated)){
    auto tensor = tensors.find(tensor_name);
211
212
213
214
215
216
217
218
219
220
221
222
223
    if(tensor != tensors.end()){
     std::vector<TensorLeg> legs;
     for(unsigned int j = 0; j < indices.size(); ++j){
      auto pos = index_map.find(indices[j].label);
      assert(pos != index_map.end());
      const auto & inds = pos->second;
      for(const auto & ind: inds){
       if(ind.getTensorId() != i || ind.getDimensionId() != j){
        legs.emplace_back(ind);
       }
      }
     }
     if(i == 0){
224
      assert(!conjugated); //output tensor must not appear complex conjugated
225
      auto res = emplaceTensorConnDirect(false,
226
227
228
                                         0U, //output tensor (id = 0)
                                         tensor->second,0U,legs);
      if(!res){
229
230
231
       std::cout << "#ERROR(exatn::numerics::TensorNetwork::TensorNetwork): Tensor id already in use!" << std::endl;
       assert(false);
      }
232
     }else{ //input tensor
233
      this->placeTensor(i,tensor->second,legs,conjugated);
234
235
236
237
238
     }
    }else{
     std::cout << "#ERROR(TensorNetwork::TensorNetwork): Unable to find tensor named " <<
      tensor_name << std::endl;
     assert(false);
239
240
241
    }
   }
  }
242
243
244
245
246
 }else{
  std::cout << "#ERROR(TensorNetwork::TensorNetwork): Invalid symbolic tensor network specification: " <<
   tensor_network << std::endl;
  assert(false);
 }
247
248
 bool finalized = this->finalize();
 assert(finalized);
249
250
251
}


252
253
TensorNetwork::TensorNetwork(const std::string & name,
                             std::shared_ptr<Tensor> output_tensor,
254
255
                             NetworkBuilder & builder,
                             bool tensor_operator):
256
 explicit_output_(1), finalized_(0), name_(name), max_tensor_id_(0),
257
258
 contraction_seq_flops_(0.0), max_intermediate_presence_volume_(0.0),
 max_intermediate_volume_(0.0), max_intermediate_rank_(0), universal_indexing_(false)
259
{
260
 auto res = emplaceTensorConnDirect(false,
261
262
263
264
                                    0U, //output tensor (id = 0)
                                    output_tensor,0U,
                                    std::vector<TensorLeg>(output_tensor->getRank(),TensorLeg(0,0))); //dummy legs
 if(!res){
265
266
267
  std::cout << "#ERROR(exatn::numerics::TensorNetwork::TensorNetwork): Tensor id already in use!" << std::endl;
  assert(false);
 }
268
 builder.build(*this,tensor_operator); //create and link input tensors of the tensor network
269
 finalized_ = 1;
270
 updateConnectionsFromInputTensors(); //update output tensor legs
271
272
273
}


274
275
276
277
278
279
280
281
TensorNetwork::TensorNetwork(const TensorNetwork & another,
                             bool replace_output,
                             const std::string & new_output_name)
{
 *this = another;
 if(replace_output) this->resetOutputTensor(new_output_name);
}

282

283
void TensorNetwork::printIt(bool with_tensor_hash) const
284
{
285
286
287
 std::cout << "TensorNetwork(" << name_
           << ")[rank = " << this->getRank()
           << ", size = " << this->getNumTensors() << "]{" << std::endl;
288
289
 for(const auto & kv: tensors_){
  std::cout << " ";
290
  kv.second.printIt(with_tensor_hash);
291
 }
292
293
294
295
 std::cout << "}" << std::endl;
 return;
}

296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
void TensorNetwork::printItFile(std::ofstream & output_file,
                                bool with_tensor_hash) const
{
 output_file << "TensorNetwork(" << name_
             << ")[rank = " << this->getRank()
             << ", size = " << this->getNumTensors() << "]{" << std::endl;
 for(const auto & kv: tensors_){
  output_file << " ";
  kv.second.printItFile(output_file,with_tensor_hash);
 }
 output_file << "}" << std::endl;
 return;
}


312
313
bool TensorNetwork::isEmpty() const
{
314
 return (tensors_.size() <= 1); //only output tensor exists => still empty
315
316
317
318
319
320
321
322
323
324
325
326
}


bool TensorNetwork::isExplicit() const
{
 return (explicit_output_ != 0);
}


bool TensorNetwork::isFinalized() const
{
 return (finalized_ != 0);
327
328
}

329

330
331
332
333
334
335
bool TensorNetwork::isValid()
{
 return checkConnections();
}


336
337
unsigned int TensorNetwork::getRank() const
{
338
 //assert(this->isFinalized());
339
340
341
342
 return tensors_.at(0).getNumLegs(); //output tensor
}


343
344
345
346
unsigned int TensorNetwork::getNumTensors() const
{
 return static_cast<unsigned int>(tensors_.size() - 1); //output tensor is not counted
}
347

348

349
unsigned int TensorNetwork::getMaxTensorId()
350
{
351
352
353
354
355
356
357
 if(max_tensor_id_ == 0){
  for(const auto & kv: tensors_) max_tensor_id_ = std::max(max_tensor_id_,kv.first);
 }
 return max_tensor_id_;
}


358
359
TensorElementType TensorNetwork::getTensorElementType() const
{
360
361
 assert(this->isFinalized());
 for(const auto & tens: tensors_){
362
363
364
365
  if(tens.first != 0){
   const auto elem_type = tens.second.getElementType();
   if(elem_type != TensorElementType::VOID) return elem_type;
  }
366
367
 }
 return TensorElementType::VOID;
368
369
370
}


371
372
373
374
375
376
377
378
379
380
381
382
383
void TensorNetwork::updateMaxTensorIdOnAppend(unsigned int tensor_id)
{
 auto curr_max_id = getMaxTensorId();
 max_tensor_id_ = std::max(curr_max_id,tensor_id);
 return;
}


void TensorNetwork::updateMaxTensorIdOnRemove(unsigned int tensor_id)
{
 if(tensor_id != 0 && tensor_id == max_tensor_id_){
  max_tensor_id_ = 0; //reset max tensor id to Undefined
  //auto refresh_max_tensor_id = getMaxTensorId();
384
 }
385
 return;
386
387
388
}


389
390
391
392
393
394
395
396
397
398
void TensorNetwork::resetOutputTensor(const std::string & name)
{
 assert(finalized_ != 0);
 auto iter = tensors_.find(0);
 assert(iter != tensors_.end());
 iter->second.replaceStoredTensor(name);
 return;
}


399
400
401
402
403
404
405
406
407
408
409
void TensorNetwork::resetOutputTensor(const std::vector<unsigned int> & order,
                                      const std::string & name)
{
 assert(finalized_ != 0);
 auto iter = tensors_.find(0);
 assert(iter != tensors_.end());
 iter->second.replaceStoredTensor(order,name);
 return;
}


410
411
412
413
414
const std::string & TensorNetwork::getName() const
{
 return name_;
}

415

416
417
void TensorNetwork::rename(const std::string & name)
{
418
419
 assert(finalized_ != 0);
 resetOutputTensor();
420
421
422
423
424
 name_ = name;
 return;
}


425
TensorConn * TensorNetwork::getTensorConn(unsigned int tensor_id)
426
427
428
429
430
431
{
 auto it = tensors_.find(tensor_id);
 if(it == tensors_.end()) return nullptr;
 return &(it->second);
}

432

433
434
435
436
437
438
439
440
441
442
443
std::vector<TensorConn*> TensorNetwork::getTensorConnAll()
{
 std::vector<TensorConn*> tensors(this->getNumTensors(),nullptr);
 unsigned int i = 0;
 for(auto & kv: tensors_){
  if(kv.first != 0) tensors[i++] = &(kv.second);
 }
 return tensors;
}


444
std::shared_ptr<Tensor> TensorNetwork::getTensor(unsigned int tensor_id, bool * conjugated) const
445
446
447
{
 auto it = tensors_.find(tensor_id);
 if(it == tensors_.end()) return std::shared_ptr<Tensor>(nullptr);
448
 if(conjugated != nullptr) *conjugated = (it->second).isComplexConjugated();
449
450
451
 return (it->second).getTensor();
}

452

453
const std::vector<TensorLeg> * TensorNetwork::getTensorConnections(unsigned int tensor_id) const
454
455
456
457
458
459
460
{
 auto it = tensors_.find(tensor_id);
 if(it == tensors_.end()) return nullptr;
 return &((it->second).getTensorLegs());
}


461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
std::list<unsigned int> TensorNetwork::getAdjacentTensors(unsigned int tensor_id) const
{
 std::list<unsigned int> tensor_ids;
 const auto * legs = this->getTensorConnections(tensor_id);
 if(legs != nullptr){
  for(const auto & leg: *legs){
   const auto new_tensor_id = leg.getTensorId();
   if(new_tensor_id != 0){ //ignore the output tensor
    auto iter = std::find(tensor_ids.begin(),tensor_ids.end(),new_tensor_id);
    if(iter == tensor_ids.end()) tensor_ids.emplace_back(new_tensor_id);
   }
  }
 }
 return tensor_ids;
}


478
bool TensorNetwork::finalize(bool check_validity)
479
{
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
 if(finalized_ == 0){
  if(this->isEmpty()){ //empty networks cannot be finalized
   std::cout << "#ERROR(TensorNetwork::finalize): Empty tensor network cannot be finalized!" << std::endl;
   return false;
  }
  finalized_ = 1;
  if(check_validity){
   if(!checkConnections()){
    finalized_ = 0;
    std::cout << "#ERROR(TensorNetwork::finalize): Invalid connectivity prevents tensor network finalization!" << std::endl;
    return false;
   }
  }
 }
 return true;
}


bool TensorNetwork::checkConnections(unsigned int tensor_id)
{
 assert(finalized_ != 0); //tensor network must be in finalized state
 auto * tensor = this->getTensorConn(tensor_id);
 assert(tensor != nullptr); //invalid tensor_id
 auto tensor_rank = tensor->getNumLegs();
 for(unsigned int i = 0; i < tensor_rank; ++i){
  const auto & tensor_leg = tensor->getTensorLeg(i);
  auto other_tensor_id = tensor_leg.getTensorId();
  auto other_tensor_leg_id = tensor_leg.getDimensionId();
  auto * other_tensor = this->getTensorConn(other_tensor_id);
  assert(other_tensor != nullptr); //unable to find the linked tensor
  const auto & other_tensor_leg = other_tensor->getTensorLeg(other_tensor_leg_id);
  if(other_tensor_leg.getTensorId() != tensor_id ||
     other_tensor_leg.getDimensionId() != i ||
     other_tensor_leg.getDirection() != reverseLegDirection(tensor_leg.getDirection())) return false;
 }
 return true;
}


bool TensorNetwork::checkConnections()
{
 assert(finalized_ != 0); //tensor network must be in finalized state
 for(const auto & kv: tensors_){
  if(!checkConnections(kv.first)) return false;
524
525
526
527
528
 }
 return true;
}


529
void TensorNetwork::updateConnections(unsigned int tensor_id)
530
{
531
532
533
534
535
536
537
538
539
540
541
542
543
544
 assert(finalized_ != 0); //tensor network must be in finalized state
 auto * tensor = this->getTensorConn(tensor_id);
 assert(tensor != nullptr); //invalid tensor_id
 auto tensor_rank = tensor->getNumLegs();
 for(unsigned int i = 0; i < tensor_rank; ++i){
  const auto & tensor_leg = tensor->getTensorLeg(i);
  auto other_tensor_id = tensor_leg.getTensorId();
  auto other_tensor_leg_id = tensor_leg.getDimensionId();
  auto * other_tensor = this->getTensorConn(other_tensor_id);
  assert(other_tensor != nullptr); //unable to find the linked tensor
  auto other_tensor_leg = other_tensor->getTensorLeg(other_tensor_leg_id);
  other_tensor_leg.resetTensorId(tensor_id);
  other_tensor_leg.resetDimensionId(i);
  other_tensor->resetLeg(other_tensor_leg_id,other_tensor_leg);
545
546
547
548
549
 }
 return;
}


550
551
552
553
554
555
556
557
558
void TensorNetwork::updateConnectionsFromInputTensors()
{
 for(auto iter = this->cbegin(); iter != this->cend(); ++iter){
  if(iter->first != 0) updateConnections(iter->first);
 }
 return;
}


559
560
561
562
563
564
565
void TensorNetwork::invalidateMaxTensorId()
{
 max_tensor_id_ = 0;
 return;
}


566
567
void TensorNetwork::invalidateContractionSequence()
{
568
 split_tensors_.clear();
569
 split_indices_.clear();
570
571
572
 operations_.clear();
 contraction_seq_.clear();
 contraction_seq_flops_ = 0.0;
573
 max_intermediate_presence_volume_ = 0.0;
574
 max_intermediate_volume_ = 0.0;
575
 max_intermediate_rank_ = 0;
576
 universal_indexing_ = false;
577
578
579
580
 return;
}


581
582
583
584
585
586
587
588
589
590
591
592
593
void TensorNetwork::invalidateTensorOperationList()
{
 split_tensors_.clear();
 split_indices_.clear();
 operations_.clear();
 max_intermediate_presence_volume_ = 0.0;
 max_intermediate_volume_ = 0.0;
 max_intermediate_rank_ = 0;
 universal_indexing_ = false;
 return;
}


594
595
596
597
double TensorNetwork::determineContractionSequence(ContractionSeqOptimizer & contr_seq_optimizer)
{
 assert(finalized_ != 0); //tensor network must be in finalized state
 if(contraction_seq_.empty()){
598
599
600
  auto intermediate_num_begin = this->getMaxTensorId() + 1;
  auto intermediate_num_generator = [intermediate_num_begin]() mutable {return intermediate_num_begin++;};
  contraction_seq_flops_ = contr_seq_optimizer.determineContractionSequence(*this,contraction_seq_,intermediate_num_generator);
601
602
603
  max_intermediate_presence_volume_ = 0.0;
  max_intermediate_volume_ = 0.0;
  max_intermediate_rank_ = 0;
604
605
606
607
608
 }
 return contraction_seq_flops_;
}


609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
double TensorNetwork::determineContractionSequence(const std::string & contr_seq_opt_name)
{
 auto iter = optimizers.find(contr_seq_opt_name);
 if(iter == optimizers.end()){ //not cached
  auto & optimizer_factory = *(ContractionSeqOptimizerFactory::get());
  auto optimizer = optimizer_factory.createContractionSeqOptimizer(contr_seq_opt_name);
  if(optimizer){
   auto res = optimizers.emplace(std::make_pair(contr_seq_opt_name,
                                 std::shared_ptr<ContractionSeqOptimizer>(std::move(optimizer))));
   assert(res.second);
   iter = res.first;
  }else{
   std::cout << "#ERROR(TensorNetwork::determineContractionSequence): Invalid request: " <<
    "Tensor contraction sequence optimizer " << contr_seq_opt_name << " has not been registered before!" << std::endl;
   assert(false);
  }
 }
 return determineContractionSequence(*(iter->second));
}


void TensorNetwork::importContractionSequence(const std::list<ContrTriple> & contr_sequence,
                                              double fma_flops)
632
633
634
635
{
 assert(finalized_ != 0); //tensor network must be in finalized state
 contraction_seq_.clear();
 contraction_seq_ = contr_sequence;
636
 contraction_seq_flops_ = fma_flops; //flop count may be unknown yet (defaults to zero)
637
 max_intermediate_presence_volume_ = 0.0; //max cumulative volume of intermediates present at a time
638
639
 max_intermediate_volume_ = 0.0; //max intermediate tensor volume is unknown yet
 max_intermediate_rank_ = 0; //max intermediate tensor rank
640
641
642
643
 return;
}


644
645
646
647
648
649
650
651
652
653
654
655
656
657
void TensorNetwork::importContractionSequence(const std::vector<unsigned int> & contr_sequence_content,
                                              double fma_flops)
{
 assert(finalized_ != 0); //tensor network must be in finalized state
 contraction_seq_.clear();
 unpackContractionSequenceFromVector(contraction_seq_,contr_sequence_content);
 contraction_seq_flops_ = fma_flops; //flop count may be unknown yet (defaults to zero)
 max_intermediate_presence_volume_ = 0.0; //max cumulative volume of intermediates present at a time
 max_intermediate_volume_ = 0.0; //max intermediate tensor volume is unknown yet
 max_intermediate_rank_ = 0; //max intermediate tensor rank
 return;
}


658
const std::list<ContrTriple> & TensorNetwork::exportContractionSequence(double * fma_flops) const
659
{
660
 if(fma_flops != nullptr) *fma_flops = contraction_seq_flops_;
661
662
663
664
 return contraction_seq_;
}


665
666
667
inline IndexSplit splitDimension(std::pair<SpaceId,SubspaceId> space_attr, //helper
                                 DimExtent dim_extent,
                                 std::size_t num_segments)
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
{
 assert(dim_extent >= num_segments);
 IndexSplit split_info;
 if(space_attr.first == SOME_SPACE){ //anonymous vector space
  std::size_t seg_size = dim_extent/num_segments;
  std::size_t remainder = dim_extent - seg_size * num_segments;
  SubspaceId base = space_attr.second;
  for(std::size_t i = 0; i < num_segments; ++i){
   DimExtent extent = seg_size; if(i < remainder) extent++;
   split_info.emplace_back(std::pair<SubspaceId,DimExtent>{base,extent});
   base += extent;
  }
  assert(base == space_attr.second + dim_extent);
 }else{ //registered named vector space
  assert(false); //`Implement in future
 }
 return split_info;
}


688
689
void TensorNetwork::establishUniversalIndexNumeration()
{
690
691
692
693
694
695
696
 if(universal_indexing_) return;
 std::unordered_map<TensorHashType,std::string> intermediates; //tensor hash --> symbolic tensor intermediate with universal indices
 std::unordered_map<std::string,std::string> index_map; //old index name --> new index name
 std::vector<std::string> tens_operands; //extracted tensor operands
 std::vector<IndexLabel> indices,new_indices; //indices extracted from a tensor
 std::string tensor_name; //tensor name extracted from a tensor
 std::string new_pattern; //new tensor operation index pattern
697
698
699
 bool conjugated = false;
 bool output_tensor_done = false;
 int num_internal_indices = 0;
700
701
702
 //Update index patterns in all tensor operations in reverse order:
 for(auto op_iter = operations_.rbegin(); op_iter != operations_.rend(); ++op_iter){
  auto & op = *(*op_iter); //tensor operation
703
704
  const auto num_operands = op.getNumOperands();
  const auto num_operands_out = op.getNumOperandsOut();
705
  assert(num_operands <= 3 && num_operands_out <= 1); //`Only expecting regular tensor operations so far
706
707
  const auto & old_pattern = op.getIndexPattern();
  if(old_pattern.length() > 0){ //index pattern present
708
   assert(num_operands > 1 && num_operands_out == 1); //presence of index pattern assumes two or more operands
709
   //std::cout << "#DEBUG(TensorNetwork::establishUniversalIndexNumeration): Old pattern: " << old_pattern << std::endl; //debug
710
   tens_operands.clear();
711
712
713
714
715
716
717
718
719
   bool success = parse_tensor_network(old_pattern,tens_operands);
   if(success){
    //Process all tensor operands:
    if(tens_operands.size() == num_operands){
     new_pattern.clear();
     //Process the only output tensor operand (#0):
     tensor_name.clear(); indices.clear();
     success = parse_tensor(tens_operands[0],tensor_name,indices,conjugated);
     if(success){
720
721
722
      //std::cout << " New output tensor: " << tens_operands[0] << std::endl; //debug
      const auto & tens0 = *(op.getTensorOperand(0));
      auto tensor_hash = tens0.getTensorHash();
723
724
      //Pre-save the output tensor of the tensor network:
      if(!output_tensor_done){
725
       assert(!conjugated); //output tensor cannot be conjugated
726
727
       auto res = intermediates.emplace(std::make_pair(tensor_hash,
                   assemble_symbolic_tensor(tens0.getName(),indices,conjugated)));
728
       assert(res.second);
729
       //std::cout << " Saved tensor: " << res.first->second << std::endl; //debug
730
731
       output_tensor_done = true;
      }
732
733
734
735
736
737
738
739
740
741
742
743
744
745
      //Retreive the intermediate (output) tensor operand in a universal form:
      auto tens_iter = intermediates.find(tensor_hash); assert(tens_iter != intermediates.end());
      //std::cout << " Found intermediate: " << tens_iter->second << std::endl; //debug
      new_pattern += (tens_iter->second + "+="); //append universally indexed output tensor operand to the new index pattern
      //Establish uncontracted index remapping:
      index_map.clear(); tensor_name.clear(); new_indices.clear();
      success = parse_tensor(tens_iter->second,tensor_name,new_indices,conjugated); assert(success);
      //std::cout << " Sizes of indices: " << indices.size() << " " << new_indices.size() << std::endl; //debug
      assert(new_indices.size() == indices.size());
      for(auto it_old = indices.cbegin(),
               it_new = new_indices.cbegin(); it_old != indices.cend(); ++it_old,
                                                                        ++it_new){
       index_map.emplace(std::make_pair(it_old->label,it_new->label));
      }
746
747
      //Process input tensor operands:
      int num_contr_indices = 0;
748
      for(unsigned int op_num = 1; op_num < num_operands; ++op_num){ //`Assumes a single output tensor operand (#0)
749
       //std::cout << " New input tensor: " << tens_operands[op_num] << std::endl; //debug
750
751
752
       tensor_name.clear(); indices.clear();
       success = parse_tensor(tens_operands[op_num],tensor_name,indices,conjugated);
       if(success){
753
754
755
756
        const auto & tens = *(op.getTensorOperand(op_num));
        tensor_hash = tens.getTensorHash();
        //Update the numeration of contracted indices with global numbers and remap uncontracted indices:
        num_contr_indices = 0;
757
758
759
760
761
        for(auto & index: indices){
         if(index.label[0] == 'c'){ //contracted index requires global shift
          num_contr_indices++;
          auto old_number = std::stoi(index.label.substr(1));
          index.label = ("c" + std::to_string(num_internal_indices + old_number));
762
763
764
765
766
767
         }else if(index.label[0] == 'u'){ //uncontracted indices need remapping
          index.label = index_map[index.label];
         }else{
          std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
                    << "Invalid index label encountered: " << index.label << std::endl;
          assert(false);
768
769
         }
        }
770
        const auto symb_tensor = assemble_symbolic_tensor(tens.getName(),indices,conjugated);
771
        if(isPureIntermediateTensorName(symb_tensor)){
772
         assert(!conjugated); //intermediate tensors do not appear conjugated
773
774
775
776
777
778
779
780
         auto res = intermediates.emplace(std::make_pair(tensor_hash,symb_tensor));
         if(!res.second){
          std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
                    << "Intermediate tensor already saved previously: " << symb_tensor << std::endl;
          assert(false);
         }
         //std::cout << " Saved tensor: " << res.first->second << std::endl; //debug
        }
781
        if(op_num == 1){
782
         new_pattern += symb_tensor;
783
        }else if(op_num == 2){
784
         new_pattern += ("*" + symb_tensor);
785
        }else{
786
         assert(false); //`At most three tensor operands are expected so far
787
788
789
790
791
792
793
794
795
        }
       }else{
        std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
                  << "Unable to parse tensor operand: " << tens_operands[op_num] << std::endl;
        assert(false);
       }
      }
      num_internal_indices += num_contr_indices;
      op.setIndexPattern(new_pattern);
796
      //std::cout << " New index pattern: " << new_pattern << std::endl; //debug
797
798
799
800
801
802
803
     }else{
      std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
                << "Unable to parse tensor operand: " << tens_operands[0] << std::endl;
      assert(false);
     }
    }else{
     std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
804
805
806
807
808
               << "Invalid number of tensor operands (" << tens_operands.size() << " VS " << num_operands
               << ") parsed from: " << old_pattern << ": ";
     for(const auto & operand: tens_operands) std::cout << operand << " ";
     std::cout << std::endl;
     op.printIt();
809
810
811
812
813
814
815
816
817
     assert(false);
    }
   }else{
    std::cout << "#ERROR(exatn::numerics::TensorNetwork::establishUniversalIndexNumeration): "
              << "Unable to parse tensor operation index pattern: " << old_pattern << std::endl;
    assert(false);
   }
  }
 }
818
 universal_indexing_ = true;
819
820
821
822
 return;
}


823
824
825
826
827
bool TensorNetwork::placeTensor(unsigned int tensor_id,                     //in: tensor id (unique within the tensor network)
                                std::shared_ptr<Tensor> tensor,             //in: appended tensor
                                const std::vector<TensorLeg> & connections, //in: tensor connections (fully specified)
                                bool conjugated,                            //in: complex conjugation flag for the appended tensor
                                bool leg_matching_check)                    //in: tensor leg matching check
828
829
{
 if(explicit_output_ == 0){
830
  std::cout << "#ERROR(TensorNetwork::placeTensor): Invalid request: " <<
831
832
833
834
   "Appending a tensor via explicit connections to the tensor network that is missing a full output tensor!" << std::endl;
  return false;
 }
 if(finalized_ != 0){
835
  std::cout << "#ERROR(TensorNetwork::placeTensor): Invalid request: " <<
836
837
   "Appending a tensor via explicit connections to the tensor network that has been finalized!" << std::endl;
  return false;
838
 }
839
 if(tensor_id == 0){
840
  std::cout << "#ERROR(TensorNetwork::placeTensor): Invalid request: " <<
841
842
843
   "Attempt to append an output tensor (id = 0) to a tensor network with an explicit output tensor!" << std::endl;
  return false;
 }
844
 //Check the validity of new connections:
845
846
847
848
849
850
851
852
 if(leg_matching_check){
  unsigned int mode = 0;
  for(const auto & leg: connections){
   const auto * tensconn = this->getTensorConn(leg.getTensorId());
   if(tensconn != nullptr){ //connected tensor is already in the tensor network
    const auto & tens_legs = tensconn->getTensorLegs();
    const auto & tens_leg = tens_legs[leg.getDimensionId()];
    if(tens_leg.getTensorId() != tensor_id || tens_leg.getDimensionId() != mode){
853
     std::cout << "#ERROR(TensorNetwork::placeTensor): Invalid argument: Connections are invalid: "
854
855
856
               << "Failed input leg: "; leg.printIt(); std::cout << std::endl;
     return false;
    }
857
   }
858
   ++mode;
859
860
861
  }
 }
 //Append the tensor to the tensor network:
862
863
864
865
 auto res = emplaceTensorConnDirect(true,
                                    tensor_id,
                                    tensor,tensor_id,connections,conjugated);
 if(!res){
866
  std::cout << "#ERROR(TensorNetwork::placeTensor): Invalid request: " <<
867
868
869
870
   "A tensor with id " << tensor_id << " already exists in the tensor network!" << std::endl;
  return false;
 }
 return true;
871
872
}

873

874
875
876
877
878
bool TensorNetwork::appendTensor(unsigned int tensor_id,
                                 std::shared_ptr<Tensor> tensor,
                                 const std::vector<std::pair<unsigned int, unsigned int>> & pairing,
                                 const std::vector<LegDirection> & leg_dir,
                                 bool conjugated)
879
{
880
 if(explicit_output_ != 0 && finalized_ == 0){
881
  std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid request: " <<
882
883
884
885
886
887
   "Appending a tensor via implicit pairing with the output tensor, but the tensor network is not finalized!" << std::endl;
  return false;
 }
 if(tensor_id == 0){
  std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid request: " <<
   "Attempt to append an output tensor (id = 0) to a finalized tensor network!" << std::endl;
888
889
  return false;
 }
890
891
 //Reset the output tensor to a new one:
 this->resetOutputTensor();
892
 //Check validity of leg pairing:
893
 auto tensor_rank = tensor->getRank();
894
895
896
897
898
 bool dir_present = (leg_dir.size() > 0);
 if(dir_present && leg_dir.size() != tensor_rank){
  std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Incomplete vector of leg directions!" << std::endl;
  return false;
 }
899
900
 auto * output_tensor = this->getTensorConn(0);
 assert(output_tensor != nullptr); //output tensor must be present
901
 auto output_rank = output_tensor->getNumLegs();
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
 if(output_rank > 0 && tensor_rank > 0){
  int ouf[output_rank] = {0};
  int tef[tensor_rank] = {0};
  for(const auto & link: pairing){
   if(link.first >= output_rank || link.second >= tensor_rank){
    std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Invalid leg pairing!" << std::endl;
    return false;
   }
   if(ouf[link.first]++ != 0){
    std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Pairing: Repeated output leg!" << std::endl;
    return false;
   }
   if(tef[link.second]++ != 0){
    std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Pairing: Repeated new tensor leg!" << std::endl;
    return false;
   }
  }
919
920
921
922
923
 }else{
  if(pairing.size() > 0){
   std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Pairing: Pairing on a scalar tensor!" << std::endl;
   return false;
  }
924
 }
925
 //Pair legs of the new tensor with the input tensors of the tensor network:
926
 if(tensor_rank > 0){ //true tensor
927
  std::vector<TensorLeg> new_tensor_legs(tensor_rank,TensorLeg(0,0)); //placeholders for legs
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
  if(pairing.size() > 0){
   std::vector<unsigned int> matched_output_legs(pairing.size(),0);
   unsigned int mode = 0;
   for(const auto & link: pairing){
    const auto & output_tensor_leg_id = link.first;
    const auto & tensor_leg_id = link.second;
    auto output_tensor_leg = output_tensor->getTensorLeg(output_tensor_leg_id);
    const auto input_tensor_id = output_tensor_leg.getTensorId();
    const auto input_tensor_leg_id = output_tensor_leg.getDimensionId();
    auto * input_tensor = this->getTensorConn(input_tensor_id);
    assert(input_tensor != nullptr);
    auto input_tensor_leg = input_tensor->getTensorLeg(input_tensor_leg_id);
    input_tensor_leg.resetTensorId(tensor_id);
    input_tensor_leg.resetDimensionId(tensor_leg_id);
    input_tensor->resetLeg(input_tensor_leg_id,input_tensor_leg);
    new_tensor_legs[tensor_leg_id].resetTensorId(input_tensor_id);
    new_tensor_legs[tensor_leg_id].resetDimensionId(input_tensor_leg_id);
    if(dir_present){
     new_tensor_legs[tensor_leg_id].resetDirection(leg_dir[tensor_leg_id]);
     if(input_tensor_leg.getDirection() != reverseLegDirection(new_tensor_legs[tensor_leg_id].getDirection())){
      std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid argument: Leg directions: Pairing leg direction mismatch!" << std::endl;
      return false;
     }
    }else{
     new_tensor_legs[tensor_leg_id].resetDirection(reverseLegDirection(input_tensor_leg.getDirection()));
953
    }
954
    matched_output_legs[mode++] = output_tensor_leg_id;
955
   }
956
957
958
   //Delete matched legs from the output tensor:
   output_tensor->deleteLegs(matched_output_legs);
   updateConnections(0); //update tensor network connections due to deletion of the matched output tensor legs
959
960
  }
  //Append unpaired legs of the new tensor to the output tensor of the network:
961
962
963
964
965
966
967
968
969
970
971
972
973
974
  output_rank = output_tensor->getNumLegs();
  unsigned int mode = 0;
  for(auto & leg: new_tensor_legs){
   if(leg.getTensorId() == 0){ //unpaired tensor leg
    LegDirection dir = LegDirection::UNDIRECT;
    if(dir_present) dir = leg_dir[mode];
    leg.resetDimensionId(output_rank);
    leg.resetDirection(dir);
    output_tensor->appendLeg(tensor->getDimSpaceAttr(mode),tensor->getDimExtent(mode),
                             TensorLeg(tensor_id,mode,reverseLegDirection(dir)));
    output_rank = output_tensor->getNumLegs();
   }
   ++mode;
  }
975
976
977
978
  auto res = emplaceTensorConnDirect(true,
                                     tensor_id,
                                     tensor,tensor_id,new_tensor_legs,conjugated);
  if(!res){
979
980
981
982
   std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid request: " <<
    "A tensor with id " << tensor_id << " already exists in the tensor network!" << std::endl;
   return false;
  }
983
 }else{ //scalar tensor
984
985
986
987
  auto res = emplaceTensorConnDirect(true,
                                     tensor_id,
                                     tensor,tensor_id,std::vector<TensorLeg>{},conjugated);
  if(!res){
988
989
990
991
992
   std::cout << "#ERROR(TensorNetwork::appendTensor): Invalid request: " <<
    "A tensor with id " << tensor_id << " already exists in the tensor network!" << std::endl;
   return false;
  }
 }
993
 invalidateContractionSequence(); //invalidate previously cached tensor contraction sequence
994
 finalized_ = 1; //implicit leg pairing always keeps the tensor network in a finalized state
995
996
997
 return true;
}

998

999
1000
bool TensorNetwork::appendTensor(std::shared_ptr<Tensor> tensor,
                                 const std::vector<std::pair<unsigned int, unsigned int>> & pairing,