Commit 04d2017e authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Merge branch 'debug_fix' into 'main'

Fix an error where the constant was built with the wrong time in atan nodes with literal constants.

See merge request !44
parents fe10e4af b7f081a9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -815,7 +815,7 @@ namespace graph {
    template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
    shared_leaf<T, SAFE_MATH> atan(const L l,
                                  shared_leaf<T, SAFE_MATH> r) {
        return atan(constant<T, SAFE_MATH> (static_cast<L> (l)), r);
        return atan(constant<T, SAFE_MATH> (static_cast<T> (l)), r);
    }

//------------------------------------------------------------------------------
@@ -831,7 +831,7 @@ namespace graph {
    template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
    shared_leaf<T, SAFE_MATH> atan(shared_leaf<T, SAFE_MATH> l,
                                  const R r) {
        return atan(l, constant<T, SAFE_MATH> (static_cast<R> (r)));
        return atan(l, constant<T, SAFE_MATH> (static_cast<T> (r)));
    }

///  Convenience type alias for shared add nodes.
+0 −1
Original line number Diff line number Diff line
@@ -76,7 +76,6 @@ template<std::floating_point T> void test_erfi(const T tolarance) {
                       "Real parts don't match.");
            }
        } else if (!std::isinf(std::real(test)) && !std::isinf(std::imag(test))) {
            std::cout << std::abs(static_cast<T> (1) - test/gold) << std::endl;
            assert(std::abs(static_cast<T> (1) - test/gold) <= tolarance &&
                   "Results don't match.");
        }
+0 −2
Original line number Diff line number Diff line
@@ -59,7 +59,6 @@ template<jit::float_scalar T> void compile(graph::input_nodes<T> inputs,
    source.add_kernel("test_kernel", inputs, outputs, setters);

    source.compile();
    source.print_source();

    auto run = source.create_kernel_call("test_kernel", inputs, outputs, 1);
    run();
@@ -68,7 +67,6 @@ template<jit::float_scalar T> void compile(graph::input_nodes<T> inputs,
    source.copy_to_host(outputs.back(), &result);

    const T diff = std::abs(result - expected);
    std::cout << expected << " " << result << " " << diff << " " << tolarance << std::endl;
    check(diff, tolarance);
}

+2 −2
Original line number Diff line number Diff line
@@ -85,11 +85,11 @@ template<jit::float_scalar T> void test_tan() {
///  @tparam T Base type of the calculation.
//------------------------------------------------------------------------------
template<jit::float_scalar T> void test_atan() {
    assert(graph::constant_cast(graph::atan(graph::constant(static_cast<T> (10.0)),
    assert(graph::constant_cast(graph::atan(10.0,
                                            graph::constant(static_cast<T> (11.0)))).get() &&
           "Expected constant");
    assert(graph::constant_cast(graph::atan(graph::zero<T> (),
                                            graph::constant(static_cast<T> (11.0)))).get() &&
                                            11.0)).get() &&
           "Expected constant");

    auto x = graph::variable<T> (1, "");