Commit 875aabf9 authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh
Browse files

Updated tensor network queue workflow in GraphExecutor and CuQuantumExecutor


Signed-off-by: default avatarDmitry I. Lyakh <quant4me@gmail.com>
parent 67452e37
......@@ -42,8 +42,11 @@ namespace exatn {
namespace runtime {
struct TensorDescriptor {
std::vector<int32_t> modes;
std::vector<int64_t> extents;
std::vector<int32_t> modes; //indices associated with tensor dimensions
std::vector<int64_t> extents; //tensor dimension extents
void * body_ptr = nullptr; //pointer to the tensor body image
std::size_t volume = 0; //tensor body volume
cudaDataType_t data_type; //tensor element data type
};
struct TensorNetworkReq {
......@@ -55,6 +58,8 @@ struct TensorNetworkReq {
cutensornetContractionOptimizerInfo_t opt_info;
cutensornetContractionPlan_t comp_plan;
cudaStream_t stream;
cutensornetComputeType_t compute_type;
TensorNetworkQueue::ExecStat exec_status = TensorNetworkQueue::ExecStat::Idle;
};
......@@ -69,14 +74,13 @@ CuQuantumExecutor::CuQuantumExecutor(TensorImplFunc tensor_data_access_func):
int num_gpus = 0;
auto error_code = talshDeviceCount(DEV_NVIDIA_GPU,&num_gpus); assert(error_code == TALSH_SUCCESS);
for(int i = 0; i < num_gpus; ++i){
if(talshDeviceState(i,DEV_NVIDIA_GPU) >= DEV_ON) gpus_.emplace_back(i);
if(talshDeviceState(i,DEV_NVIDIA_GPU) >= DEV_ON) gpu_attr_.emplace_back(std::make_pair(i,DeviceAttr{}));
}
std::cout << "#DEBUG(exatn::runtime::CuQuantumExecutor): Number of available GPUs = " << gpus_.size() << std::endl;
std::cout << "#DEBUG(exatn::runtime::CuQuantumExecutor): Number of available GPUs = " << gpu_attr_.size() << std::endl;
ctn_handles_.resize(gpus_.size());
for(const auto & gpu_id: gpus_){
HANDLE_CUDA_ERROR(cudaSetDevice(gpu_id));
HANDLE_CTN_ERROR(cutensornetCreate((cutensornetHandle_t*)(&ctn_handles_[gpu_id])));
for(const auto & gpu: gpu_attr_){
HANDLE_CUDA_ERROR(cudaSetDevice(gpu.first));
HANDLE_CTN_ERROR(cutensornetCreate((cutensornetHandle_t*)(&gpu.second.cutn_handle)));
}
std::cout << "#DEBUG(exatn::runtime::CuQuantumExecutor): Created cuTensorNet contexts for all available GPUs" << std::endl;
}
......@@ -85,43 +89,46 @@ CuQuantumExecutor::CuQuantumExecutor(TensorImplFunc tensor_data_access_func):
CuQuantumExecutor::~CuQuantumExecutor()
{
bool success = sync(); assert(success);
for(const auto & gpu_id: gpus_){
HANDLE_CUDA_ERROR(cudaSetDevice(gpu_id));
HANDLE_CTN_ERROR(cutensornetDestroy((cutensornetHandle_t)(ctn_handles_[gpu_id])));
for(const auto & gpu: gpu_attr_){
HANDLE_CUDA_ERROR(cudaSetDevice(gpu.first));
HANDLE_CTN_ERROR(cutensornetDestroy((cutensornetHandle_t)(gpu.second.cutn_handle)));
}
std::cout << "#DEBUG(exatn::runtime::CuQuantumExecutor): Destroyed cuTensorNet contexts for all available GPUs" << std::endl;
ctn_handles_.clear();
gpus_.clear();
gpu_attr_.clear();
}
int CuQuantumExecutor::execute(std::shared_ptr<numerics::TensorNetwork> network,
const TensorOpExecHandle exec_handle)
TensorNetworkQueue::ExecStat CuQuantumExecutor::execute(std::shared_ptr<numerics::TensorNetwork> network,
const TensorOpExecHandle exec_handle)
{
int error_code = 0;
//`Finish
return error_code;
}
bool CuQuantumExecutor::executing(const TensorOpExecHandle exec_handle)
{
auto iter = active_networks_.find(exec_handle);
return (iter != active_networks_.end());
assert(network);
TensorNetworkQueue::ExecStat exec_stat = TensorNetworkQueue::ExecStat::None;
auto res = active_networks_.emplace(std::make_pair(exec_handle, new TensorNetworkReq{}));
if(res.second){
auto tn_req = res.first->second;
tn_req->network = network;
exec_stat = tn_req->exec_status;
//`Finish
}else{
std::cout << "#WARNING(exatn::runtime::CuQuantumExecutor): execute: Repeated tensor network submission detected!\n";
}
return exec_stat;
}
bool CuQuantumExecutor::sync(const TensorOpExecHandle exec_handle,
int * error_code,
bool wait)
TensorNetworkQueue::ExecStat CuQuantumExecutor::sync(const TensorOpExecHandle exec_handle,
int * error_code,
bool wait)
{
bool synced = true;
*error_code = 0;
TensorNetworkQueue::ExecStat exec_stat = TensorNetworkQueue::ExecStat::None;
auto iter = active_networks_.find(exec_handle);
if(iter != active_networks_.end()){
auto tn_req = iter->second;
exec_stat = tn_req->exec_status;
//`Finish
}
return synced;
return exec_stat;
}
......
......@@ -47,27 +47,38 @@ public:
CuQuantumExecutor & operator=(CuQuantumExecutor &&) noexcept = delete;
virtual ~CuQuantumExecutor();
int execute(std::shared_ptr<numerics::TensorNetwork> network,
const TensorOpExecHandle exec_handle);
bool executing(const TensorOpExecHandle exec_handle);
bool sync(const TensorOpExecHandle exec_handle,
int * error_code,
bool wait = true);
/** Submits a tensor network for execution via CuQuantumExecutor.
The associated tensor network execution handle can be used
for progressing and completing the tensor network execution. **/
TensorNetworkQueue::ExecStat execute(std::shared_ptr<numerics::TensorNetwork> network,
const TensorOpExecHandle exec_handle);
/** Synchronizes on the progress of the tensor network execution.
If wait = TRUE, waits until completion, otherwise just tests the progress.
Returns the current status of the tensor network execution. **/
TensorNetworkQueue::ExecStat sync(const TensorOpExecHandle exec_handle,
int * error_code,
bool wait = true);
/** Synchronizes execution of all submitted tensor networks to completion. **/
bool sync();
protected:
struct DeviceAttr{
void * buffer_ptr = nullptr;
std::size_t buffer_size = 0;
void * workspace_ptr = nullptr;
std::size_t workspace_size = 0;
void * cutn_handle; //cutensornetHandle_t = void*
};
/** Currently processed tensor networks **/
std::unordered_map<TensorOpExecHandle,std::shared_ptr<TensorNetworkReq>> active_networks_;
/** GPU Ids available to the current process **/
std::vector<int> gpus_;
/** cuTensorNet contexts for all available GPUs **/
std::vector<void*> ctn_handles_; //cutensornetHandle_t = void*
/** Attributes of all GPUs available to the current process **/
std::vector<std::pair<int,DeviceAttr>> gpu_attr_; //{gpu_id, gpu_attributes}
/** Tensor data access function **/
TensorImplFunc tensor_data_access_func_;
TensorImplFunc tensor_data_access_func_; //numerics::Tensor --> {tensor_body_ptr, size_in_bytes}
};
} //namespace runtime
......
......@@ -36,7 +36,7 @@ public:
enum class ExecStat {
None, //no execution status
Idle, //submitted but execution has not yet started
Preparing, //preparation for execution has started
Preparing, //preparation for execution has started (loading data, planning)
Executing, //actual execution (numerical computation) has started
Completed //execution completed
};
......@@ -122,6 +122,22 @@ public:
return exec_stat;
}
/** Updates the execution status associated with
the given tensor network execution handle.
Returns the previous execution status. **/
ExecStat updateExecStatus(const TensorOpExecHandle exec_handle,
ExecStat new_exec_stat) {
auto exec_stat = ExecStat::None;
lock();
auto iter = tn_exec_stat_.find(exec_handle);
if(iter != tn_exec_stat_.cend()){
exec_stat = iter->second;
iter->second = new_exec_stat;
}
unlock();
return exec_stat;
}
/** Returns the constant iterator to the current tensor network. **/
ConstTensorNetworkQueueIterator getCurrent() {
return current_network_;
......
......@@ -39,7 +39,7 @@ void LazyGraphExecutor::resetNodeExecutor(std::shared_ptr<TensorNodeExecutor> no
);
}
#endif
return;
return;
}
......@@ -287,23 +287,25 @@ void LazyGraphExecutor::execute(TensorNetworkQueue & tensor_network_queue) {
tensor_network_queue.reset();
bool not_over = !tensor_network_queue.isOver();
while(not_over){
int error_code = 0;
const auto current = tensor_network_queue.getCurrent();
const auto exec_handle = current->second;
if(cuquantum_executor_->executing(exec_handle)){
int error_code = 0;
synced = cuquantum_executor_->sync(exec_handle,&error_code,false);
assert(error_code == 0);
if(synced){
tensor_network_queue.remove();
std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Completed tensor network execution via cuQuantum\n";
not_over = !tensor_network_queue.isOver();
}else{
not_over = tensor_network_queue.next();
auto exec_stat = cuquantum_executor_->sync(exec_handle,&error_code,false);
assert(error_code == 0);
if(exec_stat == TensorNetworkQueue::ExecStat::None){
exec_stat = cuquantum_executor_->execute(current->first,exec_handle);
if(exec_stat != TensorNetworkQueue::ExecStat::None){
auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Submitted tensor network to cuQuantum\n";
}
not_over = tensor_network_queue.next();
}else if(exec_stat == TensorNetworkQueue::ExecStat::Completed){
auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
tensor_network_queue.remove();
std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Completed tensor network execution via cuQuantum\n";
not_over = !tensor_network_queue.isOver();
}else{
auto error_code = cuquantum_executor_->execute(current->first,exec_handle);
assert(error_code == 0);
std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Submitted tensor network to cuQuantum\n";
auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
not_over = tensor_network_queue.next();
}
}
......
......@@ -966,19 +966,19 @@ int TalshNodeExecutor::execute(numerics::TensorOpFetch & op,
switch(tens_elem_type){
case(talsh::REAL32):
assert(tens_body_r4 != nullptr);
error_code = MPI_Irecv((void*)(&(tens_body_r4[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Irecv((void*)(&tens_body_r4[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::REAL64):
assert(tens_body_r8 != nullptr);
error_code = MPI_Irecv((void*)(&(tens_body_r8[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Irecv((void*)(&tens_body_r8[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::COMPLEX32):
assert(tens_body_c4 != nullptr);
error_code = MPI_Irecv((void*)(&(tens_body_c4[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Irecv((void*)(&tens_body_c4[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::COMPLEX64):
assert(tens_body_c8 != nullptr);
error_code = MPI_Irecv((void*)(&(tens_body_c8[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Irecv((void*)(&tens_body_c8[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
}
if(error_code != MPI_SUCCESS) break;
......@@ -1055,19 +1055,19 @@ int TalshNodeExecutor::execute(numerics::TensorOpUpload & op,
switch(tens_elem_type){
case(talsh::REAL32):
assert(tens_body_r4 != nullptr);
error_code = MPI_Isend((const void*)(&(tens_body_r4[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Isend((const void*)(&tens_body_r4[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::REAL64):
assert(tens_body_r8 != nullptr);
error_code = MPI_Isend((const void*)(&(tens_body_r8[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Isend((const void*)(&tens_body_r8[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::COMPLEX32):
assert(tens_body_c4 != nullptr);
error_code = MPI_Isend((const void*)(&(tens_body_c4[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Isend((const void*)(&tens_body_c4[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
case(talsh::COMPLEX64):
assert(tens_body_c8 != nullptr);
error_code = MPI_Isend((const void*)(&(tens_body_c8[base])),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
error_code = MPI_Isend((const void*)(&tens_body_c8[base]),count,mpi_data_kind,remote_rank,mesg_tag,communicator,mpi_req);
break;
}
if(error_code != MPI_SUCCESS) break;
......@@ -1132,19 +1132,19 @@ int TalshNodeExecutor::execute(numerics::TensorOpBroadcast & op,
switch(tens_elem_type){
case(talsh::REAL32):
assert(tens_body_r4 != nullptr);
error_code = MPI_Bcast((void*)(&(tens_body_r4[base])),count,mpi_data_kind,root_rank,communicator);
error_code = MPI_Bcast((void*)(&tens_body_r4[base]),count,mpi_data_kind,root_rank,communicator);
break;
case(talsh::REAL64):
assert(tens_body_r8 != nullptr);
error_code = MPI_Bcast((void*)(&(tens_body_r8[base])),count,mpi_data_kind,root_rank,communicator);
error_code = MPI_Bcast((void*)(&tens_body_r8[base]),count,mpi_data_kind,root_rank,communicator);
break;
case(talsh::COMPLEX32):
assert(tens_body_c4 != nullptr);
error_code = MPI_Bcast((void*)(&(tens_body_c4[base])),count,mpi_data_kind,root_rank,communicator);
error_code = MPI_Bcast((void*)(&tens_body_c4[base]),count,mpi_data_kind,root_rank,communicator);
break;
case(talsh::COMPLEX64):
assert(tens_body_c8 != nullptr);
error_code = MPI_Bcast((void*)(&(tens_body_c8[base])),count,mpi_data_kind,root_rank,communicator);
error_code = MPI_Bcast((void*)(&tens_body_c8[base]),count,mpi_data_kind,root_rank,communicator);
break;
}
if(error_code != MPI_SUCCESS) break;
......@@ -1209,19 +1209,19 @@ int TalshNodeExecutor::execute(numerics::TensorOpAllreduce & op,
switch(tens_elem_type){
case(talsh::REAL32):
assert(tens_body_r4 != nullptr);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&(tens_body_r4[base])),count,mpi_data_kind,MPI_SUM,communicator);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&tens_body_r4[base]),count,mpi_data_kind,MPI_SUM,communicator);
break;
case(talsh::REAL64):
assert(tens_body_r8 != nullptr);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&(tens_body_r8[base])),count,mpi_data_kind,MPI_SUM,communicator);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&tens_body_r8[base]),count,mpi_data_kind,MPI_SUM,communicator);
break;
case(talsh::COMPLEX32):
assert(tens_body_c4 != nullptr);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&(tens_body_c4[base])),count,mpi_data_kind,MPI_SUM,communicator);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&tens_body_c4[base]),count,mpi_data_kind,MPI_SUM,communicator);
break;
case(talsh::COMPLEX64):
assert(tens_body_c8 != nullptr);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&(tens_body_c8[base])),count,mpi_data_kind,MPI_SUM,communicator);
error_code = MPI_Allreduce(MPI_IN_PLACE,(void*)(&tens_body_c8[base]),count,mpi_data_kind,MPI_SUM,communicator);
break;
}
if(error_code != MPI_SUCCESS) break;
......@@ -1463,8 +1463,53 @@ const void * TalshNodeExecutor::getTensorImage(const numerics::Tensor & tensor,
int device_kind, int device_id,
std::size_t * size) const
{
//`Implement
return nullptr;
const auto tensor_hash = tensor.getTensorHash();
auto tens_pos = tensors_.find(tensor_hash);
if(tens_pos == tensors_.end()){
std::cout << "#ERROR(exatn::runtime::node_executor_talsh): getTensorImage: Tensor not found:\n";
tensor.printIt();
assert(false);
}
//tens_pos->second.resetTensorShapeToReduced();
auto & tens = *(tens_pos->second.talsh_tensor);
assert(device_kind == DEV_HOST && device_id == 0); //`temporary
auto synced = tens.sync(device_kind,device_id,nullptr,true); assert(synced);
void * tens_body = nullptr;
float * tens_body_r4 = nullptr;
double * tens_body_r8 = nullptr;
std::complex<float> * tens_body_c4 = nullptr;
std::complex<double> * tens_body_c8 = nullptr;
bool access_granted = false;
const int tens_elem_type = tens.getElementType();
switch(tens_elem_type){
case(talsh::REAL32):
access_granted = tens.getDataAccessHost(&tens_body_r4);
if(access_granted) tens_body = static_cast<void*>(tens_body_r4);
break;
case(talsh::REAL64):
access_granted = tens.getDataAccessHost(&tens_body_r8);
if(access_granted) tens_body = static_cast<void*>(tens_body_r8);
break;
case(talsh::COMPLEX32):
access_granted = tens.getDataAccessHost(&tens_body_c4);
if(access_granted) tens_body = static_cast<void*>(tens_body_c4);
break;
case(talsh::COMPLEX64):
access_granted = tens.getDataAccessHost(&tens_body_c8);
if(access_granted) tens_body = static_cast<void*>(tens_body_c8);
break;
default:
std::cout << "#ERROR(exatn::runtime::node_executor_talsh): getTensorImage: Unknown TAL-SH data kind: "
<< tens_elem_type << std::endl;
tens.print();
assert(false);
}
if(size != nullptr){
*size = tens.getSize();
assert(*size > 0);
}
return tens_body;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment