diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b78a956221ac6f07c66d0c287f7431ca56d6f1f..27ba07811f407be57c6867352afbb64e63dd0a60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.21) -project (rays CXX) +project (graph_framework CXX) #------------------------------------------------------------------------------- # Build Options @@ -12,6 +12,8 @@ option (USE_CONSTANT_CACHE "Cache the value of constantants in kernel registers. option (SHOW_USE_COUNT "Add a comment showing the use count in kernel sources." OFF) option (USE_INDEX_CACHE "Cache index values instead of computing them every time." OFF) option (USE_VERBOSE "Verbose jit option." OFF) +option (BUILD_C_BINDING "Build C interface." OFF) +option (BUILD_Fortran_BINDING "Build Fortran interface." OFF) #------------------------------------------------------------------------------- # Set the cmake module path. @@ -40,6 +42,9 @@ if (${APPLE}) if (${USE_METAL}) enable_language (OBJCXX) + if (${BUILD_C_BINDING}) + enable_language (OBJC) + endif () add_library (metal_lib INTERFACE) target_link_libraries (metal_lib @@ -54,7 +59,8 @@ if (${APPLE}) ) target_compile_options (metal_lib INTERFACE - -fobjc-arc + $<$:-fobjc-arc> + $<$:-fobjc-arc> ) endif () else () @@ -311,9 +317,22 @@ endif () #------------------------------------------------------------------------------- # Setup targets #------------------------------------------------------------------------------- - add_subdirectory (graph_framework) +if (${BUILD_Fortran_BINDING}) + set (BUILD_C_BINDING ON CACHE STRING "Build C interface." FORCE) +endif () + +if (${BUILD_C_BINDING}) + enable_language (C) + add_subdirectory (graph_c_binding) +endif () + +if (${BUILD_Fortran_BINDING}) + enable_language (Fortran) + add_subdirectory (graph_fortran_binding) +endif () + #------------------------------------------------------------------------------- # Setup testing #------------------------------------------------------------------------------- @@ -322,24 +341,41 @@ enable_testing () #------------------------------------------------------------------------------- # Tool setup #------------------------------------------------------------------------------- -macro (add_tool_target target) +macro (add_tool_target target lang) add_executable (${target}) target_sources (${target} PRIVATE - $ + $ ) if (${USE_METAL}) - set_source_files_properties (${CMAKE_CURRENT_SOURCE_DIR}/${target}.cpp + if (${lang} STREQUAL "cpp") + set_source_files_properties (${CMAKE_CURRENT_SOURCE_DIR}/${target}.${lang} + PROPERTIES + LANGUAGE OBJCXX + ) + elseif (${lang} STREQUAL "c") + set_source_files_properties (${CMAKE_CURRENT_SOURCE_DIR}/${target}.${lang} + PROPERTIES + LANGUAGE OBJC + ) + endif () + endif () + if (${lang} STREQUAL "c") + set_source_files_properties (${CMAKE_CURRENT_SOURCE_DIR}/${target}.${lang} PROPERTIES - LANGUAGE OBJCXX + SKIP_PRECOMPILE_HEADERS ON ) endif () target_link_libraries (${target} PUBLIC - rays + graph_framework ) + + if (${USE_PCH} AND ${BUILD_C_BINDING}) + target_precompile_headers (${target} REUSE_FROM graph_c) + endif () endmacro () add_subdirectory (graph_driver) @@ -350,15 +386,19 @@ add_subdirectory (graph_korc) #------------------------------------------------------------------------------- # Define macro function to register tests. #------------------------------------------------------------------------------- -macro (add_test_target target) - add_tool_target (${target}) +macro (add_test_target target lang) + add_tool_target (${target} ${lang}) add_test (NAME ${target} COMMAND ${target} ) if (${USE_PCH}) - target_precompile_headers (${target} REUSE_FROM xrays) + if (${BUILD_C_BINDING}) + target_precompile_headers (${target} REUSE_FROM graph_c) + else () + target_precompile_headers (${target} REUSE_FROM xrays) + endif () endif () endmacro () diff --git a/graph_benchmark/CMakeLists.txt b/graph_benchmark/CMakeLists.txt index 7f1430e522f6da280eb718d75152a0e2aed175dc..b794d9addbe6716ccb31e744b27d8c75d6a30b1c 100644 --- a/graph_benchmark/CMakeLists.txt +++ b/graph_benchmark/CMakeLists.txt @@ -1,5 +1,5 @@ -add_tool_target (xrays_bench) +add_tool_target (xrays_bench cpp) -if (${USE_PCH}) +if (${USE_PCH} AND NOT ${BUILD_C_BINDING}) target_precompile_headers (xrays_bench REUSE_FROM xrays) endif () diff --git a/graph_c_binding/CMakeLists.txt b/graph_c_binding/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..84816fcf8406a6442e02b6c73bad1b8cc40f9360 --- /dev/null +++ b/graph_c_binding/CMakeLists.txt @@ -0,0 +1,24 @@ +add_library (graph_c) + +target_include_directories (graph_c + PUBLIC + $ +) +target_link_libraries (graph_c + PUBLIC + graph_framework +) +target_sources (graph_c + PRIVATE + $ +) +if (${USE_METAL}) + set_source_files_properties (${CMAKE_CURRENT_SOURCE_DIR}/graph_c_binding.cpp + PROPERTIES + LANGUAGE OBJCXX + ) +endif () +target_compile_features (graph_c + PUBLIC + c_std_17 +) diff --git a/graph_c_binding/graph_c_binding.cpp b/graph_c_binding/graph_c_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f006a2fc2543fd54523e1fb2e23484696c3b93b --- /dev/null +++ b/graph_c_binding/graph_c_binding.cpp @@ -0,0 +1,3261 @@ +//------------------------------------------------------------------------------ +/// @file graph_c_binding.cpp +/// @brief Implimentation of the c binding library. +//------------------------------------------------------------------------------ + +#include "graph_c_binding.h" + +#include "../graph_framework/register.hpp" +#include "../graph_framework/node.hpp" +#include "../graph_framework/workflow.hpp" +#include "../graph_framework/arithmetic.hpp" +#include "../graph_framework/math.hpp" +#include "../graph_framework/trigonometry.hpp" +#include "../graph_framework/piecewise.hpp" + +//------------------------------------------------------------------------------ +/// @brief C context with specific type. +//------------------------------------------------------------------------------ +template +struct graph_c_context_type : public graph_c_context { +/// Variables nodes. + std::map> nodes; +/// Workflow manager. + workflow::manager work; + +//------------------------------------------------------------------------------ +/// @brief Construct a typed c context. +//------------------------------------------------------------------------------ + graph_c_context_type() : work(0) {} +}; + +extern "C" { +//------------------------------------------------------------------------------ +/// @brief Construct a C context. +/// +/// @param[in] type Base type. +/// @param[in] use_safe_math Control is safe math is used. +/// @returns A contructed C context. +//------------------------------------------------------------------------------ + graph_c_context *graph_construct_context(const enum graph_type type, + const bool use_safe_math) { + graph_c_context *temp; + switch (type) { + case FLOAT: + if (use_safe_math) { + temp = new graph_c_context_type (); + } else { + temp = new graph_c_context_type (); + } + break; + + case DOUBLE: + if (use_safe_math) { + temp = new graph_c_context_type (); + } else { + temp = new graph_c_context_type (); + } + break; + + case COMPLEX_FLOAT: + if (use_safe_math) { + temp = new graph_c_context_type, true> (); + } else { + temp = new graph_c_context_type> (); + } + break; + + case COMPLEX_DOUBLE: + if (use_safe_math) { + temp = new graph_c_context_type, true> (); + } else { + temp = new graph_c_context_type> (); + } + break; + } + + temp->type = type; + temp->safe_math = use_safe_math; + return temp; + } + +//------------------------------------------------------------------------------ +/// @brief Destroy C context. +/// +/// @param[in,out] c The c context to delete. +//------------------------------------------------------------------------------ + void graph_destroy_context(graph_c_context *c) { + delete c; + } + +//------------------------------------------------------------------------------ +/// @brief Create variable node. +/// +/// @param[in] c The graph C context. +/// @param[in] size Size of the data buffer. +/// @param[in] symbol Symbol of the variable used in equations. +/// @returns The created variable. +//------------------------------------------------------------------------------ + graph_node graph_variable(STRUCT_TAG graph_c_context *c, + const size_t size, + const char *symbol) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::variable, true> (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::variable> (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::variable, true> (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::variable> (size, symbol); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create constant node. +/// +/// @param[in] c The graph C context. +/// @param[in] value The value to create the constant. +/// @returns The created constant. +//------------------------------------------------------------------------------ + graph_node graph_constant(STRUCT_TAG graph_c_context *c, + const double value) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::constant (value); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::constant (value); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::constant (value); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::constant (value); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::constant, true> (std::complex (value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::constant> (std::complex (value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::constant, true> (std::complex (value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::constant> (std::complex (value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Set a variable value. +/// +/// @param[in] c The graph C context. +/// @param[in] var The variable to set. +/// @param[in] source The source pointer. +//------------------------------------------------------------------------------ + void graph_set_variable(STRUCT_TAG graph_c_context *c, + graph_node var, + const void *source) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(float)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(float)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(double)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(double)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(std::complex)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(std::complex)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(std::complex)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::variable_cast(d->nodes[var]); + if (temp.get()) { + std::memcpy(temp->data(), source, sizeof(std::complex)*temp->size()); + } else { + std::cerr << "Node is not a variable."; + } + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Create complex constant node. +/// +/// @param[in] c The graph C context. +/// @param[in] real_value The real component. +/// @param[in] img_value The imaginary component. +/// @returns The complex constant. +//------------------------------------------------------------------------------ + graph_node graph_constant_c(STRUCT_TAG graph_c_context *c, + const double real_value, + const double img_value) { + switch (c->type) { + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::constant, true> (std::complex (real_value, img_value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::constant> (std::complex (real_value, img_value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::constant, true> (std::complex (real_value, img_value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::constant> (std::complex (real_value, img_value)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case FLOAT: + case DOUBLE: + std::cerr << "Error: Context is non-complex." << std::endl; + exit(1); + } + } + +//------------------------------------------------------------------------------ +/// @brief Create a pseudo variable. +/// +/// @param[in] c The graph C context. +/// @param[in] var The variable to set. +/// @returns THe pseudo variable. +//------------------------------------------------------------------------------ + graph_node graph_pseudo_variable(STRUCT_TAG graph_c_context *c, + graph_node var) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::pseudo_variable (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::pseudo_variable (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::pseudo_variable (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::pseudo_variable (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::pseudo_variable, true> (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::pseudo_variable> (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::pseudo_variable, true> (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::pseudo_variable> (d->nodes[var]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Remove pseudo. +/// +/// @param[in] c The graph C context. +/// @param[in] var The variable to set. +/// @returns The graph with pseudo variables removed. +//------------------------------------------------------------------------------ + graph_node graph_remove_pseudo(STRUCT_TAG graph_c_context *c, + graph_node var) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[var]->remove_pseudo(); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//****************************************************************************** +// Arithmetic +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Create add node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left + right +//------------------------------------------------------------------------------ + graph_node graph_add(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left] + d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create Substract node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left - right +//------------------------------------------------------------------------------ + graph_node graph_sub(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left] - d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create Multiply node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left*right +//------------------------------------------------------------------------------ + graph_node graph_mul(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left]*d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create Divide node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left/right +//------------------------------------------------------------------------------ + graph_node graph_div(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = d->nodes[left]/d->nodes[right]; + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//****************************************************************************** +// Math +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Create Sqrt node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns sqrt(arg) +//------------------------------------------------------------------------------ + graph_node graph_sqrt(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::sqrt(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create exp node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns exp(arg) +//------------------------------------------------------------------------------ + graph_node graph_exp(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::exp(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create log node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns log(arg) +//------------------------------------------------------------------------------ + graph_node graph_log(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::log(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create Pow node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns pow(left, right) +//------------------------------------------------------------------------------ + graph_node graph_pow(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::pow(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create imaginary error function node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns erfi(arg) +//------------------------------------------------------------------------------ + graph_node graph_erfi(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::erfi(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::erfi(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::erfi(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::erfi(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case FLOAT: + case DOUBLE: + std::cerr << "Error: Imaginary error function requires complex context." << std::endl; + exit(1); + } + } + +//****************************************************************************** +// Trigonometry +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Create sine node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns sin(arg) +//------------------------------------------------------------------------------ + graph_node graph_sin(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::sin(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create cosine node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The left opperand. +/// @returns sin(arg) +//------------------------------------------------------------------------------ + graph_node graph_cos(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::cos(d->nodes[arg]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create arctangent node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns atan(left, right) +//------------------------------------------------------------------------------ + graph_node graph_atan(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::atan (d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::atan(d->nodes[left], d->nodes[right]); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//****************************************************************************** +// Random +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Construct a random state node. +/// +/// @param[in] c The graph C context. +/// @param[in] seed Intial random seed. +/// @returns A random state node. +//------------------------------------------------------------------------------ + graph_node graph_random_state(STRUCT_TAG graph_c_context *c, + const uint32_t seed) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::random_state (jit::context::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::random_state (jit::context::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::random_state (jit::context::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::random_state (jit::context::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::random_state, true> (jit::context, true>::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::random_state> (jit::context>::random_state_size, seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::random_state, true> (jit::context, true>::random_state_size, + seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::random_state> (jit::context>::random_state_size, seed); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create random node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg A random state node. +/// @returns random(state) +//------------------------------------------------------------------------------ + graph_node graph_random(STRUCT_TAG graph_c_context *c, + graph_node arg) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } else { + auto d = reinterpret_cast *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } else { + auto d = reinterpret_cast *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } else { + auto d = reinterpret_cast> *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } else { + auto d = reinterpret_cast> *> (c); + auto state = graph::random_state_cast(d->nodes[arg]); + if (state.get()) { + auto temp = graph::random(state); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + std::cerr << "Arg failed cast to state." << std::endl; + exit(1); + } + } + } + } + +//****************************************************************************** +// Piecewise +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Create 1D piecewise node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @param[in] scale Scale factor argument. +/// @param[in] offset Offset factor argument. +/// @param[in] source Source buffer to fill elements. +/// @param[in] source_size Number of elements in the source buffer. +/// @returns A 1D piecewise node. +//------------------------------------------------------------------------------ + graph_node graph_piecewise_1D(STRUCT_TAG graph_c_context *c, + graph_node arg, + const double scale, + const double offset, + const void *source, + const size_t source_size) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(float)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + static_cast (scale), + static_cast (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(float)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + static_cast (scale), + static_cast (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(double)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], scale, offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(double)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], scale, offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + std::complex (scale), + std::complex (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + std::complex (scale), + std::complex (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + std::complex (scale), + std::complex (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_1D(buffer, d->nodes[arg], + std::complex (scale), + std::complex (offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//------------------------------------------------------------------------------ +/// @brief Create 2D piecewise node. +/// +/// @param[in] c The graph C context. +/// @param[in] num_cols Number of columns. +/// @param[in] x_arg The function x argument. +/// @param[in] x_scale Scale factor x argument. +/// @param[in] x_offset Offset factor x argument. +/// @param[in] y_arg The function y argument. +/// @param[in] y_scale Scale factor y argument. +/// @param[in] y_offset Offset factor y argument. +/// @param[in] source Source buffer to fill elements. +/// @param[in] source_size Number of elements in the source buffer. +/// @returns A 2D piecewise node. +//------------------------------------------------------------------------------ + graph_node graph_piecewise_2D(STRUCT_TAG graph_c_context *c, + const size_t num_cols, + graph_node x_arg, + const double x_scale, + const double x_offset, + graph_node y_arg, + const double y_scale, + const double y_offset, + const void *source, + const size_t source_size) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(float)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + static_cast (x_scale), + static_cast (x_offset), + d->nodes[y_arg], + static_cast (y_scale), + static_cast (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(float)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + static_cast (x_scale), + static_cast (x_offset), + d->nodes[y_arg], + static_cast (y_scale), + static_cast (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(double)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], y_scale, y_offset, + d->nodes[y_arg], y_scale, y_offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + backend::buffer buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(double)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], y_scale, y_offset, + d->nodes[y_arg], y_scale, y_offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset), + d->nodes[y_arg], + std::complex (y_scale), + std::complex (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset), + d->nodes[y_arg], + std::complex (y_scale), + std::complex (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset), + d->nodes[y_arg], + std::complex (y_scale), + std::complex (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + backend::buffer> buffer(source_size); + std::memcpy(buffer.data(), source, sizeof(std::complex)*source_size); + auto temp = graph::piecewise_2D(buffer, num_cols, + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset), + d->nodes[y_arg], + std::complex (y_scale), + std::complex (y_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + +//****************************************************************************** +// JIT +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Create 2D piecewise node with complex arguments. +/// +/// @param[in] c The graph C context. +/// @returns The number of concurrent devices. +//------------------------------------------------------------------------------ + size_t graph_get_max_concurrency(graph_c_context *c) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + return jit::context::max_concurrency(); + } else { + return jit::context::max_concurrency(); + } + + case DOUBLE: + if (c->safe_math) { + return jit::context::max_concurrency(); + } else { + return jit::context::max_concurrency(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + return jit::context, true>::max_concurrency(); + } else { + return jit::context>::max_concurrency(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + return jit::context, true>::max_concurrency(); + } else { + return jit::context>::max_concurrency(); + } + } + } + +//****************************************************************************** +// Workflows +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Choose the device number. +/// +/// @param[in] c The graph C context. +/// @param[in] num The device number. +//------------------------------------------------------------------------------ + void graph_set_device_number(STRUCT_TAG graph_c_context *c, + const size_t num) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work = workflow::manager (num); + } else { + auto d = reinterpret_cast *> (c); + d->work = workflow::manager (num); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work = workflow::manager (num); + } else { + auto d = reinterpret_cast *> (c); + d->work = workflow::manager (num); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work = workflow::manager, true> (num); + } else { + auto d = reinterpret_cast> *> (c); + d->work = workflow::manager> (num); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work = workflow::manager, true> (num); + } else { + auto d = reinterpret_cast> *> (c); + d->work = workflow::manager> (num); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Add pre workflow item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +//------------------------------------------------------------------------------ + void graph_add_pre_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Preitem input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Preitem map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_preitem(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_preitem(in, out, map, NULL, name, size); + } + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Add workflow item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +//------------------------------------------------------------------------------ + void graph_add_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_item(in, out, map, rand, name, size); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_item(in, out, map, NULL, name, size); + } + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Add a converge item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +/// @param[in] tol Tolarance to converge the function to. +/// @param[in] max_iter Maximum number of iterations before giving up. +//------------------------------------------------------------------------------ + void graph_add_converge_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size, + const double tol, + const size_t max_iter) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } else { + auto d = reinterpret_cast *> (c); + graph::input_nodes in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::input_nodes, true> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes, true> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes, true> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } else { + auto d = reinterpret_cast> *> (c); + graph::input_nodes> in; + for (size_t i = 0; i < num_inputs; i++) { + auto temp = graph::variable_cast(d->nodes[inputs[i]]); + if (temp.get()) { + in.push_back(temp); + } else { + std::cerr << "Work input " << i << " is not a variable." << std::endl; + exit(1); + } + } + graph::output_nodes> out; + for (size_t i = 0; i < num_outputs; i++) { + out.push_back(d->nodes[outputs[i]]); + } + graph::map_nodes> map; + for (size_t i = 0; i < num_maps; i++) { + auto temp = graph::variable_cast(d->nodes[map_inputs[i]]); + if (temp.get()) { + map.push_back({d->nodes[map_outputs[i]], temp}); + } else { + std::cerr << "Work map input " << i << " is not a variable." << std::endl; + exit(1); + } + } + if (random_state) { + auto rand = graph::random_state_cast(d->nodes[random_state]); + if (rand.get()) { + d->work.add_converge_item(in, out, map, rand, name, + size, tol, max_iter); + } else { + std::cerr << "Invalid random state." << std::endl; + exit(1); + } + } else { + d->work.add_converge_item(in, out, map, NULL, name, + size, tol, max_iter); + } + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Compile the work items +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_compile(graph_c_context *c) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.compile(); + } else { + auto d = reinterpret_cast *> (c); + d->work.compile(); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.compile(); + } else { + auto d = reinterpret_cast *> (c); + d->work.compile(); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.compile(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.compile(); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.compile(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.compile(); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Run pre work items. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_pre_run(graph_c_context *c) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.pre_run(); + } else { + auto d = reinterpret_cast *> (c); + d->work.pre_run(); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.pre_run(); + } else { + auto d = reinterpret_cast *> (c); + d->work.pre_run(); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.pre_run(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.pre_run(); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.pre_run(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.pre_run(); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Run work items. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_run(graph_c_context *c) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.run(); + } else { + auto d = reinterpret_cast *> (c); + d->work.run(); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.run(); + } else { + auto d = reinterpret_cast *> (c); + d->work.run(); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.run(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.run(); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.run(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.run(); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Wait for work items to complete. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_wait(graph_c_context *c) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.wait(); + } else { + auto d = reinterpret_cast *> (c); + d->work.wait(); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.wait(); + } else { + auto d = reinterpret_cast *> (c); + d->work.wait(); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.wait(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.wait(); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.wait(); + } else { + auto d = reinterpret_cast> *> (c); + d->work.wait(); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Copy data to a device buffer. +/// +/// @param[in] c The graph C context. +/// @param[in] node Node to copy to. +/// @param[in] source Source to copy from. +//------------------------------------------------------------------------------ + void graph_copy_to_device(STRUCT_TAG graph_c_context *c, + graph_node node, + void *source) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.copy_to_device(d->nodes[node], static_cast (source)); + } else { + auto d = reinterpret_cast *> (c); + d->work.copy_to_device(d->nodes[node], static_cast (source)); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.copy_to_device(d->nodes[node], static_cast (source)); + } else { + auto d = reinterpret_cast *> (c); + d->work.copy_to_device(d->nodes[node], static_cast (source)); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.copy_to_device(d->nodes[node], static_cast *> (source)); + } else { + auto d = reinterpret_cast> *> (c); + d->work.copy_to_device(d->nodes[node], static_cast *> (source)); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.copy_to_device(d->nodes[node], static_cast *> (source)); + } else { + auto d = reinterpret_cast> *> (c); + d->work.copy_to_device(d->nodes[node], static_cast *> (source)); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Copy data to a host buffer. +/// +/// @param[in] c The graph C context. +/// @param[in] node Node to copy from. +/// @param[in] destination Host side buffer to copy to. +//------------------------------------------------------------------------------ + void graph_copy_to_host(STRUCT_TAG graph_c_context *c, + graph_node node, + void *destination) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.copy_to_host(d->nodes[node], static_cast (destination)); + } else { + auto d = reinterpret_cast *> (c); + d->work.copy_to_host(d->nodes[node], static_cast (destination)); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + d->work.copy_to_host(d->nodes[node], static_cast (destination)); + } else { + auto d = reinterpret_cast *> (c); + d->work.copy_to_host(d->nodes[node], static_cast (destination)); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.copy_to_host(d->nodes[node], static_cast *> (destination)); + } else { + auto d = reinterpret_cast> *> (c); + d->work.copy_to_host(d->nodes[node], static_cast *> (destination)); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + d->work.copy_to_host(d->nodes[node], static_cast *> (destination)); + } else { + auto d = reinterpret_cast> *> (c); + d->work.copy_to_host(d->nodes[node], static_cast *> (destination)); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Print a value from nodes. +/// +/// @param[in] c The graph C context. +/// @param[in] index Particle index to print. +/// @param[in] nodes Nodes to print. +/// @param[in] num_nodes Number of nodes. +//------------------------------------------------------------------------------ + void graph_print(STRUCT_TAG graph_c_context *c, + const size_t index, + graph_node *nodes, + const size_t num_nodes) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::output_nodes out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } else { + auto d = reinterpret_cast *> (c); + graph::output_nodes out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } + break; + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + graph::output_nodes out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } else { + auto d = reinterpret_cast *> (c); + graph::output_nodes out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } + break; + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::output_nodes, true> out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } else { + auto d = reinterpret_cast> *> (c); + graph::output_nodes> out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } + break; + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + graph::output_nodes, true> out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } else { + auto d = reinterpret_cast> *> (c); + graph::output_nodes> out; + for (size_t i = 0; i < num_nodes; i++) { + out.push_back(d->nodes[nodes[i]]); + } + d->work.print(index, out); + } + break; + } + } + +//------------------------------------------------------------------------------ +/// @brief Take derivative ∂f∂x. +/// +/// @param[in] c The graph C context. +/// @param[in] fnode The function expression to take the derivative of. +/// @param[in] xnode The expression to take the derivative with respect to. +//------------------------------------------------------------------------------ + graph_node graph_df(STRUCT_TAG graph_c_context *c, + graph_node fnode, + graph_node xnode) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } else { + auto d = reinterpret_cast *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } else { + auto d = reinterpret_cast *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto dfdx = d->nodes[fnode]->df(d->nodes[xnode]); + d->nodes[dfdx.get()] = dfdx; + return dfdx.get(); + } + } + } +} diff --git a/graph_c_binding/graph_c_binding.h b/graph_c_binding/graph_c_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..6cfd25891c2fe223a97bc33ab02fb71d5977f937 --- /dev/null +++ b/graph_c_binding/graph_c_binding.h @@ -0,0 +1,498 @@ +//------------------------------------------------------------------------------ +/// @file graph_c_binding.h +/// @brief Header file for the c binding library. +//------------------------------------------------------------------------------ + +#ifndef graph_c_binding_h +#define graph_c_binding_h + +#include +#include +#include + +#ifdef USE_METAL +#define START_GPU @autoreleasepool { +#define END_GPU } +#else +#define START_GPU +#define END_GPU +#endif + +#ifdef __cplusplus +extern "C" { +#define STRUCT_TAG +#else +#define STRUCT_TAG struct +#endif +/// Graph node type for C interface. + typedef void * graph_node; + +//------------------------------------------------------------------------------ +/// @brief Graph base type. +//------------------------------------------------------------------------------ + enum graph_type { + FLOAT, + DOUBLE, + COMPLEX_FLOAT, + COMPLEX_DOUBLE + }; + +//------------------------------------------------------------------------------ +/// @brief graph_c_context type. +//------------------------------------------------------------------------------ + struct graph_c_context { +/// Type of the context. + enum graph_type type; +/// Uses safe math. + bool safe_math; + }; + +//------------------------------------------------------------------------------ +/// @brief Construct a C context. +/// +/// @param[in] type Base type. +/// @param[in] use_safe_math Control is safe math is used. +/// @returns A contructed C context. +//------------------------------------------------------------------------------ + STRUCT_TAG graph_c_context *graph_construct_context(const enum graph_type type, + const bool use_safe_math); + +//------------------------------------------------------------------------------ +/// @brief Destroy C context. +/// +/// @param[in,out] c The c context to delete. +//------------------------------------------------------------------------------ + void graph_destroy_context(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Create variable node. +/// +/// @param[in] c The graph C context. +/// @param[in] size Size of the data buffer. +/// @param[in] symbol Symbol of the variable used in equations. +/// @returns The created variable. +//------------------------------------------------------------------------------ + graph_node graph_variable(STRUCT_TAG graph_c_context *c, + const size_t size, + const char *symbol); + +//------------------------------------------------------------------------------ +/// @brief Create constant node. +/// +/// @param[in] c The graph C context. +/// @param[in] value The value to create the constant. +/// @returns The created constant. +//------------------------------------------------------------------------------ + graph_node graph_constant(STRUCT_TAG graph_c_context *c, + const double value); + +//------------------------------------------------------------------------------ +/// @brief Set a variable value. +/// +/// @param[in] c The graph C context. +/// @param[in] var The variable to set. +/// @param[in] source The source pointer. +//------------------------------------------------------------------------------ + void graph_set_variable(STRUCT_TAG graph_c_context *c, + graph_node var, + const void *source); + +//------------------------------------------------------------------------------ +/// @brief Create complex constant node. +/// +/// @param[in] c The graph C context. +/// @param[in] real_value The real component. +/// @param[in] img_value The imaginary component. +/// @returns The complex constant. +//------------------------------------------------------------------------------ + graph_node graph_constant_c(STRUCT_TAG graph_c_context *c, + const double real_value, + const double img_value); + +//------------------------------------------------------------------------------ +/// @brief Create a pseudo variable. +/// +/// @param[in] c The graph C context. +/// @param[in] var The variable to set. +/// @returns The pseudo variable. +//------------------------------------------------------------------------------ + graph_node graph_pseudo_variable(STRUCT_TAG graph_c_context *c, + graph_node var); + +//------------------------------------------------------------------------------ +/// @brief Remove pseudo. +/// +/// @param[in] c The graph C context. +/// @param[in] var The graph to remove pseudo variables. +/// @returns The graph with pseudo variables removed. +//------------------------------------------------------------------------------ + graph_node graph_remove_pseudo(STRUCT_TAG graph_c_context *c, + graph_node var); + +//------------------------------------------------------------------------------ +/// @brief Create Addition node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left + right +//------------------------------------------------------------------------------ + graph_node graph_add(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Create Substract node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left - right +//------------------------------------------------------------------------------ + graph_node graph_sub(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Create Multiply node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left*right +//------------------------------------------------------------------------------ + graph_node graph_mul(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Create Divide node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns left/right +//------------------------------------------------------------------------------ + graph_node graph_div(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Create Sqrt node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns sqrt(arg) +//------------------------------------------------------------------------------ + graph_node graph_sqrt(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create exp node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns exp(arg) +//------------------------------------------------------------------------------ + graph_node graph_exp(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create log node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns log(arg) +//------------------------------------------------------------------------------ + graph_node graph_log(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create Pow node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns pow(left, right) +//------------------------------------------------------------------------------ + graph_node graph_pow(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Create imaginary error function node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns erfi(arg) +//------------------------------------------------------------------------------ + graph_node graph_erfi(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create sine node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns sin(arg) +//------------------------------------------------------------------------------ + graph_node graph_sin(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create cosine node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @returns sin(arg) +//------------------------------------------------------------------------------ + graph_node graph_cos(STRUCT_TAG graph_c_context *c, + graph_node arg); + +//------------------------------------------------------------------------------ +/// @brief Create arctangent node. +/// +/// @param[in] c The graph C context. +/// @param[in] left The left opperand. +/// @param[in] right The right opperand. +/// @returns atan(left, right) +//------------------------------------------------------------------------------ + graph_node graph_atan(STRUCT_TAG graph_c_context *c, + graph_node left, + graph_node right); + +//------------------------------------------------------------------------------ +/// @brief Construct a random state node. +/// +/// @param[in] c The graph C context. +/// @param[in] seed Intial random seed. +/// @returns A random state node. +//------------------------------------------------------------------------------ + graph_node graph_random_state(STRUCT_TAG graph_c_context *c, + const uint32_t seed); + +//------------------------------------------------------------------------------ +/// @brief Create random node. +/// +/// @param[in] c The graph C context. +/// @param[in] state A random state node. +/// @returns random(state) +//------------------------------------------------------------------------------ + graph_node graph_random(STRUCT_TAG graph_c_context *c, + graph_node state); + +//------------------------------------------------------------------------------ +/// @brief Create 1D piecewise node. +/// +/// @param[in] c The graph C context. +/// @param[in] arg The function argument. +/// @param[in] scale Scale factor argument. +/// @param[in] offset Offset factor argument. +/// @param[in] source Source buffer to fill elements. +/// @param[in] source_size Number of elements in the source buffer. +/// @returns A 1D piecewise node. +//------------------------------------------------------------------------------ + graph_node graph_piecewise_1D(STRUCT_TAG graph_c_context *c, + graph_node arg, + const double scale, + const double offset, + const void *source, + const size_t source_size); + +//------------------------------------------------------------------------------ +/// @brief Create 2D piecewise node. +/// +/// @param[in] c The graph C context. +/// @param[in] num_cols Number of columns. +/// @param[in] x_arg The function x argument. +/// @param[in] x_scale Scale factor x argument. +/// @param[in] x_offset Offset factor x argument. +/// @param[in] y_arg The function y argument. +/// @param[in] y_scale Scale factor y argument. +/// @param[in] y_offset Offset factor y argument. +/// @param[in] source Source buffer to fill elements. +/// @param[in] source_size Number of elements in the source buffer. +/// @returns A 2D piecewise node. +//------------------------------------------------------------------------------ + graph_node graph_piecewise_2D(STRUCT_TAG graph_c_context *c, + const size_t num_cols, + graph_node x_arg, + const double x_scale, + const double x_offset, + graph_node y_arg, + const double y_scale, + const double y_offset, + const void *source, + const size_t source_size); + +//------------------------------------------------------------------------------ +/// @brief Create 2D piecewise node with complex arguments. +/// +/// @param[in] c The graph C context. +/// @returns The number of concurrent devices. +//------------------------------------------------------------------------------ + size_t graph_get_max_concurrency(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Choose the device number. +/// +/// @param[in] c The graph C context. +/// @param[in] num The device number. +//------------------------------------------------------------------------------ + void graph_set_device_number(STRUCT_TAG graph_c_context *c, + const size_t num); + +//------------------------------------------------------------------------------ +/// @brief Add pre workflow item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +//------------------------------------------------------------------------------ + void graph_add_pre_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size); + +//------------------------------------------------------------------------------ +/// @brief Add workflow item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +//------------------------------------------------------------------------------ + void graph_add_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size); + +//------------------------------------------------------------------------------ +/// @brief Add a converge item. +/// +/// @param[in] c The graph C context. +/// @param[in] inputs Array of input nodes. +/// @param[in] num_inputs Number of inputs. +/// @param[in] outputs Array of output nodes. +/// @param[in] num_outputs Number of outputs. +/// @param[in] map_inputs Array of map input nodes. +/// @param[in] map_outputs Array of map output nodes. +/// @param[in] num_maps Number of maps. +/// @param[in] random_state Optional random state, can be NULL if not used. +/// @param[in] name Name for the kernel. +/// @param[in] size Number of elements to operate on. +/// @param[in] tol Tolarance to converge the function to. +/// @param[in] max_iter Maximum number of iterations before giving up. +//------------------------------------------------------------------------------ + void graph_add_converge_item(STRUCT_TAG graph_c_context *c, + graph_node *inputs, size_t num_inputs, + graph_node *outputs, size_t num_outputs, + graph_node *map_inputs, + graph_node *map_outputs, size_t num_maps, + graph_node random_state, + const char *name, + const size_t size, + const double tol, + const size_t max_iter); + +//------------------------------------------------------------------------------ +/// @brief Compile the work items. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_compile(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Run pre work items. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_pre_run(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Run work items. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_run(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Wait for work items to complete. +/// +/// @param[in] c The graph C context. +//------------------------------------------------------------------------------ + void graph_wait(STRUCT_TAG graph_c_context *c); + +//------------------------------------------------------------------------------ +/// @brief Copy data to a device buffer. +/// +/// @param[in] c The graph C context. +/// @param[in] node Node to copy to. +/// @param[in] source Source to copy from. +//------------------------------------------------------------------------------ + void graph_copy_to_device(STRUCT_TAG graph_c_context *c, + graph_node node, + void *source); + +//------------------------------------------------------------------------------ +/// @brief Copy data to a host buffer. +/// +/// @param[in] c The graph C context. +/// @param[in] node Node to copy from. +/// @param[in] destination Host side buffer to copy to. +//------------------------------------------------------------------------------ + void graph_copy_to_host(STRUCT_TAG graph_c_context *c, + graph_node node, + void *destination); + +//------------------------------------------------------------------------------ +/// @brief Print a value from nodes. +/// +/// @param[in] c The graph C context. +/// @param[in] index Particle index to print. +/// @param[in] nodes Nodes to print. +/// @param[in] num_nodes Number of nodes. +//------------------------------------------------------------------------------ + void graph_print(STRUCT_TAG graph_c_context *c, + const size_t index, + graph_node *nodes, + const size_t num_nodes); + +//------------------------------------------------------------------------------ +/// @brief Take derivative ∂f∂x. +/// +/// @param[in] c The graph C context. +/// @param[in] fnode The function expression to take the derivative of. +/// @param[in] xnode The expression to take the derivative with respect to. +//------------------------------------------------------------------------------ + graph_node graph_df(STRUCT_TAG graph_c_context *c, + graph_node fnode, + graph_node xnode); +#ifdef __cplusplus +} +#endif + +#endif /* graph_c_binding_h */ diff --git a/graph_driver/CMakeLists.txt b/graph_driver/CMakeLists.txt index a4d81912bc5c52483ec74a44ab25277dfad80fae..c7e4df683fbffb46b117b821b3df377432f28dd8 100644 --- a/graph_driver/CMakeLists.txt +++ b/graph_driver/CMakeLists.txt @@ -1,4 +1,4 @@ -add_tool_target (xrays) +add_tool_target (xrays cpp) add_test ( NAME xrays_test diff --git a/graph_fortran_binding/CMakeLists.txt b/graph_fortran_binding/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5e249069413074a668c861dbc490b067ef05f16f --- /dev/null +++ b/graph_fortran_binding/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library (graph_f) + +target_include_directories (graph_f + PUBLIC + $ +) +target_link_libraries (graph_f + PUBLIC + graph_c +) +target_compile_options (graph_f + PUBLIC + $<$:-cpp> +) +target_sources (graph_f + PRIVATE + $ +) + diff --git a/graph_fortran_binding/graph_fortran_binding.f90 b/graph_fortran_binding/graph_fortran_binding.f90 new file mode 100644 index 0000000000000000000000000000000000000000..bc539bc0e9f83ab5ee2427bd9691d1a07c87d25f --- /dev/null +++ b/graph_fortran_binding/graph_fortran_binding.f90 @@ -0,0 +1,2228 @@ +!------------------------------------------------------------------------------- +!> @file graph_fortran_binding.f90 +!> @brief Implimentation of the Fortran binding library. +! +! Note separating the Doxygen comment block here so the detailed description is +! found in the Module not the file. +! +!> Module contains subroutines for calling this from fortran. +!------------------------------------------------------------------------------- + MODULE graph_fortran + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +!> A null array for empty + INTEGER(C_INTPTR_T), DIMENSION(0) :: graph_null_array +!> A + +!------------------------------------------------------------------------------- +!> @brief Class object for the binding. +!------------------------------------------------------------------------------- + TYPE :: graph_context +#ifdef USE_METAL +!> The auto release pool context. + TYPE(C_PTR) :: arp_context +#endif +!> The graph c context. + TYPE(C_PTR) :: c_context + CONTAINS + FINAL :: graph_destruct + PROCEDURE :: variable => graph_context_variable + PROCEDURE :: constant_real => graph_context_constant_real + PROCEDURE :: constant_complex => graph_context_constant_complex + GENERIC :: constant => constant_real, constant_complex + PROCEDURE :: set_variable_float => graph_context_set_variable_float + PROCEDURE :: set_variable_double => graph_context_set_variable_double + PROCEDURE :: set_variable_cfloat => graph_context_set_variable_cfloat + PROCEDURE :: set_variable_cdouble => graph_context_set_variable_cdouble + GENERIC :: set_variable => set_variable_float, & + set_variable_double, & + set_variable_cfloat, & + set_variable_cdouble + PROCEDURE :: pseudo_variable => graph_context_pseudo_variable + PROCEDURE :: remove_pseudo => graph_context_remove_pseudo + PROCEDURE :: add => graph_context_add + PROCEDURE :: sub => graph_context_sub + PROCEDURE :: mul => graph_context_mul + PROCEDURE :: div => graph_context_div + PROCEDURE :: sqrt => graph_context_sqrt + PROCEDURE :: exp => graph_context_exp + PROCEDURE :: log => graph_context_log + PROCEDURE :: pow => graph_context_pow + PROCEDURE :: erfi => graph_context_erfi + PROCEDURE :: sin => graph_context_sin + PROCEDURE :: cos => graph_context_cos + PROCEDURE :: atan => graph_context_atan + PROCEDURE :: random_state => graph_context_random_state + PROCEDURE :: random => graph_context_random + PROCEDURE :: piecewise_1D_float => graph_context_piecewise_1D_float + PROCEDURE :: piecewise_1D_double => graph_context_piecewise_1D_double + PROCEDURE :: piecewise_1D_cfloat => graph_context_piecewise_1D_cfloat + PROCEDURE :: piecewise_1D_cdouble => & + graph_context_piecewise_1D_cdouble + GENERIC :: piecewise_1D => piecewise_1D_float, & + piecewise_1D_double, & + piecewise_1D_cfloat, & + piecewise_1D_cdouble + PROCEDURE :: piecewise_2D_float => graph_context_piecewise_2D_float + PROCEDURE :: piecewise_2D_double => graph_context_piecewise_2D_double + PROCEDURE :: piecewise_2D_cfloat => graph_context_piecewise_2D_cfloat + PROCEDURE :: piecewise_2D_cdouble => graph_context_piecewise_2D_cdouble + GENERIC :: piecewise_2D => piecewise_2D_float, & + piecewise_2D_double, & + piecewise_2D_cfloat, & + piecewise_2D_cdouble + PROCEDURE :: get_max_concurrency => graph_context_get_max_concurrency + PROCEDURE :: set_device_number => graph_context_set_device_number + PROCEDURE :: add_pre_item => graph_context_add_pre_item + PROCEDURE :: add_item => graph_context_add_item + PROCEDURE :: add_converge_item => graph_context_add_converge_item + PROCEDURE :: df => graph_context_df + PROCEDURE :: compile => graph_context_compile + PROCEDURE :: pre_run => graph_context_pre_run + PROCEDURE :: run => graph_context_run + PROCEDURE :: wait => graph_context_wait + PROCEDURE :: copy_to_device_float => graph_context_copy_to_device_float + PROCEDURE :: copy_to_device_double => & + graph_context_copy_to_device_double + PROCEDURE :: copy_to_device_cfloat => & + graph_context_copy_to_device_cfloat + PROCEDURE :: copy_to_device_cdouble => & + graph_context_copy_to_device_cdouble + GENERIC :: copy_to_device => copy_to_device_float, & + copy_to_device_double, & + copy_to_device_cfloat, & + copy_to_device_cdouble + PROCEDURE :: copy_to_host_float => graph_context_copy_to_host_float + PROCEDURE :: copy_to_host_double => graph_context_copy_to_host_double + PROCEDURE :: copy_to_host_cfloat => graph_context_copy_to_host_cfloat + PROCEDURE :: copy_to_host_cdouble => graph_context_copy_to_host_cdouble + GENERIC :: copy_to_host => copy_to_host_float, & + copy_to_host_double, & + copy_to_host_cfloat, & + copy_to_host_cdouble + PROCEDURE :: print => graph_context_print + END TYPE + +!******************************************************************************* +! ENUMERATED TYPES +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief +!------------------------------------------------------------------------------- + ENUM, BIND(C) + ENUMERATOR :: FLOAT_T + ENUMERATOR :: DOUBLE_T + ENUMERATOR :: COMPLEX_FLOAT_T + ENUMERATOR :: COMPLEX_DOUBLE_T + END ENUM + +!******************************************************************************* +! INTERFACE BLOCKS +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Interface for the graph_context constructor with float type. +!------------------------------------------------------------------------------- + INTERFACE graph_float_context + MODULE PROCEDURE graph_construct_float + END INTERFACE + +!------------------------------------------------------------------------------- +!> @brief Interface for the graph_context constructor with double type. +!------------------------------------------------------------------------------- + INTERFACE graph_double_context + MODULE PROCEDURE graph_construct_double + END INTERFACE + +!------------------------------------------------------------------------------- +!> @brief Interface for the graph_context constructor with complex float type. +!------------------------------------------------------------------------------- + INTERFACE graph_complex_float_context + MODULE PROCEDURE graph_construct_complex_float + END INTERFACE + +!------------------------------------------------------------------------------- +!> @brief Interface for the graph_context constructor with complex double type. +!------------------------------------------------------------------------------- + INTERFACE graph_complex_double_context + MODULE PROCEDURE graph_construct_complex_double + END INTERFACE + +!******************************************************************************* +! C Binding Interface. +!******************************************************************************* + INTERFACE +!------------------------------------------------------------------------------- +!> @brief Auto release pool push interface. +!> +!> @returns An auto release pool context. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION objc_autoreleasePoolPush() & + BIND(C, NAME='objc_autoreleasePoolPush') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Auto release pool pop interface. +!> +!> @param[in,out] ctx Auto Release pool context. +!------------------------------------------------------------------------------- + SUBROUTINE objc_autoreleasePoolPop(ctx) & + BIND(C, NAME='objc_autoreleasePoolPop') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: ctx + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Construct a C context. +!> +!> @param[in] c_type The type of the context @ref graph_type. +!> @param[in] use_safe_math C context uses safemath. +!> @returns The constructed C context. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_construct_context(c_type, use_safe_math) & + BIND(C, NAME='graph_construct_context') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + INTEGER(C_INT), VALUE :: c_type + LOGICAL(C_BOOL), VALUE :: use_safe_math + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Destroy C context. +!> +!> @param[in] c The c context to delete. +!------------------------------------------------------------------------------- + SUBROUTINE graph_destroy_context(c) & + BIND(C, NAME='graph_destroy_context') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Create a variable node. +!> +!> @param[in,out] c The c context. +!> @param[in] size Size of the data buffer. +!> @param[in] symbol Symbol of the variable used in equations. +!> @returns The created variable. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_variable(c, size, symbol) & + BIND(C, NAME='graph_variable') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_LONG), VALUE :: size + CHARACTER(kind=C_CHAR), DIMENSION(*) :: symbol + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create a constant node. +!> +!> @param[in,out] c The c context. +!> @param[in] value Value of the constant. +!> @returns The created constant. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_constant(c, value) & + BIND(C, NAME='graph_constant') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + REAL(C_DOUBLE), VALUE :: value + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Set a variable value. +!> +!> @param[in,out] c The c context. +!> @param[in] var Variable to set. +!> @param[in] value The buffer to the variable with. +!------------------------------------------------------------------------------- + SUBROUTINE graph_set_variable(c, var, value) & + BIND(C, NAME='graph_set_variable') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: var + INTEGER(C_INTPTR_T), VALUE :: value + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Create a constant node with complex values. +!> +!> @param[in] c The graph C context. +!> @param[in] real_value The real component. +!> @param[in] img_value The imaginary component. +!> @returns The complex constant. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_constant_c(c, real_value, img_value) & + BIND(C, NAME='graph_constant_c') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + REAL(C_DOUBLE), VALUE :: real_value + REAL(C_DOUBLE), VALUE :: img_value + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create a pseudo variable node. +!> +!> @param[in] c The graph C context. +!> @param[in] var The variable to set. +!> @returns The pseudo variable. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_pseudo_variable(c, var) & + BIND(C, NAME='graph_pseudo_variable') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: var + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Remove pseudo. +!> +!> @param[in] c The graph C context. +!> @param[in] var The graph to remove pseudo variables. +!> @returns The graph with pseudo variables removed. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_remove_pseudo(c, var) & + BIND(C, NAME='graph_remove_pseudo') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: var + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Addition node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns left + right +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_add(c, left, right) & + BIND(C, NAME='graph_add') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Substract node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns left - right +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_sub(c, left, right) & + BIND(C, NAME='graph_sub') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Multiply node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns left*right +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_mul(c, left, right) & + BIND(C, NAME='graph_mul') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Divide node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns left/right +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_div(c, left, right) & + BIND(C, NAME='graph_div') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Sqrt node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns sqrt(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_sqrt(c, arg) & + BIND(C, NAME='graph_sqrt') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Exp node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns exp(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_exp(c, arg) & + BIND(C, NAME='graph_exp') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Log node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns log(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_log(c, arg) & + BIND(C, NAME='graph_log') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create pow node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns pow(left, right) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_pow(c, left, right) & + BIND(C, NAME='graph_pow') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Erfi node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns erfi(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_erfi(c, arg) & + BIND(C, NAME='graph_erfi') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Sine node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns sin(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_sin(c, arg) & + BIND(C, NAME='graph_sin') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Cosine node. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The function argument. +!> @returns cos(arg) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_cos(c, arg) & + BIND(C, NAME='graph_cos') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create atan node. +!> +!> @param[in] c The graph C context. +!> @param[in] left The left opperand. +!> @param[in] right The right opperand. +!> @returns pow(left, right) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_atan(c, left, right) & + BIND(C, NAME='graph_atan') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: left + TYPE(C_PTR), VALUE :: right + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Construct a random state node. +!> +!> @param[in] c The graph C context. +!> @param[in] seed Intial random seed. +!> @returns A random state node. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_random_state(c, seed) & + BIND(C, NAME='graph_random_state') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_INT32_T), VALUE :: seed + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Random node. +!> +!> @param[in] c The graph C context. +!> @param[in] state A random state node. +!> @returns random(state) +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_random(c, state) & + BIND(C, NAME='graph_random') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: state + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 1D piecewise node with complex double buffer. +!> +!> @param[in] c The graph C context. +!> @param[in] arg The left opperand. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @param[in] source Source buffer to fill elements. +!> @param[in] source_size Number of elements in the source buffer. +!> @returns A 1D piecewise node. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_piecewise_1D(c, arg, scale, offset, & + source, source_size) & + BIND(C, NAME='graph_piecewise_1D') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: arg + REAL(C_DOUBLE), VALUE :: scale + REAL(C_DOUBLE), VALUE :: offset + INTEGER(C_INTPTR_T), VALUE :: source + INTEGER(C_LONG), VALUE :: source_size + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 2D piecewise node. +!> +!> @param[in] c The graph C context. +!> @param[in] num_cols Number of columns. +!> @param[in] x_arg The function x argument. +!> @param[in] x_scale Scale factor x argument. +!> @param[in] x_offset Offset factor x argument. +!> @param[in] y_arg The function y argument. +!> @param[in] y_scale Scale factor y argument. +!> @param[in] y_offset Offset factor y argument. +!> @param[in] source Source buffer to fill elements. +!> @param[in] source_size Number of elements in the source buffer. +!> @returns A 2D piecewise node. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_piecewise_2D(c, num_cols, & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + source, source_size) & + BIND(C, NAME='graph_piecewise_2D') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_LONG), VALUE :: num_cols + TYPE(C_PTR), VALUE :: x_arg + REAL(C_DOUBLE), VALUE :: x_scale + REAL(C_DOUBLE), VALUE :: x_offset + TYPE(C_PTR), VALUE :: y_arg + REAL(C_DOUBLE), VALUE :: y_scale + REAL(C_DOUBLE), VALUE :: y_offset + INTEGER(C_INTPTR_T), VALUE :: source + INTEGER(C_LONG), VALUE :: source_size + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Get the maximum number of concurrent devices. +!> +!> @param[in] c The graph C context. +!> @returns The number of devices. +!------------------------------------------------------------------------------- + INTEGER(C_LONG) FUNCTION graph_get_max_concurrency(c) & + BIND(C, NAME='graph_get_max_concurrency') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Choose the device number. +!> +!> @param[in] c The graph C context. +!> @param[in] num The device number. +!------------------------------------------------------------------------------- + SUBROUTINE graph_set_device_number(c, num) & + BIND(C, NAME='graph_set_device_number') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_LONG), VALUE :: num + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add pre workflow item. +!> +!> @param[in] c The graph C context. +!> @param[in] inputs Array of input nodes. +!> @param[in] num_inputs Number of inputs. +!> @param[in] outputs Array of output nodes. +!> @param[in] num_outputs Number of outputs. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] num_maps Number of maps. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!------------------------------------------------------------------------------- + SUBROUTINE graph_add_pre_item(c, inputs, num_inputs, & + outputs, num_outputs, & + map_inputs, map_outputs, num_maps, & + random_state, name, num_particles) & + BIND(C, NAME='graph_add_pre_item') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_INTPTR_T), VALUE :: inputs + INTEGER(C_LONG), VALUE :: num_inputs + INTEGER(C_INTPTR_T), VALUE :: outputs + INTEGER(C_LONG), VALUE :: num_outputs + INTEGER(C_INTPTR_T), VALUE :: map_inputs + INTEGER(C_INTPTR_T), VALUE :: map_outputs + INTEGER(C_LONG), VALUE :: num_maps + TYPE(C_PTR), VALUE :: random_state + CHARACTER(kind=C_CHAR), DIMENSION(*) :: name + INTEGER(C_LONG), VALUE :: num_particles + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add workflow item. +!> +!> @param[in] c The graph C context. +!> @param[in] inputs Array of input nodes. +!> @param[in] num_inputs Number of inputs. +!> @param[in] outputs Array of output nodes. +!> @param[in] num_outputs Number of outputs. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] num_maps Number of maps. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!------------------------------------------------------------------------------- + SUBROUTINE graph_add_item(c, inputs, num_inputs, & + outputs, num_outputs, & + map_inputs, map_outputs, num_maps, & + random_state, name, num_particles) & + BIND(C, NAME='graph_add_item') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_INTPTR_T), VALUE :: inputs + INTEGER(C_LONG), VALUE :: num_inputs + INTEGER(C_INTPTR_T), VALUE :: outputs + INTEGER(C_LONG), VALUE :: num_outputs + INTEGER(C_INTPTR_T), VALUE :: map_inputs + INTEGER(C_INTPTR_T), VALUE :: map_outputs + INTEGER(C_LONG), VALUE :: num_maps + TYPE(C_PTR), VALUE :: random_state + CHARACTER(kind=C_CHAR), DIMENSION(*) :: name + INTEGER(C_LONG), VALUE :: num_particles + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add workflow converge item. +!> +!> @param[in] c The graph C context. +!> @param[in] inputs Array of input nodes. +!> @param[in] num_inputs Number of inputs. +!> @param[in] outputs Array of output nodes. +!> @param[in] num_outputs Number of outputs. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] num_maps Number of maps. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!> @param[in] tol Tolarance to converge the function to. +!> @param[in] max_iter Maximum number of iterations before giving up. +!------------------------------------------------------------------------------- + SUBROUTINE graph_add_converge_item(c, inputs, num_inputs, & + outputs, num_outputs, & + map_inputs, map_outputs, num_maps, & + random_state, name, num_particles, & + tol, max_iter) & + BIND(C, NAME='graph_add_converge_item') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_INTPTR_T), VALUE :: inputs + INTEGER(C_LONG), VALUE :: num_inputs + INTEGER(C_INTPTR_T), VALUE :: outputs + INTEGER(C_LONG), VALUE :: num_outputs + INTEGER(C_INTPTR_T), VALUE :: map_inputs + INTEGER(C_INTPTR_T), VALUE :: map_outputs + INTEGER(C_LONG), VALUE :: num_maps + TYPE(C_PTR), VALUE :: random_state + CHARACTER(kind=C_CHAR), DIMENSION(*) :: name + INTEGER(C_LONG), VALUE :: num_particles + REAL(C_DOUBLE), VALUE :: tol + INTEGER(C_LONG), VALUE :: max_iter + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Compile the work items. +!> +!> @param[in] c The graph C context. +!------------------------------------------------------------------------------- + SUBROUTINE graph_compile(c) & + BIND(C, NAME='graph_compile') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run pre work items. +!> +!> @param[in] c The graph C context. +!------------------------------------------------------------------------------- + SUBROUTINE graph_pre_run(c) & + BIND(C, NAME='graph_pre_run') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run work items. +!> +!> @param[in] c The graph C context. +!------------------------------------------------------------------------------- + SUBROUTINE graph_run(c) & + BIND(C, NAME='graph_run') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Wait for work items to complete. +!> +!> @param[in] c The graph C context. +!------------------------------------------------------------------------------- + SUBROUTINE graph_wait(c) & + BIND(C, NAME='graph_wait') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Take derivative ∂f∂x. +!> +!> @param[in] c The graph C context. +!> @param[in] fnode The function expression to take the derivative of. +!> @param[in] xnode The expression to take the derivative with respect to. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_df(c, fnode, xnode) & + BIND(C, NAME='graph_df') + USE, INTRINSIC :: ISO_C_BINDING + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: fnode + TYPE(C_PTR), VALUE :: xnode + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Copy data to a device buffer. +!> +!> @param[in] c The c context. +!> @param[in] node Node to copy to. +!> @param[in] source Source to copy from. +!------------------------------------------------------------------------------- + SUBROUTINE graph_copy_to_device(c, node, source) & + BIND(C, NAME='graph_copy_to_device') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: node + INTEGER(C_LONG), VALUE :: source + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy data to a host buffer. +!> +!> @param[in] c The graph C context. +!> @param[in] node Node to copy from. +!> @param[in,out] destination Host side buffer to copy to. +!------------------------------------------------------------------------------- + SUBROUTINE graph_copy_to_host(c, node, destination) & + BIND(C, NAME='graph_copy_to_host') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: node + INTEGER(C_LONG), VALUE :: destination + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Print a value from nodes. +!> +!> @param[in] c The graph C context. +!> @param[in] index Particle index to print. +!> @param[in] nodes Nodes to print. +!> @param[in] num_nodes Number of nodes. +!------------------------------------------------------------------------------- + SUBROUTINE graph_print(c, index, nodes, num_nodes) & + BIND(C, NAME='graph_print') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + INTEGER(C_LONG), VALUE :: index + INTEGER(C_INTPTR_T), VALUE :: nodes + INTEGER(C_LONG), VALUE :: num_nodes + END SUBROUTINE + + END INTERFACE + + CONTAINS + +!******************************************************************************* +! Utilities +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Convert a node to the pointer value. +!> +!> @return The pointer value. +!------------------------------------------------------------------------------- + FUNCTION graph_ptr(node) + + IMPLICIT NONE + +! Declare Arguments + INTEGER(C_INTPTR_T) :: graph_ptr + TYPE(C_PTR), INTENT(IN) :: node + +! Start of executable code. + graph_ptr = TRANSFER(node, 0_C_INTPTR_T) + + END FUNCTION + +!******************************************************************************* +! CONSTRUCTION SUBROUTINES +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Construct a @ref graph_context object with float type. +!> +!> Allocate memory for the @ref graph_context and initalize the c context with +!> a double type. +!> +!> @param[in] use_safe_math Optional use safe math. +!------------------------------------------------------------------------------- + FUNCTION graph_construct_float(use_safe_math) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), POINTER :: graph_construct_float + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Start of executable code. + ALLOCATE(graph_construct_float) +#ifdef USE_METAL + graph_construct_float%arp_context = objc_autoreleasePoolPush() +#endif + graph_construct_float%c_context = & + graph_construct_context(FLOAT_T, use_safe_math) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Construct a @ref graph_context object with double type. +!> +!> Allocate memory for the @ref graph_context and initalize the c context with +!> a double type. +!> +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + FUNCTION graph_construct_double(use_safe_math) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), POINTER :: graph_construct_double + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Start of executable code. + ALLOCATE(graph_construct_double) +#ifdef USE_METAL + graph_construct_double%arp_context = objc_autoreleasePoolPush() +#endif + graph_construct_double%c_context = & + graph_construct_context(DOUBLE_T, use_safe_math) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Construct a @ref graph_context object with complex float type. +!> +!> Allocate memory for the @ref graph_context and initalize the c context with +!> a complex float type. +!> +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + FUNCTION graph_construct_complex_float(use_safe_math) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), POINTER :: graph_construct_complex_float + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Start of executable code. + ALLOCATE(graph_construct_complex_float) +#ifdef USE_METAL + graph_construct_complex_float%arp_context = objc_autoreleasePoolPush() +#endif + graph_construct_complex_float%c_context = & + graph_construct_context(COMPLEX_FLOAT_T, use_safe_math) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Construct a @ref graph_context object with complex double type. +!> +!> Allocate memory for the @ref graph_context and initalize the c context with +!> a complex double type. +!> +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + FUNCTION graph_construct_complex_double(use_safe_math) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), POINTER :: graph_construct_complex_double + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Start of executable code. + ALLOCATE(graph_construct_complex_double) +#ifdef USE_METAL + graph_construct_complex_double%arp_context = objc_autoreleasePoolPush() +#endif + graph_construct_complex_double%c_context = & + graph_construct_context(COMPLEX_DOUBLE_T, use_safe_math) + + END FUNCTION + +!******************************************************************************* +! DESTRUCTION SUBROUTINES +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Deconstruct a @ref graph_context object. +!> +!> Deallocate memory and unitialize a @ref graph_context object. +!> +!> @param[in,out] this A @ref graph_context instance. +!------------------------------------------------------------------------------- + SUBROUTINE graph_destruct(this) + + IMPLICIT NONE + +! Declare Arguments + TYPE(graph_context), INTENT(INOUT) :: this + +! Start of executable. +#ifdef USE_METAL + CALL objc_autoreleasePoolPop(this%arp_context) +#endif + CALL graph_destroy_context(this%c_context) + + END SUBROUTINE + +!******************************************************************************* +! Basic Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Create variable node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] size Size of the data buffer. +!> @param[in] symbol Symbol of the variable. +!------------------------------------------------------------------------------- + FUNCTION graph_context_variable(this, size, symbol) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_variable + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_LONG), INTENT(IN) :: size + CHARACTER(kind=C_CHAR,len=*), INTENT(IN) :: symbol + +! Start of executable. + graph_context_variable = graph_variable(this%c_context, size, symbol) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create variable node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] value Size of the data buffer. +!------------------------------------------------------------------------------- + FUNCTION graph_context_constant_real(this, value) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_constant_real + CLASS(graph_context), INTENT(INOUT) :: this + REAL(C_DOUBLE), INTENT(IN) :: value + +! Start of executable. + graph_context_constant_real = graph_constant(this%c_context, value) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Set the value of a variable float types. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The variable to set +!> @param[in] value THe value to set. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_set_variable_float(this, var, value) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + REAL(C_FLOAT), DIMENSION(:), INTENT(IN) :: value + +! Start of executable. + CALL graph_set_variable(this%c_context, var, LOC(value)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Set the value of a variable double types. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The variable to set +!> @param[in] value THe value to set. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_set_variable_double(this, var, value) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + REAL(C_DOUBLE), DIMENSION(:), INTENT(IN) :: value + +! Start of executable. + CALL graph_set_variable(this%c_context, var, LOC(value)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Set the value of a variable complex float types. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The variable to set +!> @param[in] value THe value to set. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_set_variable_cfloat(this, var, value) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(:), INTENT(IN) :: value + +! Start of executable. + CALL graph_set_variable(this%c_context, var, LOC(value)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Set the value of a variable complex double types. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The variable to set +!> @param[in] value THe value to set. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_set_variable_cdouble(this, var, value) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:), INTENT(IN) :: value + +! Start of executable. + CALL graph_set_variable(this%c_context, var, LOC(value)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Create variable node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] real_value The real component. +!> @param[in] img_value The imaginary component. +!------------------------------------------------------------------------------- + FUNCTION graph_context_constant_complex(this, real_value, img_value) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_constant_complex + CLASS(graph_context), INTENT(INOUT) :: this + REAL(C_DOUBLE), INTENT(IN) :: real_value + REAL(C_DOUBLE), INTENT(IN) :: img_value + +! Start of executable. + graph_context_constant_complex = graph_constant_c(this%c_context, & + real_value, img_value) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create variable node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The variable to set. +!> @returns The pseudo variable. +!------------------------------------------------------------------------------- + FUNCTION graph_context_pseudo_variable(this, var) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_pseudo_variable + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + +! Start of executable. + graph_context_pseudo_variable = graph_pseudo_variable(this%c_context, var) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Remove pseudo. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] var The graph to remove pseudo variables. +!> @returns The graph with pseudo variables removed. +!------------------------------------------------------------------------------- + FUNCTION graph_context_remove_pseudo(this, var) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_remove_pseudo + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: var + +! Start of executable. + graph_context_remove_pseudo = graph_remove_pseudo(this%c_context, var) + + END FUNCTION + +!******************************************************************************* +! Arithmetic Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Create Addition node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns left + right +!------------------------------------------------------------------------------- + FUNCTION graph_context_add(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_add + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_add = graph_add(this%c_context, left, right) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Subtract node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns left - right +!------------------------------------------------------------------------------- + FUNCTION graph_context_sub(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_sub + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_sub = graph_sub(this%c_context, left, right) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Multiply node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns left*right +!------------------------------------------------------------------------------- + FUNCTION graph_context_mul(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_mul + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_mul = graph_mul(this%c_context, left, right) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Divide node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns left/right +!------------------------------------------------------------------------------- + FUNCTION graph_context_div(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_div + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_div = graph_div(this%c_context, left, right) + + END FUNCTION + +!******************************************************************************* +! Math Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Create Sqrt node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns sqrt(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_sqrt(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_sqrt + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_sqrt = graph_sqrt(this%c_context, arg) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Exp node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns exp(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_exp(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_exp + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_exp = graph_exp(this%c_context, arg) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Log node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns log(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_log(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_log + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_log = graph_log(this%c_context, arg) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Pow node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns pow(left, right) +!------------------------------------------------------------------------------- + FUNCTION graph_context_pow(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_pow + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_pow = graph_pow(this%c_context, left, right) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create erfi node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns erfi(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_erfi(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_erfi + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_erfi = graph_erfi(this%c_context, arg) + + END FUNCTION + +!******************************************************************************* +! Trigonometry Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Create Sine node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns sin(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_sin(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_sin + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_sin = graph_sin(this%c_context, arg) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create Cosine node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @returns cos(arg) +!------------------------------------------------------------------------------- + FUNCTION graph_context_cos(this, arg) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_cos + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + +! Start of executable. + graph_context_cos = graph_cos(this%c_context, arg) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create atan node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] left The graph to remove pseudo variables. +!> @param[in] right The graph to remove pseudo variables. +!> @returns atan(left, right) +!------------------------------------------------------------------------------- + FUNCTION graph_context_atan(this, left, right) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_atan + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: left + TYPE(C_PTR), INTENT(IN) :: right + +! Start of executable. + graph_context_atan = graph_atan(this%c_context, left, right) + + END FUNCTION + +!******************************************************************************* +! Random Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Get random size. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] seed Intial random seed. +!> @returns The random size. +!------------------------------------------------------------------------------- + FUNCTION graph_context_random_state(this, seed) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_random_state + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_INT32_T), INTENT(IN) :: seed + +! Start of executable. + graph_context_random_state = graph_random_state(this%c_context, seed) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create random node. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] state A random state node. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_random(this, state) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_random + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: state + +! Start of executable. + graph_context_random = graph_random(this%c_context, state) + + END FUNCTION + +!******************************************************************************* +! Piecewise Nodes +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Create 1D piecewise node with float buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_1D_float(this, arg, scale, offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_1D_float + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + REAL(C_DOUBLE) :: scale + REAL(C_DOUBLE) :: offset + REAL(C_FLOAT), DIMENSION(:) :: source + +! Start of executable. + graph_context_piecewise_1D_float = & + graph_piecewise_1D(this%c_context, arg, scale, offset, LOC(source), & + INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 1D piecewise node with double buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_1D_double(this, arg, scale, offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_1D_double + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + REAL(C_DOUBLE) :: scale + REAL(C_DOUBLE) :: offset + REAL(C_DOUBLE), DIMENSION(:) :: source + +! Start of executable. + graph_context_piecewise_1D_double = & + graph_piecewise_1D(this%c_context, arg, scale, offset, LOC(source), & + INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 1D piecewise node with complex float buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_1D_cfloat(this, arg, scale, offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_1D_cfloat + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + REAL(C_DOUBLE) :: scale + REAL(C_DOUBLE) :: offset + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(:) :: source + +! Start of executable. + graph_context_piecewise_1D_cfloat = & + graph_piecewise_1D(this%c_context, arg, scale, offset, LOC(source), & + INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 1D piecewise node with complex double buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] arg The function argument. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_1D_cdouble(this, arg, scale, offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_1D_cdouble + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: arg + REAL(C_DOUBLE) :: scale + REAL(C_DOUBLE) :: offset + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:) :: source + +! Start of executable. + graph_context_piecewise_1D_cdouble = & + graph_piecewise_1D(this%c_context, arg, scale, offset, LOC(source), & + INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 2D piecewise node with float buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] x_arg The function x argument. +!> @param[in] x_scale Scale factor for x argument. +!> @param[in] x_offset Offset factor for x argument. +!> @param[in] y_arg The function y argument. +!> @param[in] y_scale Scale factor for y argument. +!> @param[in] y_offset Offset factor for y argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_2D_float(this, & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_2D_float + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: x_arg + REAL(C_DOUBLE) :: x_scale + REAL(C_DOUBLE) :: x_offset + TYPE(C_PTR), INTENT(IN) :: y_arg + REAL(C_DOUBLE) :: y_scale + REAL(C_DOUBLE) :: y_offset + REAL(C_FLOAT), DIMENSION(:,:) :: source + +! Start of executable. + graph_context_piecewise_2D_float = & + graph_piecewise_2D(this%c_context, & + INT(SIZE(source, 1), KIND=C_LONG), & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + LOC(source), INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 2D piecewise node with double buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] x_arg The function x argument. +!> @param[in] x_scale Scale factor for x argument. +!> @param[in] x_offset Offset factor for x argument. +!> @param[in] y_arg The function y argument. +!> @param[in] y_scale Scale factor for y argument. +!> @param[in] y_offset Offset factor for y argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_2D_double(this, & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_2D_double + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: x_arg + REAL(C_DOUBLE) :: x_scale + REAL(C_DOUBLE) :: x_offset + TYPE(C_PTR), INTENT(IN) :: y_arg + REAL(C_DOUBLE) :: y_scale + REAL(C_DOUBLE) :: y_offset + REAL(C_DOUBLE), DIMENSION(:,:) :: source + +! Start of executable. + graph_context_piecewise_2D_double = & + graph_piecewise_2D(this%c_context, & + INT(SIZE(source, 1), KIND=C_LONG), & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + LOC(source), INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 2D piecewise node with complex float buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] x_arg The function x argument. +!> @param[in] x_scale Scale factor for x argument. +!> @param[in] x_offset Offset factor for x argument. +!> @param[in] y_arg The function y argument. +!> @param[in] y_scale Scale factor for y argument. +!> @param[in] y_offset Offset factor for y argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_2D_cfloat(this, & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_2D_cfloat + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: x_arg + REAL(C_DOUBLE) :: x_scale + REAL(C_DOUBLE) :: x_offset + TYPE(C_PTR), INTENT(IN) :: y_arg + REAL(C_DOUBLE) :: y_scale + REAL(C_DOUBLE) :: y_offset + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(:,:) :: source + +! Start of executable. + graph_context_piecewise_2D_cfloat = & + graph_piecewise_2D(this%c_context, & + INT(SIZE(source, 1), KIND=C_LONG), & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + LOC(source), INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Create 2D piecewise node with complex double buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] x_arg The function x argument. +!> @param[in] x_scale Scale factor for x argument. +!> @param[in] x_offset Offset factor for x argument. +!> @param[in] y_arg The function y argument. +!> @param[in] y_scale Scale factor for y argument. +!> @param[in] y_offset Offset factor for y argument. +!> @param[in] source Source buffer to fill elements. +!> @returns random(state) +!------------------------------------------------------------------------------- + FUNCTION graph_context_piecewise_2D_cdouble(this, & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + source) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_piecewise_2D_cdouble + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: x_arg + REAL(C_DOUBLE) :: x_scale + REAL(C_DOUBLE) :: x_offset + TYPE(C_PTR), INTENT(IN) :: y_arg + REAL(C_DOUBLE) :: y_scale + REAL(C_DOUBLE) :: y_offset + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:,:) :: source + +! Start of executable. + graph_context_piecewise_2D_cdouble = & + graph_piecewise_2D(this%c_context, & + INT(SIZE(source, 1), KIND=C_LONG), & + x_arg, x_scale, x_offset, & + y_arg, y_scale, y_offset, & + LOC(source), INT(SIZE(source), KIND=C_LONG)) + + END FUNCTION + +!******************************************************************************* +! JIT +!******************************************************************************* +!------------------------------------------------------------------------------- +!> @brief Get the maximum number of concurrent devices. +!> +!> @param[in] this @ref graph_context instance. +!> @returns The number of devices. +!------------------------------------------------------------------------------- + FUNCTION graph_context_get_max_concurrency(this) + + IMPLICIT NONE + +! Declare Arguments + INTEGER(C_LONG) :: graph_context_get_max_concurrency + CLASS(graph_context), INTENT(IN) :: this + +! Start of executable. + graph_context_get_max_concurrency = & + graph_get_max_concurrency(this%c_context) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Choose the device number. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] num The device number. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_set_device_number(this, num) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_LONG), INTENT(IN) :: num + +! Start of executable. + CALL graph_set_device_number(this%c_context, num) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add pre workflow item. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] inputs Array of input nodes. +!> @param[in] outputs Array of output nodes. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_add_pre_item(this, inputs, outputs, & + map_inputs, map_outputs, & + random_state, name, num_particles) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: outputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_outputs + TYPE(C_PTR), INTENT(IN) :: random_state + CHARACTER(kind=C_CHAR,len=*), INTENT(IN) :: name + INTEGER(C_LONG), INTENT(IN) :: num_particles + +! Start of executable. + CALL graph_add_pre_item(this%c_context, & + LOC(inputs), INT(SIZE(inputs), KIND=C_LONG), & + LOC(outputs), INT(SIZE(outputs), KIND=C_LONG), & + LOC(map_inputs), LOC(map_outputs), & + INT(SIZE(map_inputs), KIND=C_LONG), & + random_state, name, num_particles) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add workflow item. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] inputs Array of input nodes. +!> @param[in] outputs Array of output nodes. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_add_item(this, inputs, outputs, & + map_inputs, map_outputs, & + random_state, name, num_particles) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: outputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_outputs + TYPE(C_PTR), INTENT(IN) :: random_state + CHARACTER(kind=C_CHAR,len=*), INTENT(IN) :: name + INTEGER(C_LONG), INTENT(IN) :: num_particles + +! Start of executable. + CALL graph_add_item(this%c_context, & + LOC(inputs), INT(SIZE(inputs), KIND=C_LONG), & + LOC(outputs), INT(SIZE(outputs), KIND=C_LONG), & + LOC(map_inputs), LOC(map_outputs), & + INT(SIZE(map_inputs), KIND=C_LONG), & + random_state, name, num_particles) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Add workflow converge item. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] inputs Array of input nodes. +!> @param[in] outputs Array of output nodes. +!> @param[in] map_inputs Array of map input nodes. +!> @param[in] map_outputs Array of map output nodes. +!> @param[in] random_state Optional random state, can be NULL if not used. +!> @param[in] name Name for the kernel. +!> @param[in] num_particles Number of elements to operate on. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_add_converge_item(this, inputs, outputs, & + map_inputs, map_outputs, & + random_state, name, & + num_particles, tol, max_iter) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(INOUT) :: this + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: outputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_inputs + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: map_outputs + TYPE(C_PTR), INTENT(IN) :: random_state + CHARACTER(kind=C_CHAR,len=*), INTENT(IN) :: name + INTEGER(C_LONG), INTENT(IN) :: num_particles + REAL(C_DOUBLE), VALUE :: tol + INTEGER(C_LONG), VALUE :: max_iter + +! Start of executable. + CALL graph_add_converge_item(this%c_context, LOC(inputs), & + INT(SIZE(inputs), KIND=C_LONG), & + LOC(outputs), & + INT(SIZE(outputs), KIND=C_LONG), & + LOC(map_inputs), LOC(map_outputs), & + INT(SIZE(map_inputs), KIND=C_LONG), & + random_state, name, num_particles, & + tol, max_iter) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Compile the work items. +!> +!> @param[in] this @ref graph_context instance. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_compile(this) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + +! Start of executable. + CALL graph_compile(this%c_context) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run pre work items. +!> +!> @param[in] this @ref graph_context instance. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_pre_run(this) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + +! Start of executable. + CALL graph_pre_run(this%c_context) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run work items. +!> +!> @param[in] this @ref graph_context instance. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_run(this) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + +! Start of executable. + CALL graph_run(this%c_context) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Wait for work items to complete. +!> +!> @param[in] this @ref graph_context instance. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_wait(this) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + +! Start of executable. + CALL graph_wait(this%c_context) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Take derivative ∂f∂x. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] fnode The function expression to take the derivative of. +!> @param[in] xnode The expression to take the derivative with respect to. +!------------------------------------------------------------------------------- + FUNCTION graph_context_df(this, fnode, xnode) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_df + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: fnode + TYPE(C_PTR), INTENT(IN) :: xnode + +! Start of executable. + graph_context_df = graph_df(this%c_context, fnode, xnode) + + END FUNCTION + +!------------------------------------------------------------------------------- +!> @brief Copy float data to a device buffer. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] node Node to copy to. +!> @param[in] source Source to copy from. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_device_float(this, node, source) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), INTENT(IN) :: node + REAL(C_FLOAT), DIMENSION(:), INTENT(IN) :: source + +! Start of executable. + CALL graph_copy_to_device(this%c_context, node, LOC(source)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy double data to a device buffer. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] node Node to copy to. +!> @param[in] source Source to copy from. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_device_double(this, node, source) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), INTENT(IN) :: node + REAL(C_DOUBLE), DIMENSION(:), INTENT(IN) :: source + +! Start of executable. + CALL graph_copy_to_device(this%c_context, node, LOC(source)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy complex float data to a device buffer. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] node Node to copy to. +!> @param[in] source Source to copy from. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_device_cfloat(this, node, source) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), INTENT(IN) :: node + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(:), INTENT(IN) :: source + +! Start of executable. + CALL graph_copy_to_device(this%c_context, node, LOC(source)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy complex double data to a device buffer. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] node Node to copy to. +!> @param[in] source Source to copy from. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_device_cdouble(this, node, source) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), INTENT(IN) :: node + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:), INTENT(IN) :: source + +! Start of executable. + CALL graph_copy_to_device(this%c_context, node, LOC(source)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy data to a host float buffer. +!> +!> @param[in] this @ref graph_context instance. +!> @param[in] node Node to copy from. +!> @param[in,out] destination Host side buffer to copy to. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_host_float(this, node, destination) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), VALUE :: node + REAL(C_FLOAT), DIMENSION(:), INTENT(INOUT) :: destination + +! Start of executable. + CALL graph_copy_to_host(this%c_context, node, LOC(destination)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy data to a host double buffer. +!> +!> @param[in] c The graph C context. +!> @param[in] node Node to copy from. +!> @param[in,out] destination Host side buffer to copy to. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_host_double(this, node, destination) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), VALUE :: node + REAL(C_DOUBLE), DIMENSION(:), INTENT(INOUT) :: destination + +! Start of executable. + CALL graph_copy_to_host(this%c_context, node, LOC(destination)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy data to a host complex float buffer. +!> +!> @param[in] c The graph C context. +!> @param[in] node Node to copy from. +!> @param[in,out] destination Host side buffer to copy to. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_host_cfloat(this, node, destination) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), VALUE :: node + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(:), INTENT(INOUT) :: destination + +! Start of executable. + CALL graph_copy_to_host(this%c_context, node, LOC(destination)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Copy data to a host complex double buffer. +!> +!> @param[in] c The graph C context. +!> @param[in] node Node to copy from. +!> @param[in,out] destination Host side buffer to copy to. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_copy_to_host_cdouble(this, node, destination) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + TYPE(C_PTR), VALUE :: node + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:), INTENT(INOUT) :: destination + +! Start of executable. + CALL graph_copy_to_host(this%c_context, node, LOC(destination)) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Print a value from nodes. +!> +!> @param[in] c The graph C context. +!> @param[in] index Particle index to print. +!> @param[in] nodes Nodes to print. +!------------------------------------------------------------------------------- + SUBROUTINE graph_context_print(this, index, nodes) + + IMPLICIT NONE + +! Declare Arguments + CLASS(graph_context), INTENT(IN) :: this + INTEGER(C_LONG), INTENT(IN) :: index + INTEGER(C_INTPTR_T), DIMENSION(:), INTENT(IN) :: nodes + +! Start of executable. + CALL graph_print(this%c_context, index, LOC(nodes), & + INT(SIZE(nodes), KIND=C_LONG)) + + END SUBROUTINE + + END MODULE diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index b230c5f3522efc8bb2b624d23eac949cc6bcab1b..593bf91bdda794160090abef0f0420cd1abb14ab 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -45,6 +45,7 @@ C7B676082AA9023F005AB34C /* xrays_bench.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7B676072AA9023F005AB34C /* xrays_bench.cpp */; }; C7D12D9A2DBAB31F00925420 /* random.hpp in Headers */ = {isa = PBXBuildFile; fileRef = C7D12D992DBAB31F00925420 /* random.hpp */; }; C7D371132A0595A40074676E /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C71342682947F36100672AD4 /* Metal.framework */; }; + C7DC9EEC2E39790100524F6F /* graph_c_binding.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7DC9EE22E39768300524F6F /* graph_c_binding.cpp */; }; C7E5644528A2A1AA000F31A2 /* backend_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7931E7328074F540033B488 /* backend_test.cpp */; }; C7E5645128A2A1DD000F31A2 /* dispersion_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7931E6B28073BCA0033B488 /* dispersion_test.cpp */; }; C7E5645D28A2A21D000F31A2 /* solver_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7931E6C28073BCA0033B488 /* solver_test.cpp */; }; @@ -149,6 +150,13 @@ remoteGlobalIDString = C79141A522DA9BF200E0BA0D; remoteInfo = graph_framework; }; + C7DC9EED2E39791C00524F6F /* PBXContainerItemProxy */ = { + isa = PBXContainerItemProxy; + containerPortal = C791419E22DA9BF200E0BA0D /* Project object */; + proxyType = 1; + remoteGlobalIDString = C79141A522DA9BF200E0BA0D; + remoteInfo = graph_framework; + }; /* End PBXContainerItemProxy section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -370,6 +378,9 @@ C7931E7128073BF30033B488 /* CMakeLists.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = CMakeLists.txt; sourceTree = ""; }; C7931E7228073BFC0033B488 /* CMakeLists.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = CMakeLists.txt; sourceTree = ""; }; C7931E7328074F540033B488 /* backend_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = backend_test.cpp; sourceTree = ""; }; + C7AE06632E3C285000586BCD /* CMakeLists.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = CMakeLists.txt; sourceTree = ""; }; + C7AE06642E3C285000586BCD /* graph_fortran_binding.f90 */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.fortran.f90; path = graph_fortran_binding.f90; sourceTree = ""; }; + C7AE06662E3C2AEE00586BCD /* f_binding_test.f90 */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.fortran.f90; path = f_binding_test.f90; sourceTree = ""; }; C7B676072AA9023F005AB34C /* xrays_bench.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; fileEncoding = 4; path = xrays_bench.cpp; sourceTree = ""; }; C7B676092AA90243005AB34C /* CMakeLists.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = CMakeLists.txt; sourceTree = ""; }; C7B677D829E45C9500D3ADC6 /* backend.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = backend.hpp; sourceTree = ""; }; @@ -378,6 +389,12 @@ C7CEA0052948EB0F00F61D09 /* cuda_context.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = cuda_context.hpp; sourceTree = ""; }; C7D12D992DBAB31F00925420 /* random.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = random.hpp; sourceTree = ""; }; C7D3C5B02C654AD3008AD8C6 /* efit_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = efit_test.cpp; sourceTree = ""; }; + C7DC9EE02E39768300524F6F /* CMakeLists.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = CMakeLists.txt; sourceTree = ""; }; + C7DC9EE12E39768300524F6F /* graph_c_binding.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = graph_c_binding.h; sourceTree = ""; }; + C7DC9EE22E39768300524F6F /* graph_c_binding.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = graph_c_binding.cpp; sourceTree = ""; }; + C7DC9EE82E39789900524F6F /* libgraph_c.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libgraph_c.a; sourceTree = BUILT_PRODUCTS_DIR; }; + C7DC9EEF2E397BE600524F6F /* Cocoa.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Cocoa.framework; path = System/Library/Frameworks/Cocoa.framework; sourceTree = SDKROOT; }; + C7DC9EF12E3A688F00524F6F /* c_binding_test.c */ = {isa = PBXFileReference; explicitFileType = sourcecode.c.objc; path = c_binding_test.c; sourceTree = ""; }; C7E134492A3CB3EC0083F6A7 /* output.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = output.hpp; sourceTree = ""; }; C7E5643E28A2A16F000F31A2 /* backend_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = backend_test; sourceTree = BUILT_PRODUCTS_DIR; }; C7E5644A28A2A1C5000F31A2 /* dispersion_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = dispersion_test; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -470,6 +487,13 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + C7DC9EE62E39789900524F6F /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; C7E5643B28A2A16F000F31A2 /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; @@ -542,6 +566,7 @@ C71342672947F36100672AD4 /* Frameworks */ = { isa = PBXGroup; children = ( + C7DC9EEF2E397BE600524F6F /* Cocoa.framework */, C71342682947F36100672AD4 /* Metal.framework */, ); name = Frameworks; @@ -598,6 +623,8 @@ C7931E7228073BFC0033B488 /* CMakeLists.txt */, C79141AD22DA9C0600E0BA0D /* graph_framework */, C7931E6928073BCA0033B488 /* graph_tests */, + C7DC9EE32E39768300524F6F /* graph_c_binding */, + C7AE06652E3C285000586BCD /* graph_fortran_binding */, C79141B422DAAD0C00E0BA0D /* graph_driver */, C74DF4582AA8BC7300319113 /* graph_benchmark */, C736E6B02C9B52CA00AAE3C0 /* graph_playground */, @@ -631,6 +658,7 @@ C736E6A42C9B526500AAE3C0 /* graph_playground */, C78F3D8F2DC41ACA002E3D94 /* random_test */, C78F3D9D2DC41B26002E3D94 /* graph_korc */, + C7DC9EE82E39789900524F6F /* libgraph_c.a */, ); name = Products; sourceTree = ""; @@ -694,10 +722,31 @@ C73BBE6929F7117E0027BB7F /* trigonometry_test.cpp */, C73BBE7D29F816E60027BB7F /* piecewise_test.cpp */, C78F3D8A2DC122C7002E3D94 /* random_test.cpp */, + C7DC9EF12E3A688F00524F6F /* c_binding_test.c */, + C7AE06662E3C2AEE00586BCD /* f_binding_test.f90 */, ); path = graph_tests; sourceTree = ""; }; + C7AE06652E3C285000586BCD /* graph_fortran_binding */ = { + isa = PBXGroup; + children = ( + C7AE06632E3C285000586BCD /* CMakeLists.txt */, + C7AE06642E3C285000586BCD /* graph_fortran_binding.f90 */, + ); + path = graph_fortran_binding; + sourceTree = ""; + }; + C7DC9EE32E39768300524F6F /* graph_c_binding */ = { + isa = PBXGroup; + children = ( + C7DC9EE02E39768300524F6F /* CMakeLists.txt */, + C7DC9EE12E39768300524F6F /* graph_c_binding.h */, + C7DC9EE22E39768300524F6F /* graph_c_binding.cpp */, + ); + path = graph_c_binding; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXHeadersBuildPhase section */ @@ -728,6 +777,13 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + C7DC9EE42E39789900524F6F /* Headers */ = { + isa = PBXHeadersBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; /* End PBXHeadersBuildPhase section */ /* Begin PBXNativeTarget section */ @@ -911,6 +967,26 @@ productReference = C79141B322DAAD0C00E0BA0D /* graph_driver */; productType = "com.apple.product-type.tool"; }; + C7DC9EE72E39789900524F6F /* graph_c */ = { + isa = PBXNativeTarget; + buildConfigurationList = C7DC9EE92E39789900524F6F /* Build configuration list for PBXNativeTarget "graph_c" */; + buildPhases = ( + C7DC9EE42E39789900524F6F /* Headers */, + C7DC9EE52E39789900524F6F /* Sources */, + C7DC9EE62E39789900524F6F /* Frameworks */, + ); + buildRules = ( + ); + dependencies = ( + C7DC9EEE2E39791C00524F6F /* PBXTargetDependency */, + ); + name = graph_c; + packageProductDependencies = ( + ); + productName = graph_c; + productReference = C7DC9EE82E39789900524F6F /* libgraph_c.a */; + productType = "com.apple.product-type.library.static"; + }; C7E5643D28A2A16F000F31A2 /* backend_test */ = { isa = PBXNativeTarget; buildConfigurationList = C7E5644228A2A16F000F31A2 /* Build configuration list for PBXNativeTarget "backend_test" */; @@ -1080,7 +1156,7 @@ isa = PBXProject; attributes = { BuildIndependentTargetsInParallel = YES; - LastUpgradeCheck = 1540; + LastUpgradeCheck = 1640; ORGANIZATIONNAME = "Cianciosa, Mark R."; TargetAttributes = { C7170CB82C66A10D003274E2 = { @@ -1113,6 +1189,9 @@ C79141B222DAAD0C00E0BA0D = { CreatedOnToolsVersion = 10.2.1; }; + C7DC9EE72E39789900524F6F = { + CreatedOnToolsVersion = 16.4; + }; C7E5643D28A2A16F000F31A2 = { CreatedOnToolsVersion = 13.4; }; @@ -1174,6 +1253,7 @@ C736E6A32C9B526500AAE3C0 /* graph_playground */, C78F3D8E2DC41ACA002E3D94 /* random_test */, C78F3D9C2DC41B26002E3D94 /* graph_korc */, + C7DC9EE72E39789900524F6F /* graph_c */, ); }; /* End PBXProject section */ @@ -1258,6 +1338,14 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + C7DC9EE52E39789900524F6F /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C7DC9EEC2E39790100524F6F /* graph_c_binding.cpp in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; C7E5643A28A2A16F000F31A2 /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; @@ -1398,6 +1486,11 @@ target = C79141A522DA9BF200E0BA0D /* graph_framework */; targetProxy = C74DF47B2AA8BD6600319113 /* PBXContainerItemProxy */; }; + C7DC9EEE2E39791C00524F6F /* PBXTargetDependency */ = { + isa = PBXTargetDependency; + target = C79141A522DA9BF200E0BA0D /* graph_framework */; + targetProxy = C7DC9EED2E39791C00524F6F /* PBXContainerItemProxy */; + }; /* End PBXTargetDependency section */ /* Begin XCBuildConfiguration section */ @@ -1715,7 +1808,7 @@ ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_C_LANGUAGE_STANDARD = c17; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; GCC_OPTIMIZATION_LEVEL = 0; @@ -1889,7 +1982,7 @@ ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_C_LANGUAGE_STANDARD = c17; GCC_NO_COMMON_BLOCKS = YES; GCC_PREPROCESSOR_DEFINITIONS = ( "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", @@ -2085,6 +2178,40 @@ }; name = Release; }; + C7DC9EEA2E39789900524F6F /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_STYLE = Automatic; + EXECUTABLE_PREFIX = lib; + GCC_C_LANGUAGE_STANDARD = gnu23; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 15.5; + PRODUCT_NAME = "$(TARGET_NAME)"; + SKIP_INSTALL = YES; + }; + name = Debug; + }; + C7DC9EEB2E39789900524F6F /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_STYLE = Automatic; + EXECUTABLE_PREFIX = lib; + GCC_C_LANGUAGE_STANDARD = gnu23; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 15.5; + PRODUCT_NAME = "$(TARGET_NAME)"; + SKIP_INSTALL = YES; + }; + name = Release; + }; C7E5644328A2A16F000F31A2 /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { @@ -2438,6 +2565,15 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; + C7DC9EE92E39789900524F6F /* Build configuration list for PBXNativeTarget "graph_c" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C7DC9EEA2E39789900524F6F /* Debug */, + C7DC9EEB2E39789900524F6F /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; C7E5644228A2A16F000F31A2 /* Build configuration list for PBXNativeTarget "backend_test" */ = { isa = XCConfigurationList; buildConfigurations = ( diff --git a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme index 42cdf873640301daae963a2257bf63901bc07bd2..073c3f31895590fd36ea39ae80bbde7ca64b8baa 100644 --- a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme +++ b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme @@ -1,6 +1,6 @@ :CXX_ARGS="-I${CMAKE_CURRENT_SOURCE_DIR}${jit_include_paths} -fgnuc-version=4.2.1 -std=gnu++2a"> $<$:CXX_ARGS="-I${CMAKE_CURRENT_SOURCE_DIR}${jit_include_paths} -std=gnu++2a -fno-use-cxa-atexit"> @@ -28,11 +28,11 @@ target_compile_definitions (rays $,USE_VERBOSE=true,USE_VERBOSE=false> ) -target_include_directories (rays +target_include_directories (graph_framework INTERFACE $ ) -target_link_libraries (rays +target_link_libraries (graph_framework INTERFACE sanitizer gpu_lib @@ -41,7 +41,7 @@ target_link_libraries (rays $<$:pthread> llvm_dep ) -target_precompile_headers (rays +target_precompile_headers (graph_framework INTERFACE $<$:$> $<$:$> diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index 66a3fd5957f41de464134fc78ce252e7e264ce95..37b0e297076b1c2d8e313ef2040d7a016f06274d 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -12,6 +12,7 @@ #include #include #include +#include // Clang headers will define IBAction and IBOutlet these so undefine them here. #undef IBAction @@ -448,22 +449,31 @@ namespace gpu { } source_buffer << ") {" << std::endl; + std::unordered_set used_args; for (size_t i = 0, ie = inputs.size(); i < ie; i++) { - source_buffer << " "; - if (is_constant[i]) { - source_buffer << "const "; + if (!used_args.contains(inputs[i].get())) { + source_buffer << " "; + if (is_constant[i]) { + source_buffer << "const "; + } + jit::add_type (source_buffer); + source_buffer << " *" << jit::to_string('v', inputs[i].get()) + << " = args[" + << reinterpret_cast (inputs[i].get()) + << "];" << std::endl; + used_args.insert(inputs[i].get()); } - jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', inputs[i].get()) - << " = args[" << reinterpret_cast (inputs[i].get()) - << "];" << std::endl; } for (auto &output : outputs) { - source_buffer << " "; - jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('o', output.get()) - << " = args[" << reinterpret_cast (output.get()) - << "];" << std::endl; + if (!used_args.contains(output.get())) { + source_buffer << " "; + jit::add_type (source_buffer); + source_buffer << " *" << jit::to_string('o', output.get()) + << " = args[" + << reinterpret_cast (output.get()) + << "];" << std::endl; + used_args.insert(output.get()); + } } if (state.get()) { registers[state.get()] = jit::to_string('r', state.get()); @@ -509,56 +519,65 @@ namespace gpu { jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage) { + std::unordered_set out_registers; for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('v', in.get()); - source_buffer << "[i] = "; - if constexpr (SAFE_MATH) { - if constexpr (jit::complex_scalar) { - jit::add_type (source_buffer); - source_buffer << " ("; - source_buffer << "isnan(real(" << registers[a.get()] - << ")) ? 0.0 : real(" << registers[a.get()] - << "), "; - source_buffer << "isnan(imag(" << registers[a.get()] - << ")) ? 0.0 : imag(" << registers[a.get()] - << "));" << std::endl; + if (!out->is_match(in) && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " << jit::to_string('v', in.get()); + source_buffer << "[i] = "; + if constexpr (SAFE_MATH) { + if constexpr (jit::complex_scalar) { + jit::add_type (source_buffer); + source_buffer << " ("; + source_buffer << "isnan(real(" << registers[a.get()] + << ")) ? 0.0 : real(" << registers[a.get()] + << "), "; + source_buffer << "isnan(imag(" << registers[a.get()] + << ")) ? 0.0 : imag(" << registers[a.get()] + << "));" << std::endl; + } else { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : " << registers[a.get()] + << ";" << std::endl; + } } else { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : " << registers[a.get()] - << ";" << std::endl; + source_buffer << registers[a.get()] << ";" << std::endl; } - } else { - source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('o', out.get()); - source_buffer << "[i] = "; - if constexpr (SAFE_MATH) { - if constexpr (jit::complex_scalar) { - jit::add_type (source_buffer); - source_buffer << " ("; - source_buffer << "isnan(real(" << registers[a.get()] - << ")) ? 0.0 : real(" << registers[a.get()] - << "), "; - source_buffer << "isnan(imag(" << registers[a.get()] - << ")) ? 0.0 : imag(" << registers[a.get()] - << "));" << std::endl; + if (!graph::variable_cast(out).get() && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " << jit::to_string('o', out.get()); + source_buffer << "[i] = "; + if constexpr (SAFE_MATH) { + if constexpr (jit::complex_scalar) { + jit::add_type (source_buffer); + source_buffer << " ("; + source_buffer << "isnan(real(" << registers[a.get()] + << ")) ? 0.0 : real(" << registers[a.get()] + << "), "; + source_buffer << "isnan(imag(" << registers[a.get()] + << ")) ? 0.0 : imag(" << registers[a.get()] + << "));" << std::endl; + } else { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : " << registers[a.get()] + << ";" << std::endl; + } } else { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : " << registers[a.get()] - << ";" << std::endl; + source_buffer << registers[a.get()] << ";" << std::endl; } - } else { - source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } } diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 288b42becbc0b620cfbdaf7105ac6a86071eed1b..c1d2cbb89067e8039a69b54cc596bff01610bb2c 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -8,6 +8,7 @@ #ifndef cuda_context_h #define cuda_context_h +#include #include #include @@ -337,8 +338,8 @@ namespace gpu { &backend[0], backend.size()*sizeof(T)), "cuMemcpyHtoD"); + buffers.push_back(reinterpret_cast (&kernel_arguments[input.get()])); } - buffers.push_back(reinterpret_cast (&kernel_arguments[input.get()])); } for (auto &output : outputs) { if (!kernel_arguments.contains(output.get())) { @@ -347,8 +348,8 @@ namespace gpu { num_rays*sizeof(T), CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged"); + buffers.push_back(reinterpret_cast (&kernel_arguments[output.get()])); } - buffers.push_back(reinterpret_cast (&kernel_arguments[output.get()])); } const size_t num_buffers = buffers.size(); @@ -713,6 +714,7 @@ namespace gpu { source_buffer << "extern \"C\" __global__ void " << name << "(" << std::endl; + std::unordered_set used_args; if (inputs.size()) { source_buffer << " "; if (is_constant[0]) { @@ -721,22 +723,26 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " * __restrict__ " << jit::to_string('v', inputs[0].get()); + sed_args.insert(inputs[0].get()); } for (size_t i = 1, ie = inputs.size(); i < ie; i++) { - source_buffer << ", // " << inputs[i - 1]->get_symbol() + if (!used_args.contains(inputs[i].get())) { + source_buffer << ", // " << inputs[i - 1]->get_symbol() #ifndef USE_INPUT_CACHE #ifdef SHOW_USE_COUNT - << " used " << usage.at(inputs[i - 1].get()) + << " used " << usage.at(inputs[i - 1].get()) #endif #endif - << std::endl; - source_buffer << " "; - if (is_constant[i]) { - source_buffer << "const "; + << std::endl; + source_buffer << " "; + if (is_constant[i]) { + source_buffer << "const "; + } + jit::add_type (source_buffer); + source_buffer << " * __restrict__ " + << jit::to_string('v', inputs[i].get()); + used_args.insert(inputs[i].get()); } - jit::add_type (source_buffer); - source_buffer << " * __restrict__ " - << jit::to_string('v', inputs[i].get()); } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { if (i == 0) { @@ -755,10 +761,13 @@ namespace gpu { source_buffer << "," << std::endl; } - source_buffer << " "; - jit::add_type (source_buffer); - source_buffer << " * __restrict__ " - << jit::to_string('o', outputs[i].get()); + if (!used_args.contains(outputs[i].get())) { + source_buffer << " "; + jit::add_type (source_buffer); + source_buffer << " * __restrict__ " + << jit::to_string('o', outputs[i].get()); + used_args.insert(outputs[i].get()); + } } if (state.get()) { source_buffer << "," << std::endl @@ -846,65 +855,80 @@ namespace gpu { jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage) { + std::unordered_set out_registers; for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('v', in.get()) - << "["; - if (state.get()) { - source_buffer << "offset[0] + "; - } - source_buffer << "index] = "; - if constexpr (SAFE_MATH) { - if constexpr (jit::complex_scalar) { - jit::add_type (source_buffer); - source_buffer << " ("; - source_buffer << "isnan(real(" << registers[a.get()] - << ")) ? 0.0 : real(" << registers[a.get()] - << "), "; - source_buffer << "isnan(imag(" << registers[a.get()] - << ")) ? 0.0 : imag(" << registers[a.get()] - << "));" << std::endl; + if (!out->is_match(in) && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " + << jit::to_string('v', in.get()) + << "["; + if (state.get()) { + source_buffer << "offset[0] + "; + } + source_buffer << "index] = "; + if constexpr (SAFE_MATH) { + if constexpr (jit::complex_scalar) { + jit::add_type (source_buffer); + source_buffer << " ("; + source_buffer << "isnan(real(" << registers[a.get()] + << ")) ? 0.0 : real(" + << registers[a.get()] + << "), "; + source_buffer << "isnan(imag(" << registers[a.get()] + << ")) ? 0.0 : imag(" + << registers[a.get()] + << "));" << std::endl; + } else { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : " << registers[a.get()] + << ";" << std::endl; + } } else { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : " << registers[a.get()] - << ";" << std::endl; + source_buffer << registers[a.get()] << ";" << std::endl; } - } else { - source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('o', out.get()) - << "["; - if (state.get()) { - source_buffer << "offset[0] + "; - } - source_buffer << "index] = "; - if constexpr (SAFE_MATH) { - if constexpr (jit::complex_scalar) { - jit::add_type (source_buffer); - source_buffer << " ("; - source_buffer << "isnan(real(" << registers[a.get()] - << ")) ? 0.0 : real(" << registers[a.get()] - << "), "; - source_buffer << "isnan(imag(" << registers[a.get()] - << ")) ? 0.0 : imag(" << registers[a.get()] - << "));" << std::endl; + if (!graph::variable_cast(out).get() && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " + << jit::to_string('o', out.get()) + << "["; + if (state.get()) { + source_buffer << "offset[0] + "; + } + source_buffer << "index] = "; + if constexpr (SAFE_MATH) { + if constexpr (jit::complex_scalar) { + jit::add_type (source_buffer); + source_buffer << " ("; + source_buffer << "isnan(real(" << registers[a.get()] + << ")) ? 0.0 : real(" + << registers[a.get()] + << "), "; + source_buffer << "isnan(imag(" << registers[a.get()] + << ")) ? 0.0 : imag(" + << registers[a.get()] + << "));" << std::endl; + } else { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : " << registers[a.get()] + << ";" << std::endl; + } } else { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : " << registers[a.get()] - << ";" << std::endl; + source_buffer << registers[a.get()] << ";" << std::endl; } - } else { - source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } } diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index f9974741c2812b184a5157a88ad6a337ae680f73..d85ea9f1ff832b4457c3272c26af2106fdec47b8 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -8,6 +8,8 @@ #ifndef metal_context_h #define metal_context_h +#include + #import #include "random.hpp" @@ -143,15 +145,15 @@ namespace gpu { kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.data() length:buffer.size()*buffer_element_size options:MTLResourceStorageModeShared]; + buffers.push_back(kernel_arguments[input.get()]); } - buffers.push_back(kernel_arguments[input.get()]); } for (graph::shared_leaf &output : outputs) { if (!kernel_arguments.contains(output.get())) { kernel_arguments[output.get()] = [device newBufferWithLength:num_rays*sizeof(float) options:MTLResourceStorageModeShared]; + buffers.push_back(kernel_arguments[output.get()]); } - buffers.push_back(kernel_arguments[output.get()]); } if (state.get()) { if (!kernel_arguments.contains(state.get())) { @@ -456,35 +458,43 @@ namespace gpu { bufferMutability[name] = std::vector (); + size_t buffer_count = 0; + std::unordered_set used_args; for (size_t i = 0, ie = inputs.size(); i < ie; i++) { - bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable); - source_buffer << " " << (is_constant[i] ? "constant" : "device") - << " float *" - << jit::to_string('v', inputs[i].get()) - << " [[buffer(" << i << ")]], // " - << inputs[i]->get_symbol() + if (!used_args.contains(inputs[i].get())) { + bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable); + source_buffer << " " << (is_constant[i] ? "constant" : "device") + << " float *" + << jit::to_string('v', inputs[i].get()) + << " [[buffer(" << buffer_count++ << ")]], // " + << inputs[i]->get_symbol() #ifndef USE_INPUT_CACHE #ifdef SHOW_USE_COUNT - << " used " << usage.at(inputs[i].get()) + << " used " << usage.at(inputs[i].get()) #endif #endif - << std::endl; + << std::endl; + used_args.insert(inputs[i].get()); + } } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { - bufferMutability[name].push_back(MTLMutabilityMutable); - source_buffer << " device float *" - << jit::to_string('o', outputs[i].get()) - << " [[buffer(" << i + inputs.size() << ")]]," - << std::endl; + if (!used_args.contains(outputs[i].get())) { + bufferMutability[name].push_back(MTLMutabilityMutable); + source_buffer << " device float *" + << jit::to_string('o', outputs[i].get()) + << " [[buffer(" << buffer_count++ << ")]]," + << std::endl; + used_args.insert(outputs[i].get()); + } } if (state.get()) { bufferMutability[name].push_back(MTLMutabilityMutable); source_buffer << " device mt_state *" << jit::to_string('s', state.get()) - << " [[buffer(" << inputs.size() + outputs.size() << ")]]," + << " [[buffer(" << buffer_count++ << ")]]," << std::endl << " constant uint32_t &offset [[buffer(" - << inputs.size() + outputs.size() + 1 << ")]]," + << buffer_count++ << ")]]," << std::endl; } size_t index = 0; @@ -563,32 +573,42 @@ namespace gpu { jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage) { + std::unordered_set out_registers; for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('v', in.get()) - << "[index] = "; - if constexpr (SAFE_MATH) { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : "; + if (!out->is_match(in) && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " + << jit::to_string('v', in.get()) + << "[index] = "; + if constexpr (SAFE_MATH) { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : "; + } + source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } - source_buffer << registers[a.get()] << ";" << std::endl; } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, - registers, - indices, - usage); - source_buffer << " " << jit::to_string('o', out.get()) - << "[index] = "; - if constexpr (SAFE_MATH) { - source_buffer << "isnan(" << registers[a.get()] - << ") ? 0.0 : "; + if (!graph::variable_cast(out).get() && + !out_registers.contains(out.get())) { + graph::shared_leaf a = out->compile(source_buffer, + registers, + indices, + usage); + source_buffer << " " << jit::to_string('o', out.get()) + << "[index] = "; + if constexpr (SAFE_MATH) { + source_buffer << "isnan(" << registers[a.get()] + << ") ? 0.0 : "; + } + source_buffer << registers[a.get()] << ";" << std::endl; + out_registers.insert(out.get()); } - source_buffer << registers[a.get()] << ";" << std::endl; } source_buffer << " }" << std::endl << "}" << std::endl; diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index b6a1839086b94f04f09aa38f03e685a6059d2d64..5b236d24ac57df6c0db23c6a4a1fcdd3c73ab9c9 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -440,7 +440,7 @@ namespace graph { /// @returns The derivative of the node. //------------------------------------------------------------------------------ virtual shared_leaf df(shared_leaf x) { - return zero (); + return this->is_match(x) ? one () : zero (); } //------------------------------------------------------------------------------ @@ -1474,7 +1474,7 @@ namespace graph { /// @returns The derivative of the node. //------------------------------------------------------------------------------ virtual shared_leaf df(shared_leaf x) { - return constant (static_cast (this == x.get())); + return constant (static_cast (this->is_match(x))); } //------------------------------------------------------------------------------ diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 9147353e906e4f2bba47d0c96e69834c2ccb0765..de735f478632a9138ca73cdc4675525d568bad02 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -19,7 +19,7 @@ namespace graph { /// @param[in,out] stream String buffer stream. /// @param[in] register_name Reister for the argument. /// @param[in] length Dimension length of argument. -/// @param[in] scale Argument scale factor. +/// @param[in] scale Argument scale factor. /// @param[in] offset Argument offset factor. //------------------------------------------------------------------------------ template @@ -203,7 +203,7 @@ void compile_index(std::ostringstream &stream, /// @returns The derivative of the node. //------------------------------------------------------------------------------ virtual shared_leaf df(shared_leaf x) { - return zero (); + return constant (static_cast (this->is_match(x))); } //------------------------------------------------------------------------------ @@ -553,8 +553,10 @@ void compile_index(std::ostringstream &stream, /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// -/// @param[in] d Data to initalize the piecewise constant. -/// @param[in] x Argument. +/// @param[in] d Data to initalize the piecewise constant. +/// @param[in] x Argument. +/// @param[in] scale Argument scale factor. +/// @param[in] offset Argument offset factor. /// @returns A reduced piecewise\_1D node. //------------------------------------------------------------------------------ template @@ -836,7 +838,7 @@ void compile_index(std::ostringstream &stream, /// @returns The derivative of the node. //------------------------------------------------------------------------------ virtual shared_leaf df(shared_leaf x) { - return zero (); + return constant (static_cast (this->is_match(x))); } //------------------------------------------------------------------------------ diff --git a/graph_framework/random.hpp b/graph_framework/random.hpp index f049a932bf8618ac9b2a25605a5d2168f1964093..1d0cb883ceb29d3b2c34e52982f017cf9dafb7f4 100644 --- a/graph_framework/random.hpp +++ b/graph_framework/random.hpp @@ -39,7 +39,6 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Construct a constant node from a vector. /// -/// @param[in] size Number of random states. /// @param[in] seed Inital random seed. //------------------------------------------------------------------------------ random_state_node(const size_t size, diff --git a/graph_framework/workflow.hpp b/graph_framework/workflow.hpp index c33f525f6958134fc7b74e4be4667e10c2d8fbc8..86dac56692e7ba7778d67af0dd8811029084d180 100644 --- a/graph_framework/workflow.hpp +++ b/graph_framework/workflow.hpp @@ -229,7 +229,7 @@ namespace workflow { } //------------------------------------------------------------------------------ -/// @brief Add a workflow item. +/// @brief Add a converge item. /// /// @param[in] in Input variables. /// @param[in] out Output nodes. @@ -237,7 +237,7 @@ namespace workflow { /// @param[in] state Random state node. /// @param[in] name Name of the workitem. /// @param[in] size Size of the workitem. -/// @param[in] tol Tolarance to solve the dispersion function to. +/// @param[in] tol Tolarance to converge the function to. /// @param[in] max_iter Maximum number of iterations before giving up. //------------------------------------------------------------------------------ void add_converge_item(graph::input_nodes in, @@ -297,8 +297,8 @@ namespace workflow { //------------------------------------------------------------------------------ /// @brief Copy buffer contents to the device. /// -/// @param[in] node Not to copy buffer to. -/// @param[in] destination Device side buffer to copy to. +/// @param[in] node Node to copy buffer to. +/// @param[in] destination Host side buffer to copy from. //------------------------------------------------------------------------------ void copy_to_device(graph::shared_leaf &node, T *destination) { @@ -330,7 +330,7 @@ namespace workflow { //------------------------------------------------------------------------------ /// @brief Check the value. /// -/// @param[in] index Ray index to check value for. +/// @param[in] index Particle index to check value for. /// @param[in] node Node to check the value for. /// @returns The value at the index. //------------------------------------------------------------------------------ diff --git a/graph_korc/CMakeLists.txt b/graph_korc/CMakeLists.txt index 460f9c5050e0038d920f41647722aeea212aa98b..4da42c13373b688d6b186d76b7d74b388a9ae1e5 100644 --- a/graph_korc/CMakeLists.txt +++ b/graph_korc/CMakeLists.txt @@ -1,6 +1,5 @@ -add_tool_target (xkorc) +add_tool_target (xkorc cpp) -if (${USE_PCH}) - target_precompile_headers (xkorc REUSE_FROM xrays) +if (${USE_PCH} AND NOT ${BUILD_C_BINDING}) + target_precompile_headers (xrays_bench REUSE_FROM xrays) endif () - diff --git a/graph_playground/CMakeLists.txt b/graph_playground/CMakeLists.txt index 6b2fbdd8d93c8cf8878b458842964a2b56dcaf40..b72bfd5c8b08e682b7c85b593627beb16c6697a7 100644 --- a/graph_playground/CMakeLists.txt +++ b/graph_playground/CMakeLists.txt @@ -1,5 +1,5 @@ -add_tool_target (xplayground) +add_tool_target (xplayground cpp) -if (${USE_PCH}) - target_precompile_headers (xplayground REUSE_FROM xrays) +if (${USE_PCH} AND NOT ${BUILD_C_BINDING}) + target_precompile_headers (xrays_bench REUSE_FROM xrays) endif () diff --git a/graph_tests/CMakeLists.txt b/graph_tests/CMakeLists.txt index b69e7f3738a0cacd8a63a5372f80a78701ea014f..5b090c3ad5f28eabc0202cd434dc631477fb1a36 100644 --- a/graph_tests/CMakeLists.txt +++ b/graph_tests/CMakeLists.txt @@ -1,17 +1,17 @@ -add_test_target (node_test) -add_test_target (arithmetic_test) -add_test_target (math_test) -add_test_target (dispersion_test) -add_test_target (solver_test) -add_test_target (backend_test) -add_test_target (vector_test) -add_test_target (physics_test) -add_test_target (jit_test) -add_test_target (trigonometry_test) -add_test_target (piecewise_test) -add_test_target (erfi_test) -add_test_target (efit_test) -add_test_target (random_test) +add_test_target (node_test cpp) +add_test_target (arithmetic_test cpp) +add_test_target (math_test cpp) +add_test_target (dispersion_test cpp) +add_test_target (solver_test cpp) +add_test_target (backend_test cpp) +add_test_target (vector_test cpp) +add_test_target (physics_test cpp) +add_test_target (jit_test cpp) +add_test_target (trigonometry_test cpp) +add_test_target (piecewise_test cpp) +add_test_target (erfi_test cpp) +add_test_target (efit_test cpp) +add_test_target (random_test cpp) target_compile_definitions (erfi_test PRIVATE @@ -22,3 +22,19 @@ target_compile_definitions (efit_test PRIVATE EFIT_GOLD_FILE="${CMAKE_CURRENT_SOURCE_DIR}/efit_gold.nc" ) + +if (${BUILD_C_BINDING}) + add_test_target (c_binding_test c) + target_link_libraries (c_binding_test + PRIVATE + graph_c + ) +endif () + +if (${BUILD_Fortran_BINDING}) + add_test_target (f_binding_test f90) + target_link_libraries (f_binding_test + PRIVATE + graph_f + ) +endif () diff --git a/graph_tests/c_binding_test.c b/graph_tests/c_binding_test.c new file mode 100644 index 0000000000000000000000000000000000000000..56be54f7c78f7f72e51f98daae7a3b5ab22db616 --- /dev/null +++ b/graph_tests/c_binding_test.c @@ -0,0 +1,374 @@ +//------------------------------------------------------------------------------ +/// @file c_binding_test.cpp +/// @brief Tests for c bindings. +//------------------------------------------------------------------------------ + +// Turn on asserts even in release builds. +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include + +#include "../graph_c_binding/graph_c_binding.h" + +//------------------------------------------------------------------------------ +/// @brief Run tests +/// +/// @param[in] type Type to run tests on. +/// @param[in] use_safe_math Use safe math. +//------------------------------------------------------------------------------ +void run_tests(const enum graph_type type, + const bool use_safe_math) { + struct graph_c_context *c_context = graph_construct_context(type, use_safe_math); + + graph_node x = graph_variable(c_context, 1, "x"); + graph_node m; + graph_node b; + if (type == FLOAT || type == DOUBLE) { + m = graph_constant(c_context, 0.5); + b = graph_constant(c_context, 0.2); + } else { + m = graph_constant_c(c_context, 0.5, 0.0); + b = graph_constant_c(c_context, 0.2, 0.0); + } + graph_node y = graph_add(c_context, graph_mul(c_context, m, x), b); + + graph_node px = graph_pseudo_variable(c_context, x); + assert(graph_remove_pseudo(c_context, px) == x && + "Expected to recieve x."); + + graph_node one = graph_constant(c_context, 1.0); + graph_node zero = graph_constant(c_context, 0.0); + assert(graph_sub(c_context, one, one) == zero && + "Expected to recieve zero."); + assert(graph_div(c_context, one, one) == one && + "Expected to recieve one."); + assert(graph_sqrt(c_context, one) == one && + "Expected to recieve one."); + assert(graph_exp(c_context, zero) == one && + "Expected to recieve one."); + assert(graph_log(c_context, one) == zero && + "Expected to recieve zero."); + assert(graph_pow(c_context, one, one) == one && + "Expected to recieve one."); + + if (type == COMPLEX_FLOAT || type == COMPLEX_DOUBLE) { + assert(graph_erfi(c_context, zero) == zero && + "Expected to recieve zero."); + } + + assert(graph_sin(c_context, zero) == zero && + "Expected to recieve zero."); + assert(graph_cos(c_context, zero) == one && + "Expected to recieve one."); + assert(graph_atan(c_context, one, zero) == zero && + "Expected to recieve zero."); + + graph_node dydx = graph_df(c_context, y, x); + graph_node dydm = graph_df(c_context, y, m); + graph_node dydb = graph_df(c_context, y, b); + graph_node dydy = graph_df(c_context, y, y); + + switch (c_context->type) { + case FLOAT: { + float value = 2.0; + graph_set_variable(c_context, x, &value); + break; + } + + case DOUBLE: { + double value = 2.0; + graph_set_variable(c_context, x, &value); + break; + } + + case COMPLEX_FLOAT: { + float complex value = CMPLXF(2.0, 0.0); + graph_set_variable(c_context, x, &value); + break; + } + + case COMPLEX_DOUBLE: { + double complex value = CMPLX(2.0, 0.0); + graph_set_variable(c_context, x, &value); + break; + } + } + + graph_node state = graph_random_state(c_context, 0); + graph_node rand = graph_random(c_context, state); + + const size_t max_device = graph_get_max_concurrency(c_context) - 1; + graph_set_device_number(c_context, max_device); + + graph_node inputs[1] = {x}; + graph_node outputs[5] = {y, dydx, dydm, dydb, dydy}; + graph_node *map_inputs = NULL; + graph_node *map_outputs = NULL; + + graph_node z = graph_variable(c_context, 1, "z"); + graph_node root = graph_sub(c_context, + graph_pow(c_context, z, + graph_constant(c_context, 3.0)), + graph_pow(c_context, z, + graph_constant(c_context, 2.0))); + graph_node root2 = graph_mul(c_context, root, root); + graph_node dz = graph_sub(c_context, z, + graph_div(c_context, root, + graph_df(c_context, root, z))); + + graph_node p1; + graph_node p2; + graph_node i = graph_variable(c_context, 1, "i"); + graph_node j = graph_variable(c_context, 1, "j"); + switch (c_context->type) { + case FLOAT: { + float value1[3] = {2.0, 4.0, 6.0}; + float value2[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + float value3 = 1.5; + float value4 = 2.5; + graph_set_variable(c_context, i, &value3); + graph_set_variable(c_context, j, &value4); + p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); + p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); + break; + } + + case DOUBLE: { + double value1[3] = {2.0, 4.0, 6.0}; + double value2[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + double value3 = 1.5; + double value4 = 2.5; + graph_set_variable(c_context, i, &value3); + graph_set_variable(c_context, j, &value4); + p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); + p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); + break; + } + + case COMPLEX_FLOAT: { + float complex value1[3] = {CMPLXF(2.0, 0.0), CMPLXF(4.0, 0.0), CMPLXF(6.0, 0.0)}; + float complex value2[9] = {CMPLXF(1.0, 0.0), CMPLXF(2.0, 0.0), CMPLXF(3.0, 0.0), + CMPLXF(4.0, 0.0), CMPLXF(5.0, 0.0), CMPLXF(6.0, 0.0), + CMPLXF(7.0, 0.0), CMPLXF(8.0, 0.0), CMPLXF(9.0, 0.0)}; + float complex value3 = CMPLXF(1.5, 0.0); + float complex value4 = CMPLXF(2.5, 0.0); + graph_set_variable(c_context, i, &value3); + graph_set_variable(c_context, j, &value4); + p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); + p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); + break; + } + + case COMPLEX_DOUBLE: { + double complex value1[3] = {CMPLX(2.0, 0.0), CMPLX(4.0, 0.0), CMPLX(6.0, 0.0)}; + double complex value2[9] = {CMPLX(1.0, 0.0), CMPLX(2.0, 0.0), CMPLX(3.0, 0.0), + CMPLX(4.0, 0.0), CMPLX(5.0, 0.0), CMPLX(6.0, 0.0), + CMPLX(7.0, 0.0), CMPLX(8.0, 0.0), CMPLX(9.0, 0.0)}; + double complex value3 = CMPLXF(1.5, 0.0); + double complex value4 = CMPLXF(2.5, 0.0); + graph_set_variable(c_context, i, &value3); + graph_set_variable(c_context, j, &value4); + p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); + p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); + break; + } + } + + graph_node inputs2[2] = {i, j}; + graph_node outputs2[2] = {p1, p2}; + graph_node *map_inputs2 = NULL; + graph_node *map_outputs2 = NULL; + + graph_add_pre_item(c_context, + NULL, 0, + &rand, 1, + NULL, NULL, 0, + state, + "c_binding_pre_kernel", 1); + graph_add_item(c_context, + inputs, 1, + outputs, 5, + map_inputs, map_outputs, 0, + NULL, "c_binding", 1); + graph_add_item(c_context, + inputs2, 2, + outputs2, 2, + map_inputs2, map_outputs2, 0, + NULL, "c_binding_piecewise", 1); + graph_add_converge_item(c_context, &z, 1, + &root2, 1, + &z, &dz, 1, + NULL, "c_binding_converge", 1, + 1.0E-30, 1000); + graph_compile(c_context); + switch (c_context->type) { + case FLOAT: { + float value = 10.0; + graph_copy_to_device(c_context, z, &value); + break; + } + + case DOUBLE: { + double value = 10.0; + graph_copy_to_device(c_context, z, &value); + break; + } + + case COMPLEX_FLOAT: { + float complex value = CMPLXF(10.0, 0.0); + graph_copy_to_device(c_context, z, &value); + break; + } + + case COMPLEX_DOUBLE: { + double complex value = CMPLX(10.0, 0.0); + graph_copy_to_device(c_context, z, &value); + break; + } + } + graph_pre_run(c_context); + graph_run(c_context); + graph_wait(c_context); + inputs2[0] = z; + inputs2[1] = y; + graph_print(c_context, 0, inputs2, 2); + + switch (c_context->type) { + case FLOAT: { + float value[9]; + graph_copy_to_host(c_context, y, value); + graph_copy_to_host(c_context, dydx, value + 1); + graph_copy_to_host(c_context, dydm, value + 2); + graph_copy_to_host(c_context, dydb, value + 3); + graph_copy_to_host(c_context, dydy, value + 4); + graph_copy_to_host(c_context, rand, value + 5); + graph_copy_to_host(c_context, z, value + 6); + graph_copy_to_host(c_context, p1, value + 7); + graph_copy_to_host(c_context, p2, value + 8); + assert(value[0] == 0.5f*2.0f + 0.2f && "Value of y does not match."); + assert(value[1] == 0.5f && "Value of dydx does not match."); + assert(value[2] == 2.0f && "Value of dydm does not match."); + assert(value[3] == 1.0f && "Value of dydb does not match."); + assert(value[4] == 1.0f && "Value of dydy does not match."); + if (c_context->safe_math) { + assert(value[5] == 2546248192.0f && "Value of rand does not match."); + } else { + assert(value[5] == 2357136128.0f && "Value of rand does not match."); + } + assert(value[6] == 1.0f && "Value of root does not match."); + assert(value[7] == 4.0f && "Value of p1 does not match."); + assert(value[8] == 8.0f && "Value of p2 does not match."); + break; + } + + case DOUBLE: { + double value[9]; + graph_copy_to_host(c_context, y, value); + graph_copy_to_host(c_context, dydx, value + 1); + graph_copy_to_host(c_context, dydm, value + 2); + graph_copy_to_host(c_context, dydb, value + 3); + graph_copy_to_host(c_context, dydy, value + 4); + graph_copy_to_host(c_context, rand, value + 5); + graph_copy_to_host(c_context, z, value + 6); + graph_copy_to_host(c_context, p1, value + 7); + graph_copy_to_host(c_context, p2, value + 8); + assert(value[0] == 0.5*2.0 + 0.2 && "Value of y does not match."); + assert(value[1] == 0.5 && "Value of dydx does not match."); + assert(value[2] == 2.0 && "Value of dydm does not match."); + assert(value[3] == 1.0 && "Value of dydb does not match."); + assert(value[4] == 1.0 && "Value of dydy does not match."); + if (c_context->safe_math) { + assert(value[5] == 2546248239.0 && "Value of rand does not match."); + } else { + assert(value[5] == 2357136044.0 && "Value of rand does not match."); + } + assert(value[6] == 1.0 && "Value of root does not match."); + assert(value[7] == 4.0 && "Value of p1 does not match."); + assert(value[8] == 8.0 && "Value of p2 does not match."); + break; + } + + case COMPLEX_FLOAT: { + float complex value[9]; + graph_copy_to_host(c_context, y, value); + graph_copy_to_host(c_context, dydx, value + 1); + graph_copy_to_host(c_context, dydm, value + 2); + graph_copy_to_host(c_context, dydb, value + 3); + graph_copy_to_host(c_context, dydy, value + 4); + graph_copy_to_host(c_context, rand, value + 5); + graph_copy_to_host(c_context, z, value + 6); + graph_copy_to_host(c_context, p1, value + 7); + graph_copy_to_host(c_context, p2, value + 8); + assert(crealf(value[0]) == 0.5f*2.0f + 0.2f && "Value of y does not match."); + assert(crealf(value[1]) == 0.5f && "Value of dydx does not match."); + assert(crealf(value[2]) == 2.0f && "Value of dydm does not match."); + assert(crealf(value[3]) == 1.0f && "Value of dydb does not match."); + assert(crealf(value[4]) == 1.0f && "Value of dydy does not match."); + if (c_context->safe_math) { + assert(crealf(value[5]) == 2546248192.0f && "Value of rand does not match."); + } else { + assert(crealf(value[5]) == 2357136128.0f && "Value of rand does not match."); + } + assert(crealf(value[6]) == 1.0f && "Value of root does not match."); + assert(crealf(value[7]) == 4.0f && "Value of p1 does not match."); + assert(crealf(value[8]) == 8.0f && "Value of p2 does not match."); + break; + } + + case COMPLEX_DOUBLE: { + double complex value[9]; + graph_copy_to_host(c_context, y, value); + graph_copy_to_host(c_context, dydx, value + 1); + graph_copy_to_host(c_context, dydm, value + 2); + graph_copy_to_host(c_context, dydb, value + 3); + graph_copy_to_host(c_context, dydy, value + 4); + graph_copy_to_host(c_context, rand, value + 5); + graph_copy_to_host(c_context, z, value + 6); + graph_copy_to_host(c_context, p1, value + 7); + graph_copy_to_host(c_context, p2, value + 8); + assert(creal(value[0]) == 0.5*2.0 + 0.2 && "Value of y does not match."); + assert(creal(value[1]) == 0.5 && "Value of dydx does not match."); + assert(creal(value[2]) == 2.0 && "Value of dydm does not match."); + assert(creal(value[3]) == 1.0 && "Value of dydb does not match."); + assert(creal(value[4]) == 1.0 && "Value of dydy does not match."); + if (c_context->safe_math) { + assert(creal(value[5]) == 2546248239.0 && "Value of rand does not match."); + } else { + assert(creal(value[5]) == 2357136044.0 && "Value of rand does not match."); + } + assert(creal(value[6]) == 1.0 && "Value of root does not match."); + assert(creal(value[7]) == 4.0 && "Value of p1 does not match."); + assert(creal(value[8]) == 8.0 && "Value of p2 does not match."); + break; + } + } + + graph_destroy_context(c_context); +} + +//------------------------------------------------------------------------------ +/// @brief Main program of the test. +/// +/// @param[in] argc Number of commandline arguments. +/// @param[in] argv Array of commandline arguments. +//------------------------------------------------------------------------------ +int main(int argc, const char * argv[]) { + START_GPU + (void)argc; + (void)argv; + + run_tests(FLOAT, false); + run_tests(FLOAT, true); + run_tests(DOUBLE, false); + run_tests(DOUBLE, true); + run_tests(COMPLEX_FLOAT, false); + run_tests(COMPLEX_FLOAT, true); + run_tests(COMPLEX_DOUBLE, false); + run_tests(COMPLEX_DOUBLE, true); + + END_GPU +} diff --git a/graph_tests/f_binding_test.f90 b/graph_tests/f_binding_test.f90 new file mode 100644 index 0000000000000000000000000000000000000000..8fa992d7664b3ad99c4331ba0d19497c000f6d48 --- /dev/null +++ b/graph_tests/f_binding_test.f90 @@ -0,0 +1,778 @@ +!------------------------------------------------------------------------------- +!> @file f_binding_test.f90 +!> @brief Test for fortran bindings. +!------------------------------------------------------------------------------- +!------------------------------------------------------------------------------- +!> @brief Main test program. +!------------------------------------------------------------------------------- + PROGRAM f_binding_test + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +! Define parameters. + LOGICAL(C_BOOL), PARAMETER :: c_true = .true. + LOGICAL(C_BOOL), PARAMETER :: c_false = .false. + +! Start of executable code. + CALL run_test_float(c_true) + CALL run_test_float(c_false) + CALL run_test_double(c_true) + CALL run_test_double(c_false) + CALL run_test_complex_float(c_true) + CALL run_test_complex_float(c_false) + CALL run_test_complex_double(c_true) + CALL run_test_complex_double(c_false) + + END PROGRAM + +!------------------------------------------------------------------------------- +!> @brief Assert check. +!> +!> If the assert check does not pass write error to standard error and exit. +!> +!> @param[in] test The check test. +!> @param[in] message Message to report if check fails. +!------------------------------------------------------------------------------- + SUBROUTINE assert(test, message) + USE, INTRINSIC :: ISO_FORTRAN_ENV, ONLY : error_unit + + IMPLICIT NONE + +! Declare Arguments + LOGICAL, INTENT(IN) :: test + CHARACTER(len=*) :: message + +! Start of executable code. + IF (.not.test) THEN + WRITE(error_unit,*) message + CALL exit(1) + END IF + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run float tests. +!> +!> @param[in] c_type Base type of the calculation. +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + SUBROUTINE run_test_float(use_safe_math) + + USE graph_fortran + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +! Declare Arguments + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Local variables. + CLASS(graph_context), POINTER :: graph + TYPE(C_PTR) :: x + TYPE(C_PTR) :: m + TYPE(C_PTR) :: b + REAL(C_FLOAT), DIMENSION(1) :: value + TYPE(C_PTR) :: px + TYPE(C_PTR) :: y + TYPE(C_PTR) :: dydx + TYPE(C_PTR) :: dydm + TYPE(C_PTR) :: dydb + TYPE(C_PTR) :: dydy + TYPE(C_PTR) :: one + TYPE(C_PTR) :: zero + INTEGER(C_LONG) :: size + TYPE(C_PTR) :: rand + TYPE(C_PTR) :: state + REAL(C_FLOAT), DIMENSION(3) :: buffer1D + TYPE(C_PTR) :: p1 + TYPE(C_PTR) :: i + REAL(C_FLOAT), DIMENSION(3,3) :: buffer2D + TYPE(C_PTR) :: p2 + TYPE(C_PTR) :: j + TYPE(C_PTR) :: z + TYPE(C_PTR) :: root + TYPE(C_PTR) :: root2 + TYPE(C_PTR) :: dz + +! Start of executable code. + graph => graph_float_context(use_safe_math) + + x = graph%variable(1_C_LONG, 'x' // C_NULL_CHAR) + m = graph%constant(0.5_C_DOUBLE) + b = graph%constant(0.2_C_DOUBLE) + + value(1) = 2.0 + CALL graph%set_variable(x, value) + + px = graph%pseudo_variable(x) + CALL assert(graph_ptr(px) .ne. graph_ptr(x), & + 'Expected different nodes.') + CALL assert(graph_ptr(graph%remove_pseudo(px)) .eq. graph_ptr(x), & + 'Remove pseudo failed.') + + y = graph%add(graph%mul(m, x), b) + + dydx = graph%df(y, x); + dydm = graph%df(y, m); + dydb = graph%df(y, b); + dydy = graph%df(y, y); + + one = graph%constant(1.0_C_DOUBLE) + zero = graph%constant(0.0_C_DOUBLE) + + CALL assert(graph_ptr(graph%sub(one, one)) .eq. graph_ptr(zero), & + 'Expected 1 - 1 = 0.') + CALL assert(graph_ptr(graph%div(one, one)) .eq. graph_ptr(one), & + 'Expected 1/1 = 1.') + CALL assert(graph_ptr(graph%sqrt(one)) .eq. graph_ptr(one), & + 'Expected sqrt(1) = 1.') + CALL assert(graph_ptr(graph%exp(zero)) .eq. graph_ptr(one), & + 'Expected exp(0) = 1.') + CALL assert(graph_ptr(graph%log(one)) .eq. graph_ptr(zero), & + 'Expected log(1) = 0.') + CALL assert(graph_ptr(graph%pow(one, zero)) .eq. graph_ptr(one), & + 'Expected pow(1,0) = 1.') + CALL assert(graph_ptr(graph%sin(zero)) .eq. graph_ptr(zero), & + 'Expected sin(0) = 0.') + CALL assert(graph_ptr(graph%cos(zero)) .eq. graph_ptr(one), & + 'Expected cos(0) = 1.') + CALL assert(graph_ptr(graph%atan(one, zero)) .eq. graph_ptr(zero), & + 'Expected atan(one, zero) = zero.') + + state = graph%random_state(0) + rand = graph%random(state) + + i = graph%variable(1_C_LONG, 'i' // C_NULL_CHAR) + value(1) = 1.5 + CALL graph%set_variable(i, value) + buffer1D = (/ 2.0, 4.0, 6.0 /) + p1 = graph%piecewise_1D(i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer1D) + + j = graph%variable(1_C_LONG, 'j' // C_NULL_CHAR) + value(1) = 2.5 + CALL graph%set_variable(j, value) + buffer2D = RESHAPE((/ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 /), & + SHAPE(buffer2D)) + p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & + i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) + root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & + graph%pow(z, graph%constant(2.0_C_DOUBLE))) + root2 = graph%mul(root, root) + dz = graph%sub(z, graph%div(root, graph%df(root, z))) + + CALL graph%set_device_number(graph%get_max_concurrency() - 1) + + CALL graph%add_pre_item(graph_null_array, (/ graph_ptr(rand) /), & + graph_null_array, graph_null_array, state, & + 'f_binding_pre_kernel' // C_NULL_CHAR, & + 1_C_LONG) + CALL graph%add_item((/ graph_ptr(x) /), (/ & + graph_ptr(y), & + graph_ptr(dydx), & + graph_ptr(dydm), & + graph_ptr(dydb), & + graph_ptr(dydy) & + /), graph_null_array, graph_null_array, C_NULL_PTR, & + 'f_binding' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & + (/ graph_ptr(p1), graph_ptr(p2) /), & + graph_null_array, graph_null_array, C_NULL_PTR, & + 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & + (/ graph_ptr(z) /), (/ graph_ptr(dz) /), & + C_NULL_PTR, & + 'f_binding_converge' // C_NULL_CHAR, & + 1_C_LONG, 1.0E-30_C_DOUBLE, 1000_C_LONG) + CALL graph%compile() + value(1) = 10.0 + CALL graph%copy_to_device(z, value) + CALL graph%pre_run() + CALL graph%run() + CALL graph%wait() + CALL graph%print(0_C_LONG, (/ graph_ptr(z), graph_ptr(y) /)) + + CALL graph%copy_to_host(y, value) + CALL assert(value(1) .eq. 0.5_C_FLOAT*2.0_C_FLOAT + 0.2_C_FLOAT, & + 'Value of y does not match.') + CALL graph%copy_to_host(dydx, value) + CALL assert(value(1) .eq. 0.5_C_FLOAT, 'Value of dydx does not match.') + CALL graph%copy_to_host(dydm, value) + CALL assert(value(1) .eq. 2.0_C_FLOAT, 'Value of dydm does not match.') + CALL graph%copy_to_host(dydb, value) + CALL assert(value(1) .eq. 1.0_C_FLOAT, 'Value of dydb does not match.') + CALL graph%copy_to_host(dydy, value) + CALL assert(value(1) .eq. 1.0_C_FLOAT, 'Value of dydy does not match.') + CALL graph%copy_to_host(rand, value) + IF (use_safe_math) THEN + CALL assert(value(1) .eq. 2546248192.0_C_FLOAT, & + 'Value of rand does not match.') + ELSE + CALL assert(value(1) .eq. 2357136128.0_C_FLOAT, & + 'Value of rand does not match.') + END IF + CALL graph%copy_to_host(z, value) + CALL assert(value(1) .eq. 1.0_C_FLOAT, 'Value of root does not match.') + CALL graph%copy_to_host(p1, value) + CALL assert(value(1) .eq. 4.0_C_FLOAT, 'Value of p1 does not match.') + CALL graph%copy_to_host(p2, value) + CALL assert(value(1) .eq. 8.0_C_FLOAT, 'Value of p2 does not match.') + + DEALLOCATE(graph) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run double tests. +!> +!> @param[in] c_type Base type of the calculation. +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + SUBROUTINE run_test_double(use_safe_math) + + USE graph_fortran + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +! Declare Arguments + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Local variables. + CLASS(graph_context), POINTER :: graph + TYPE(C_PTR) :: x + TYPE(C_PTR) :: m + TYPE(C_PTR) :: b + REAL(C_DOUBLE), DIMENSION(1) :: value + TYPE(C_PTR) :: px + TYPE(C_PTR) :: y + TYPE(C_PTR) :: dydx + TYPE(C_PTR) :: dydm + TYPE(C_PTR) :: dydb + TYPE(C_PTR) :: dydy + TYPE(C_PTR) :: one + TYPE(C_PTR) :: zero + INTEGER(C_LONG) :: size + TYPE(C_PTR) :: rand + TYPE(C_PTR) :: state + REAL(C_DOUBLE), DIMENSION(3) :: buffer1D + TYPE(C_PTR) :: p1 + TYPE(C_PTR) :: i + REAL(C_DOUBLE), DIMENSION(3,3) :: buffer2D + TYPE(C_PTR) :: p2 + TYPE(C_PTR) :: j + TYPE(C_PTR) :: z + TYPE(C_PTR) :: root + TYPE(C_PTR) :: root2 + TYPE(C_PTR) :: dz + +! Start of executable code. + graph => graph_double_context(use_safe_math) + + x = graph%variable(1_C_LONG, 'x' // C_NULL_CHAR) + m = graph%constant(0.5_C_DOUBLE) + b = graph%constant(0.2_C_DOUBLE) + + value(1) = 2.0 + CALL graph%set_variable(x, value) + + px = graph%pseudo_variable(x) + CALL assert(graph_ptr(px) .ne. graph_ptr(x), & + 'Expected different nodes.') + CALL assert(graph_ptr(graph%remove_pseudo(px)) .eq. graph_ptr(x), & + 'Remove pseudo failed.') + + y = graph%add(graph%mul(m, x), b) + + dydx = graph%df(y, x); + dydm = graph%df(y, m); + dydb = graph%df(y, b); + dydy = graph%df(y, y); + + one = graph%constant(1.0_C_DOUBLE) + zero = graph%constant(0.0_C_DOUBLE) + + CALL assert(graph_ptr(graph%sub(one, one)) .eq. graph_ptr(zero), & + 'Expected 1 - 1 = 0.') + CALL assert(graph_ptr(graph%div(one, one)) .eq. graph_ptr(one), & + 'Expected 1/1 = 1.') + CALL assert(graph_ptr(graph%sqrt(one)) .eq. graph_ptr(one), & + 'Expected sqrt(1) = 1.') + CALL assert(graph_ptr(graph%exp(zero)) .eq. graph_ptr(one), & + 'Expected exp(0) = 1.') + CALL assert(graph_ptr(graph%log(one)) .eq. graph_ptr(zero), & + 'Expected log(1) = 0.') + CALL assert(graph_ptr(graph%pow(one, zero)) .eq. graph_ptr(one), & + 'Expected pow(1, 0) = 1.') + CALL assert(graph_ptr(graph%sin(zero)) .eq. graph_ptr(zero), & + 'Expected sin(0) = 0.') + CALL assert(graph_ptr(graph%cos(zero)) .eq. graph_ptr(one), & + 'Expected cos(0) = 1.') + CALL assert(graph_ptr(graph%atan(one, zero)) .eq. graph_ptr(zero), & + 'Expected atan(one, zero) = zero.') + + state = graph%random_state(0) + rand = graph%random(state) + + i = graph%variable(1_C_LONG, 'i' // C_NULL_CHAR) + value(1) = 1.5 + CALL graph%set_variable(i, value) + buffer1D = (/ 2.0, 4.0, 6.0 /) + p1 = graph%piecewise_1D(i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer1D) + + j = graph%variable(1_C_LONG, 'j' // C_NULL_CHAR) + value(1) = 2.5 + CALL graph%set_variable(j, value) + buffer2D = RESHAPE((/ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 /), & + SHAPE(buffer2D)) + p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & + i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) + root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & + graph%pow(z, graph%constant(2.0_C_DOUBLE))) + root2 = graph%mul(root, root) + dz = graph%sub(z, graph%div(root, graph%df(root, z))) + + CALL graph%set_device_number(graph%get_max_concurrency() - 1) + + CALL graph%add_pre_item(graph_null_array, (/ graph_ptr(rand) /), & + graph_null_array, graph_null_array, state, & + 'f_binding_pre_kernel' // C_NULL_CHAR, & + 1_C_LONG) + CALL graph%add_item((/ graph_ptr(x) /), (/ & + graph_ptr(y), & + graph_ptr(dydx), & + graph_ptr(dydm), & + graph_ptr(dydb), & + graph_ptr(dydy) & + /), graph_null_array, graph_null_array, C_NULL_PTR, & + 'f_binding' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & + (/ graph_ptr(p1), graph_ptr(p2) /), & + graph_null_array, graph_null_array, C_NULL_PTR, & + 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & + (/ graph_ptr(z) /), (/ graph_ptr(dz) /), & + C_NULL_PTR, & + 'f_binding_converge' // C_NULL_CHAR, & + 1_C_LONG, 1.0E-30_C_DOUBLE, 1000_C_LONG) + CALL graph%compile() + value(1) = 10.0 + CALL graph%copy_to_device(z, value) + CALL graph%pre_run() + CALL graph%run() + CALL graph%wait() + CALL graph%print(0_C_LONG, (/ graph_ptr(z), graph_ptr(y) /)) + + CALL graph%copy_to_host(y, value) + CALL assert(value(1) .eq. 0.5_C_DOUBLE*2.0_C_DOUBLE + 0.2_C_DOUBLE, & + 'Value of y does not match.') + CALL graph%copy_to_host(dydx, value) + CALL assert(value(1) .eq. 0.5_C_DOUBLE, 'Value of dydx does not match.') + CALL graph%copy_to_host(dydm, value) + CALL assert(value(1) .eq. 2.0_C_DOUBLE, 'Value of dydm does not match.') + CALL graph%copy_to_host(dydb, value) + CALL assert(value(1) .eq. 1.0_C_DOUBLE, 'Value of dydb does not match.') + CALL graph%copy_to_host(dydy, value) + CALL assert(value(1) .eq. 1.0_C_DOUBLE, 'Value of dydy does not match.') + CALL graph%copy_to_host(rand, value) + IF (use_safe_math) THEN + CALL assert(value(1) .eq. 2546248239.0_C_DOUBLE, & + 'Value of rand does not match.') + ELSE + CALL assert(value(1) .eq. 2357136044.0_C_DOUBLE, & + 'Value of rand does not match.') + END IF + CALL graph%copy_to_host(z, value) + CALL assert(value(1) .eq. 1.0_C_DOUBLE, 'Value of root does not match.') + CALL graph%copy_to_host(p1, value) + CALL assert(value(1) .eq. 4.0_C_DOUBLE, 'Value of p1 does not match.') + CALL graph%copy_to_host(p2, value) + CALL assert(value(1) .eq. 8.0_C_DOUBLE, 'Value of p2 does not match.') + + DEALLOCATE(graph) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run complex float tests. +!> +!> @param[in] c_type Base type of the calculation. +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + SUBROUTINE run_test_complex_float(use_safe_math) + + USE graph_fortran + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +! Declare Arguments + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Local variables. + CLASS(graph_context), POINTER :: graph + TYPE(C_PTR) :: x + TYPE(C_PTR) :: m + TYPE(C_PTR) :: b + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(1) :: value + TYPE(C_PTR) :: px + TYPE(C_PTR) :: y + TYPE(C_PTR) :: dydx + TYPE(C_PTR) :: dydm + TYPE(C_PTR) :: dydb + TYPE(C_PTR) :: dydy + TYPE(C_PTR) :: one + TYPE(C_PTR) :: zero + INTEGER(C_LONG) :: size + TYPE(C_PTR) :: rand + TYPE(C_PTR) :: state + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(3) :: buffer1D + TYPE(C_PTR) :: p1 + TYPE(C_PTR) :: i + COMPLEX(C_FLOAT_COMPLEX), DIMENSION(3,3) :: buffer2D + TYPE(C_PTR) :: p2 + TYPE(C_PTR) :: j + TYPE(C_PTR) :: z + TYPE(C_PTR) :: root + TYPE(C_PTR) :: root2 + TYPE(C_PTR) :: dz + +! Start of executable code. + graph => graph_complex_float_context(use_safe_math) + + x = graph%variable(1_C_LONG, 'x' // C_NULL_CHAR) + m = graph%constant(0.5_C_DOUBLE) + b = graph%constant(0.2_C_DOUBLE, 0.0_C_DOUBLE) + + value(1) = 2.0 + CALL graph%set_variable(x, value) + + px = graph%pseudo_variable(x) + CALL assert(graph_ptr(px) .ne. graph_ptr(x), & + 'Expected different nodes.') + CALL assert(graph_ptr(graph%remove_pseudo(px)) .eq. graph_ptr(x), & + 'Remove pseudo failed.') + + y = graph%add(graph%mul(m, x), b) + + dydx = graph%df(y, x); + dydm = graph%df(y, m); + dydb = graph%df(y, b); + dydy = graph%df(y, y); + + one = graph%constant(1.0_C_DOUBLE) + zero = graph%constant(0.0_C_DOUBLE) + + CALL assert(graph_ptr(graph%sub(one, one)) .eq. graph_ptr(zero), & + 'Expected 1 - 1 = 0.') + CALL assert(graph_ptr(graph%div(one, one)) .eq. graph_ptr(one), & + 'Expected 1/1 = 1.') + CALL assert(graph_ptr(graph%sqrt(one)) .eq. graph_ptr(one), & + 'Expected sqrt(1) = 1.') + CALL assert(graph_ptr(graph%exp(zero)) .eq. graph_ptr(one), & + 'Expected exp(0) = 1.') + CALL assert(graph_ptr(graph%log(one)) .eq. graph_ptr(zero), & + 'Expected log(1) = 0.') + CALL assert(graph_ptr(graph%pow(one, zero)) .eq. graph_ptr(one), & + 'Expected pow(1,0) = 1.') + CALL assert(graph_ptr(graph%erfi(zero)) .eq. graph_ptr(zero), & + 'Expected erfi(0) = 0.') + CALL assert(graph_ptr(graph%sin(zero)) .eq. graph_ptr(zero), & + 'Expected sin(0) = 0.') + CALL assert(graph_ptr(graph%cos(zero)) .eq. graph_ptr(one), & + 'Expected cos(0) = 1.') + CALL assert(graph_ptr(graph%atan(one, zero)) .eq. graph_ptr(zero), & + 'Expected atan(one, zero) = zero.') + + state = graph%random_state(0) + rand = graph%random(state) + + i = graph%variable(1_C_LONG, 'i' // C_NULL_CHAR) + value(1) = 1.5 + CALL graph%set_variable(i, value) + buffer1D = (/ CMPLX(2.0, 0.0), CMPLX(4.0, 0.0), CMPLX(6.0, 0.0) /) + p1 = graph%piecewise_1D(i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer1D) + + j = graph%variable(1_C_LONG, 'j' // C_NULL_CHAR) + value(1) = 2.5 + CALL graph%set_variable(j, value) + buffer2D = RESHAPE((/ CMPLX(1.0, 0.0), CMPLX(2.0, 0.0), CMPLX(3.0, 0.0), & + CMPLX(4.0, 0.0), CMPLX(5.0, 0.0), CMPLX(6.0, 0.0), & + CMPLX(7.0, 0.0), CMPLX(8.0, 0.0), CMPLX(9.0, 0.0) & + /), SHAPE(buffer2D)) + p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & + i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) + root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & + graph%pow(z, graph%constant(2.0_C_DOUBLE))) + root2 = graph%mul(root, root) + dz = graph%sub(z, graph%div(root, graph%df(root, z))) + + CALL graph%set_device_number(graph%get_max_concurrency() - 1) + + CALL graph%add_pre_item(graph_null_array, (/ graph_ptr(rand) /), & + graph_null_array, graph_null_array, state, & + 'c_binding_pre_kernel' // C_NULL_CHAR, & + 1_C_LONG) + CALL graph%add_item((/ graph_ptr(x) /), (/ & + graph_ptr(y), & + graph_ptr(dydx), & + graph_ptr(dydm), & + graph_ptr(dydb), & + graph_ptr(dydy) & + /), graph_null_array, graph_null_array, C_NULL_PTR, & + 'f_binding' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & + (/ graph_ptr(p1), graph_ptr(p2) /), & + graph_null_array, graph_null_array, C_NULL_PTR, & + 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & + (/ graph_ptr(z) /), (/ graph_ptr(dz) /), & + C_NULL_PTR, & + 'f_binding_converge' // C_NULL_CHAR, & + 1_C_LONG, 1.0E-30_C_DOUBLE, 1000_C_LONG) + CALL graph%compile() + value(1) = 10.0 + CALL graph%copy_to_device(z, value) + CALL graph%pre_run() + CALL graph%run() + CALL graph%wait() + CALL graph%print(0_C_LONG, (/ graph_ptr(z), graph_ptr(y) /)) + + CALL graph%copy_to_host(y, value) + CALL assert(REAL(value(1)) .eq. 0.5_C_FLOAT*2.0_C_FLOAT + 0.2_C_FLOAT, & + 'Value of y does not match.') + CALL graph%copy_to_host(dydx, value) + CALL assert(REAL(value(1)) .eq. 0.5_C_FLOAT, & + 'Value of dydx does not match.') + CALL graph%copy_to_host(dydm, value) + CALL assert(REAL(value(1)) .eq. 2.0_C_FLOAT, & + 'Value of dydm does not match.') + CALL graph%copy_to_host(dydb, value) + CALL assert(REAL(value(1)) .eq. 1.0_C_FLOAT, & + 'Value of dydb does not match.') + CALL graph%copy_to_host(dydy, value) + CALL assert(REAL(value(1)) .eq. 1.0_C_FLOAT, & + 'Value of dydy does not match.') + CALL graph%copy_to_host(rand, value) + IF (use_safe_math) THEN + CALL assert(REAL(value(1)) .eq. 2546248192.0_C_FLOAT, & + 'Value of rand does not match.') + ELSE + CALL assert(REAL(value(1)) .eq. 2357136128.0_C_FLOAT, & + 'Value of rand does not match.') + END IF + CALL graph%copy_to_host(z, value) + CALL assert(REAL(value(1)) .eq. 1.0_C_FLOAT, & + 'Value of root does not match.') + CALL graph%copy_to_host(p1, value) + CALL assert(REAL(value(1)) .eq. 4.0_C_FLOAT, & + 'Value of p1 does not match.') + CALL graph%copy_to_host(p2, value) + CALL assert(REAL(value(1)) .eq. 8.0_C_FLOAT, & + 'Value of p2 does not match.') + + DEALLOCATE(graph) + + END SUBROUTINE + +!------------------------------------------------------------------------------- +!> @brief Run double tests. +!> +!> @param[in] c_type Base type of the calculation. +!> @param[in] use_safe_math Use safe math. +!------------------------------------------------------------------------------- + SUBROUTINE run_test_complex_double(use_safe_math) + + USE graph_fortran + USE, INTRINSIC :: ISO_C_BINDING + + IMPLICIT NONE + +! Declare Arguments + LOGICAL(C_BOOL), INTENT(IN) :: use_safe_math + +! Local variables. + CLASS(graph_context), POINTER :: graph + TYPE(C_PTR) :: x + TYPE(C_PTR) :: m + TYPE(C_PTR) :: b + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(1) :: value + TYPE(C_PTR) :: px + TYPE(C_PTR) :: y + TYPE(C_PTR) :: dydx + TYPE(C_PTR) :: dydm + TYPE(C_PTR) :: dydb + TYPE(C_PTR) :: dydy + TYPE(C_PTR) :: one + TYPE(C_PTR) :: zero + INTEGER(C_LONG) :: size + TYPE(C_PTR) :: rand + TYPE(C_PTR) :: state + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(3) :: buffer1D + TYPE(C_PTR) :: p1 + TYPE(C_PTR) :: i + COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(3,3) :: buffer2D + TYPE(C_PTR) :: p2 + TYPE(C_PTR) :: j + TYPE(C_PTR) :: z + TYPE(C_PTR) :: root + TYPE(C_PTR) :: root2 + TYPE(C_PTR) :: dz + +! Start of executable code. + graph => graph_complex_double_context(use_safe_math) + + x = graph%variable(1_C_LONG, 'x' // C_NULL_CHAR) + m = graph%constant(0.5_C_DOUBLE) + b = graph%constant(0.2_C_DOUBLE, 0.0_C_DOUBLE) + + value(1) = 2.0 + CALL graph%set_variable(x, value) + + px = graph%pseudo_variable(x) + CALL assert(graph_ptr(px) .ne. graph_ptr(x), & + 'Expected different nodes.') + CALL assert(graph_ptr(graph%remove_pseudo(px)) .eq. graph_ptr(x), & + 'Remove pseudo failed.') + + y = graph%add(graph%mul(m, x), b) + + dydx = graph%df(y, x); + dydm = graph%df(y, m); + dydb = graph%df(y, b); + dydy = graph%df(y, y); + + one = graph%constant(1.0_C_DOUBLE) + zero = graph%constant(0.0_C_DOUBLE) + + CALL assert(graph_ptr(graph%sub(one, one)) .eq. graph_ptr(zero), & + 'Expected 1 - 1 = 0.') + CALL assert(graph_ptr(graph%div(one, one)) .eq. graph_ptr(one), & + 'Expected 1/1 = 1.') + CALL assert(graph_ptr(graph%sqrt(one)) .eq. graph_ptr(one), & + 'Expected sqrt(1) = 1.') + CALL assert(graph_ptr(graph%exp(zero)) .eq. graph_ptr(one), & + 'Expected exp(0) = 1.') + CALL assert(graph_ptr(graph%log(one)) .eq. graph_ptr(zero), & + 'Expected log(1) = 0.') + CALL assert(graph_ptr(graph%pow(one, zero)) .eq. graph_ptr(one), & + 'Expected pow(1,0) = 1.') + CALL assert(graph_ptr(graph%erfi(zero)) .eq. graph_ptr(zero), & + 'Expected erfi(0) = 0.') + CALL assert(graph_ptr(graph%sin(zero)) .eq. graph_ptr(zero), & + 'Expected sin(0) = 0.') + CALL assert(graph_ptr(graph%cos(zero)) .eq. graph_ptr(one), & + 'Expected cos(0) = 1.') + CALL assert(graph_ptr(graph%atan(one, zero)) .eq. graph_ptr(zero), & + 'Expected atan(one, zero) = zero.') + + state = graph%random_state(0) + rand = graph%random(state) + + i = graph%variable(1_C_LONG, 'i' // C_NULL_CHAR) + value(1) = 1.5 + CALL graph%set_variable(i, value) + buffer1D = (/ & + CMPLX(2.0, 0.0, KIND=C_DOUBLE), & + CMPLX(4.0, 0.0, KIND=C_DOUBLE), & + CMPLX(6.0, 0.0, KIND=C_DOUBLE) & + /) + p1 = graph%piecewise_1D(i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer1D) + + j = graph%variable(1_C_LONG, 'j' // C_NULL_CHAR) + value(1) = 2.5 + CALL graph%set_variable(j, value) + buffer2D = RESHAPE((/ & + CMPLX(1.0, 0.0, KIND=C_DOUBLE), & + CMPLX(2.0, 0.0, KIND=C_DOUBLE), & + CMPLX(3.0, 0.0, KIND=C_DOUBLE), & + CMPLX(4.0, 0.0, KIND=C_DOUBLE), & + CMPLX(5.0, 0.0, KIND=C_DOUBLE), & + CMPLX(6.0, 0.0, KIND=C_DOUBLE), & + CMPLX(7.0, 0.0, KIND=C_DOUBLE), & + CMPLX(8.0, 0.0, KIND=C_DOUBLE), & + CMPLX(9.0, 0.0, KIND=C_DOUBLE) & + /), SHAPE(buffer2D)) + p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & + i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) + root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & + graph%pow(z, graph%constant(2.0_C_DOUBLE))) + root2 = graph%mul(root, root) + dz = graph%sub(z, graph%div(root, graph%df(root, z))) + + CALL graph%set_device_number(graph%get_max_concurrency() - 1) + + CALL graph%add_pre_item(graph_null_array, (/ graph_ptr(rand) /), & + graph_null_array, graph_null_array, state, & + 'f_binding_pre_kernel' // C_NULL_CHAR, & + 1_C_LONG) + CALL graph%add_item((/ graph_ptr(x) /), (/ & + graph_ptr(y), & + graph_ptr(dydx), & + graph_ptr(dydm), & + graph_ptr(dydb), & + graph_ptr(dydy) & + /), graph_null_array, graph_null_array, C_NULL_PTR, & + 'f_binding' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & + (/ graph_ptr(p1), graph_ptr(p2) /), & + graph_null_array, graph_null_array, C_NULL_PTR, & + 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) + CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & + (/ graph_ptr(z) /), (/ graph_ptr(dz) /), & + C_NULL_PTR, & + 'f_binding_converge' // C_NULL_CHAR, & + 1_C_LONG, 1.0E-30_C_DOUBLE, 1000_C_LONG) + CALL graph%compile() + value(1) = 10.0 + CALL graph%copy_to_device(z, value) + CALL graph%pre_run() + CALL graph%run() + CALL graph%wait() + CALL graph%print(0_C_LONG, (/ graph_ptr(z), graph_ptr(y) /)) + + CALL graph%copy_to_host(y, value) + CALL assert(DBLE(value(1)) .eq. 0.5_C_DOUBLE*2.0_C_DOUBLE + & + 0.2_C_DOUBLE, & + 'Value of y does not match.') + CALL graph%copy_to_host(dydx, value) + CALL assert(DBLE(value(1)) .eq. 0.5_C_DOUBLE, & + 'Value of dydx does not match.') + CALL graph%copy_to_host(dydm, value) + CALL assert(DBLE(value(1)) .eq. 2.0_C_DOUBLE, & + 'Value of dydm does not match.') + CALL graph%copy_to_host(dydb, value) + CALL assert(DBLE(value(1)) .eq. 1.0_C_DOUBLE, & + 'Value of dydb does not match.') + CALL graph%copy_to_host(dydy, value) + CALL assert(DBLE(value(1)) .eq. 1.0_C_DOUBLE, & + 'Value of dydy does not match.') + CALL graph%copy_to_host(rand, value) + IF (use_safe_math) THEN + CALL assert(DBLE(value(1)) .eq. 2546248239.0_C_DOUBLE, & + 'Value of rand does not match.') + ELSE + CALL assert(DBLE(value(1)) .eq. 2357136044.0_C_DOUBLE, & + 'Value of rand does not match.') + END IF + CALL graph%copy_to_host(z, value) + CALL assert(DBLE(value(1)) .eq. 1.0_C_DOUBLE, & + 'Value of root does not match.') + CALL graph%copy_to_host(p1, value) + CALL assert(DBLE(value(1)) .eq. 4.0_C_DOUBLE, & + 'Value of p1 does not match.') + CALL graph%copy_to_host(p2, value) + CALL assert(DBLE(value(1)) .eq. 8.0_C_DOUBLE, & + 'Value of p2 does not match.') + + DEALLOCATE(graph) + + END SUBROUTINE diff --git a/graph_tests/node_test.cpp b/graph_tests/node_test.cpp index 04ff017e4fc92059f61b25526d4db5a4a91a311e..1d1470ab8ff2e64fa3d43edac80314faafb28272 100644 --- a/graph_tests/node_test.cpp +++ b/graph_tests/node_test.cpp @@ -36,7 +36,7 @@ void test_constant() { auto dzero = zero->df(zero); auto dzero_cast = graph::constant_cast(dzero); assert(dzero_cast.get() && "Expected a constant type for derivative."); - assert(dzero_cast->is(0.0) && "Constant value expeced zero."); + assert(dzero_cast->is(1.0) && "Constant value expeced one."); zero->set(static_cast (1.0)); assert(zero_cast->is(0.0) && "Constant value expeced zero.");