Commit a492f112 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add inital tests for the JIT. Fix some bugs found along the way.

parent a314916c
Loading
Loading
Loading
Loading
+87 −0
Original line number Diff line number Diff line
@@ -19,6 +19,8 @@
		C7E5648128A2A2EE000F31A2 /* node_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7931E6F28073BCA0033B488 /* node_test.cpp */; };
		C7E5648D28A2A333000F31A2 /* vector_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7E7D02F283565A200E09896 /* vector_test.cpp */; };
		C7E5649928A2A360000F31A2 /* physics_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C725CD792840088000D0EDE2 /* physics_test.cpp */; };
		C7FA0E0A29590F9100A31E4D /* jit_test.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C7FA0DFD29590B7400A31E4D /* jit_test.cpp */; };
		C7FA0E0B29590F9F00A31E4D /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C71342682947F36100672AD4 /* Metal.framework */; };
/* End PBXBuildFile section */

/* Begin PBXCopyFilesBuildPhase section */
@@ -103,6 +105,15 @@
			);
			runOnlyForDeploymentPostprocessing = 1;
		};
		C7FA0E0129590EF300A31E4D /* CopyFiles */ = {
			isa = PBXCopyFilesBuildPhase;
			buildActionMask = 2147483647;
			dstPath = /usr/share/man/man1/;
			dstSubfolderSpec = 0;
			files = (
			);
			runOnlyForDeploymentPostprocessing = 1;
		};
/* End PBXCopyFilesBuildPhase section */

/* Begin PBXFileReference section */
@@ -145,6 +156,8 @@
		C7E5648628A2A324000F31A2 /* vector_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = vector_test; sourceTree = BUILT_PRODUCTS_DIR; };
		C7E5649228A2A34A000F31A2 /* physics_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = physics_test; sourceTree = BUILT_PRODUCTS_DIR; };
		C7E7D02F283565A200E09896 /* vector_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = vector_test.cpp; sourceTree = "<group>"; };
		C7FA0DFD29590B7400A31E4D /* jit_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; fileEncoding = 4; path = jit_test.cpp; sourceTree = "<group>"; };
		C7FA0E0329590EF300A31E4D /* jit_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = jit_test; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
@@ -219,6 +232,14 @@
			);
			runOnlyForDeploymentPostprocessing = 0;
		};
		C7FA0E0029590EF300A31E4D /* Frameworks */ = {
			isa = PBXFrameworksBuildPhase;
			buildActionMask = 2147483647;
			files = (
				C7FA0E0B29590F9F00A31E4D /* Metal.framework in Frameworks */,
			);
			runOnlyForDeploymentPostprocessing = 0;
		};
/* End PBXFrameworksBuildPhase section */

/* Begin PBXGroup section */
@@ -255,6 +276,7 @@
				C7E5647A28A2A2DA000F31A2 /* node_test */,
				C7E5648628A2A324000F31A2 /* vector_test */,
				C7E5649228A2A34A000F31A2 /* physics_test */,
				C7FA0E0329590EF300A31E4D /* jit_test */,
			);
			name = Products;
			sourceTree = "<group>";
@@ -303,6 +325,7 @@
				C7931E6F28073BCA0033B488 /* node_test.cpp */,
				C7E7D02F283565A200E09896 /* vector_test.cpp */,
				C725CD792840088000D0EDE2 /* physics_test.cpp */,
				C7FA0DFD29590B7400A31E4D /* jit_test.cpp */,
			);
			path = graph_tests;
			sourceTree = SOURCE_ROOT;
@@ -492,6 +515,23 @@
			productReference = C7E5649228A2A34A000F31A2 /* physics_test */;
			productType = "com.apple.product-type.tool";
		};
		C7FA0E0229590EF300A31E4D /* jit_test */ = {
			isa = PBXNativeTarget;
			buildConfigurationList = C7FA0E0729590EF300A31E4D /* Build configuration list for PBXNativeTarget "jit_test" */;
			buildPhases = (
				C7FA0DFF29590EF300A31E4D /* Sources */,
				C7FA0E0029590EF300A31E4D /* Frameworks */,
				C7FA0E0129590EF300A31E4D /* CopyFiles */,
			);
			buildRules = (
			);
			dependencies = (
			);
			name = jit_test;
			productName = jit_test;
			productReference = C7FA0E0329590EF300A31E4D /* jit_test */;
			productType = "com.apple.product-type.tool";
		};
/* End PBXNativeTarget section */

/* Begin PBXProject section */
@@ -531,6 +571,9 @@
					C7E5649128A2A34A000F31A2 = {
						CreatedOnToolsVersion = 13.4;
					};
					C7FA0E0229590EF300A31E4D = {
						CreatedOnToolsVersion = 14.1;
					};
				};
			};
			buildConfigurationList = C79141A122DA9BF200E0BA0D /* Build configuration list for PBXProject "graph_framework" */;
@@ -556,6 +599,7 @@
				C7E5647928A2A2DA000F31A2 /* node_test */,
				C7E5648528A2A324000F31A2 /* vector_test */,
				C7E5649128A2A34A000F31A2 /* physics_test */,
				C7FA0E0229590EF300A31E4D /* jit_test */,
			);
		};
/* End PBXProject section */
@@ -640,6 +684,14 @@
			);
			runOnlyForDeploymentPostprocessing = 0;
		};
		C7FA0DFF29590EF300A31E4D /* Sources */ = {
			isa = PBXSourcesBuildPhase;
			buildActionMask = 2147483647;
			files = (
				C7FA0E0A29590F9100A31E4D /* jit_test.cpp in Sources */,
			);
			runOnlyForDeploymentPostprocessing = 0;
		};
/* End PBXSourcesBuildPhase section */

/* Begin XCBuildConfiguration section */
@@ -1012,6 +1064,32 @@
			};
			name = Release;
		};
		C7FA0E0829590EF300A31E4D /* Debug */ = {
			isa = XCBuildConfiguration;
			buildSettings = {
				CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
				"CODE_SIGN_IDENTITY[sdk=macosx*]" = "-";
				CODE_SIGN_STYLE = Automatic;
				GCC_PREPROCESSOR_DEFINITIONS = (
					"DEBUG=1",
					"$(inherited)",
				);
				MACOSX_DEPLOYMENT_TARGET = 12.6;
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Debug;
		};
		C7FA0E0929590EF300A31E4D /* Release */ = {
			isa = XCBuildConfiguration;
			buildSettings = {
				CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
				"CODE_SIGN_IDENTITY[sdk=macosx*]" = "-";
				CODE_SIGN_STYLE = Automatic;
				MACOSX_DEPLOYMENT_TARGET = 12.6;
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Release;
		};
/* End XCBuildConfiguration section */

/* Begin XCConfigurationList section */
@@ -1114,6 +1192,15 @@
			defaultConfigurationIsVisible = 0;
			defaultConfigurationName = Release;
		};
		C7FA0E0729590EF300A31E4D /* Build configuration list for PBXNativeTarget "jit_test" */ = {
			isa = XCConfigurationList;
			buildConfigurations = (
				C7FA0E0829590EF300A31E4D /* Debug */,
				C7FA0E0929590EF300A31E4D /* Release */,
			);
			defaultConfigurationIsVisible = 0;
			defaultConfigurationName = Release;
		};
/* End XCConfigurationList section */
	};
	rootObject = C791419E22DA9BF200E0BA0D /* Project object */;
+1 −1
Original line number Diff line number Diff line
@@ -241,7 +241,7 @@ namespace jit {
            GPU_CONTEXT context;
            context.create_pipeline(source_buffer.str(), name,
                                    inputs, outputs,
                                    num_rays, num_steps, 0);
                                    num_rays, num_steps + 1, 0);
            
            const timeing::measure_diagnostic gpu_time("GPU Time");

+1 −1
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ namespace jit {
    template<class NODE>
    std::string to_string(const char prefix,
                          const NODE *pointer) {
        assert(prefix == 'r' || prefix == 'v' &&
        assert((prefix == 'r' || prefix == 'v' || prefix == 'o' ) &&
               "Expected a variable (v) or register (r) prefix.");
        std::stringstream stream;
        stream << prefix << "_" << reinterpret_cast<size_t> (pointer);
+1 −0
Original line number Diff line number Diff line
@@ -6,3 +6,4 @@ add_test_target (solver_test)
add_test_target (backend_test)
add_test_target (vector_test)
add_test_target (physics_test)
add_test_target (jit_test)
+139 −0
Original line number Diff line number Diff line
//------------------------------------------------------------------------------
///  @file jit_test.cpp
///  @brief Tests for the jit code.
//------------------------------------------------------------------------------

//  Turn on asserts even in release builds.
#ifdef NDEBUG
#undef NDEBUG
#endif

#include <cassert>

#include "../graph_framework/jit.hpp"
#include "../graph_framework/math.hpp"

template<typename BASE> void compile(const std::string name,
                                     graph::input_nodes<backend::cpu<BASE>> inputs,
                                     graph::output_nodes<backend::cpu<BASE>> outputs,
                                     graph::map_nodes<backend::cpu<BASE>> setters) {
    for (auto output : outputs) {
        output->to_latex();
        std::cout << std::endl;
    }

    jit::kernel<backend::cpu<BASE>> source(name, inputs, outputs, setters);
    
    source.compile(name, inputs, outputs, 1, 1);
    
    source.print();
}

//------------------------------------------------------------------------------
///  @brief Run tests with a specified backend.
//------------------------------------------------------------------------------
template<typename BASE> void run_tests() {
    auto v1 = graph::variable<backend::cpu<BASE>> (1, "v1");
    auto v2 = graph::variable<backend::cpu<BASE>> (1, "v2");
    auto v3 = graph::variable<backend::cpu<BASE>> (1, "v3");
    
    v1->set(2.0);
    v2->set(3.0);
    v3->set(4.0);
    
    auto add_node = v1 + v2;
    
    compile<BASE> ("add_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {add_node}, {});
    
    auto subtract_node = v1 - v2;
    
    compile<BASE> ("subtract_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {subtract_node}, {});
    
    auto multiply_node = v1*v2;
    
    compile<BASE> ("multiply_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {multiply_node}, {});
    
    auto divide_node = v1/v2;
    
    compile<BASE> ("divide_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {divide_node}, {});
    
    auto fma_node = graph::fma(v1, v2, v3);

    compile<BASE> ("fma_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2),
                    graph::variable_cast(v3)},
                   {fma_node}, {});

    auto sqrt_node = graph::sqrt(v2);
    
    compile<BASE> ("sqrt_kernel",
                   {graph::variable_cast(v2)},
                   {sqrt_node}, {});

    auto log_node = graph::log(v1);
    
    compile<BASE> ("log_kernel",
                   {graph::variable_cast(v1)},
                   {log_node}, {});

    auto exp_node = graph::exp(v2);
    
    compile<BASE> ("exp_kernel",
                   {graph::variable_cast(v2)},
                   {exp_node}, {});

    auto pow_node = graph::pow(v1, v2);
    
    compile<BASE> ("pow_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {pow_node}, {});

    auto divide_node_df = divide_node->df(v3);

    compile<BASE> ("divide_df_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {divide_node_df}, {});

    v1->set(0.0);
    v2->set(0.0);
    
    compile<BASE> ("divide_by_zero_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2)},
                   {divide_node}, {});
    
    v3->set(0.0);
    
    auto multiply_divide_node = v3*divide_node;
    
    compile<BASE> ("multiply_divide_by_zero_kernel",
                   {graph::variable_cast(v1),
                    graph::variable_cast(v2),
                    graph::variable_cast(v3)},
                   {multiply_divide_node}, {});
}

//------------------------------------------------------------------------------
///  @brief Main program of the test.
///
///  @param[in] argc Number of commandline arguments.
///  @param[in] argv Array of commandline arguments.
//------------------------------------------------------------------------------
int main(int argc, const char * argv[]) {
    run_tests<float> ();
}
Loading