Commit a282f181 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Merge branch 'cuda_test_fix' into 'main'

Fix test case when using cuda backend.

See merge request !33
parents da6fe5e8 1d6183a3
Loading
Loading
Loading
Loading
+36 −18
Original line number Diff line number Diff line
@@ -68,6 +68,7 @@ template<jit::float_scalar T> void compile(graph::input_nodes<T> inputs,
    source.copy_to_host(outputs.back(), &result);

    const T diff = std::abs(result - expected);
    std::cout << expected << " " << result << " " << diff << " " << tolarance << std::endl;
    check(diff, tolarance);
}

@@ -208,10 +209,15 @@ template<jit::float_scalar T> void piecewise_1D() {
                 graph::variable_cast(b)},
                {graph::fma(p1, p3, p2)}, {},
                static_cast<T> (10.0), 0.0);
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(a)},
                    {graph::pow(p1, p3)}, {},
                static_cast<T> (std::pow(static_cast<T> (2.0),
                                         static_cast<T> (4.0))), 0.0);
                    static_cast<T> (16.0), 2.0E-15);
    } else {
        compile<T> ({graph::variable_cast(a)},
                    {graph::pow(p1, p3)}, {},
                    static_cast<T> (16.0), 0.0);
    }
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(a)},
                    {graph::atan(p1, p3)}, {},
@@ -430,11 +436,17 @@ template<jit::float_scalar T> void piecewise_2D() {
                 graph::variable_cast(by)},
                {graph::fma(p1, p3, p2)}, {},
                static_cast<T> (14.0), 0.0);
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},
                    {graph::pow(p1, p3)}, {},
                static_cast<T> (std::pow(static_cast<T> (2.0),
                                         static_cast<T> (4.0))), 0.0);
                     static_cast<T> (16.0), 2.0E-15);
    } else {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},
                    {graph::pow(p1, p3)}, {},
                     static_cast<T> (16.0), 0.0);
    }
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},
@@ -525,11 +537,17 @@ template<jit::float_scalar T> void piecewise_2D() {
                 graph::variable_cast(by)},
                {graph::fma(p1, p5, p2)}, {},
                static_cast<T> (14.0), 0.0);
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},
                    {graph::pow(p1, p5)}, {},
                static_cast<T> (std::pow(static_cast<T> (2.0),
                                         static_cast<T> (4.0))), 0.0);
                    static_cast<T> (16.0), 2.0E-15);
    } else {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},
                    {graph::pow(p1, p5)}, {},
                    static_cast<T> (16.0), 0.0);
    }
    if constexpr (jit::is_complex<T> ()) {
        compile<T> ({graph::variable_cast(ax),
                     graph::variable_cast(ay)},