Commit 207d4be4 authored by Arthur O'Dwyer's avatar Arthur O'Dwyer
Browse files

[libc++] [P0879] constexpr std::nth_element, and rewrite its tests.

This patch is more than just adding the `constexpr` keyword, because
the old code relied on `goto`, and `goto` is not constexpr-friendly.
Refactor to eliminate `goto`, and then mark it as constexpr in C++20.

I freely admit that the name `__nth_element_partloop` is bad;
I couldn't find any better name because I don't really know
what this loop is doing, conceptually. Vice versa, I think
`__nth_element_find_guard` has a decent name.

Now the only one we're still missing from P0879 is `sort`.

Differential Revision: https://reviews.llvm.org/D93557
parent 9cbef8c9
Loading
Loading
Loading
Loading
+92 −71
Original line number Diff line number Diff line
@@ -385,11 +385,11 @@ template <class InputIterator, class RandomAccessIterator, class Compare>
                      RandomAccessIterator result_first, RandomAccessIterator result_last, Compare comp);

template <class RandomAccessIterator>
    void
    constexpr void                    // constexpr in C++20
    nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last);

template <class RandomAccessIterator, class Compare>
    void
    constexpr void                    // constexpr in C++20
    nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last, Compare comp);

template <class ForwardIterator, class T>
@@ -3833,7 +3833,7 @@ is_sorted(_ForwardIterator __first, _ForwardIterator __last)
// stable, 2-3 compares, 0-2 swaps

template <class _Compare, class _ForwardIterator>
unsigned
_LIBCPP_CONSTEXPR_AFTER_CXX11 unsigned
__sort3(_ForwardIterator __x, _ForwardIterator __y, _ForwardIterator __z, _Compare __c)
{
    unsigned __r = 0;
@@ -3927,7 +3927,7 @@ __sort5(_ForwardIterator __x1, _ForwardIterator __x2, _ForwardIterator __x3,

// Assumes size > 0
template <class _Compare, class _BidirectionalIterator>
void
_LIBCPP_CONSTEXPR_AFTER_CXX11 void
__selection_sort(_BidirectionalIterator __first, _BidirectionalIterator __last, _Compare __comp)
{
    _BidirectionalIterator __lm1 = __last;
@@ -5280,7 +5280,69 @@ partial_sort_copy(_InputIterator __first, _InputIterator __last,
// nth_element

template<class _Compare, class _RandomAccessIterator>
void
_LIBCPP_CONSTEXPR_AFTER_CXX11 bool
__nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
                         _RandomAccessIterator& __m, _Compare __comp)
{
    // manually guard downward moving __j against __i
    while (--__j != __i)
    {
        if (__comp(*__j, *__m))
        {
            return true;
        }
    }
    return false;
}

template<class _Compare, class _RandomAccessIterator>
_LIBCPP_CONSTEXPR_AFTER_CXX11 bool
__nth_element_partloop(_RandomAccessIterator __first, _RandomAccessIterator __last,
                       _RandomAccessIterator& __i, _RandomAccessIterator& __j,
                       unsigned& __n_swaps, _Compare __comp)
{
    // *__first == *__m, *__m <= all other elements
    // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
    ++__i;  // __first + 1
    __j = __last;
    if (!__comp(*__first, *--__j))  // we need a guard if *__first == *(__last-1)
    {
        while (true)
        {
            if (__i == __j)
                return true;  // [__first, __last) all equivalent elements
            if (__comp(*__first, *__i))
            {
                swap(*__i, *__j);
                ++__n_swaps;
                ++__i;
                break;
            }
            ++__i;
        }
    }
    // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
    if (__i == __j)
        return true;
    while (true)
    {
        while (!__comp(*__first, *__i))
            ++__i;
        while (__comp(*__first, *--__j))
            ;
        if (__i >= __j)
            break;
        swap(*__i, *__j);
        ++__n_swaps;
        ++__i;
    }
    // [__first, __i) == *__first and *__first < [__i, __last)
    // The first part is sorted.
    return false;
}

template <class _Compare, class _RandomAccessIterator>
_LIBCPP_CONSTEXPR_AFTER_CXX11 void
__nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
{
    // _Compare is known to be a reference type
@@ -5288,7 +5350,7 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
    const difference_type __limit = 7;
    while (true)
    {
    __restart:
        // __restart: -- this is the target of a "continue" below
        if (__nth == __last)
            return;
        difference_type __len = __last - __first;
@@ -5328,61 +5390,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
        if (!__comp(*__i, *__m))  // if *__first == *__m
        {
            // *__first == *__m, *__first doesn't go in first part
            // manually guard downward moving __j against __i
            while (true)
            {
                if (__i == --__j)
                {
                    // *__first == *__m, *__m <= all other elements
                    // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
                    ++__i;  // __first + 1
                    __j = __last;
                    if (!__comp(*__first, *--__j))  // we need a guard if *__first == *(__last-1)
                    {
                        while (true)
                        {
                            if (__i == __j)
                                return;  // [__first, __last) all equivalent elements
                            if (__comp(*__first, *__i))
                            {
            if (_VSTD::__nth_element_find_guard<_Compare>(__i, __j, __m, __comp)) {
                swap(*__i, *__j);
                ++__n_swaps;
                                ++__i;
                                break;
                            }
                            ++__i;
                        }
                    }
                    // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
                    if (__i == __j)
                // found guard for downward moving __j, now use unguarded partition
            } else if (_VSTD::__nth_element_partloop<_Compare>(__first, __last, __i, __j, __n_swaps, __comp)) {
                return;
                    while (true)
                    {
                        while (!__comp(*__first, *__i))
                            ++__i;
                        while (__comp(*__first, *--__j))
                            ;
                        if (__i >= __j)
                            break;
                        swap(*__i, *__j);
                        ++__n_swaps;
                        ++__i;
                    }
                    // [__first, __i) == *__first and *__first < [__i, __last)
                    // The first part is sorted,
                    if (__nth < __i)
            } else if (__nth < __i) {
                return;
            } else {
                // __nth_element the second part
                // _VSTD::__nth_element<_Compare>(__i, __nth, __last, __comp);
                __first = __i;
                    goto __restart;
                }
                if (__comp(*__j, *__m))
                {
                    swap(*__i, *__j);
                    ++__n_swaps;
                    break;  // found guard for downward moving __j, now use unguarded partition
                }
                continue;  // i.e., goto __restart
            }
        }
        ++__i;
@@ -5426,15 +5446,16 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
            {
                // Check for [__first, __i) already sorted
                __j = __m = __first;
                while (++__j != __i)
                while (true)
                {
                    if (++__j == __i)
                        // [__first, __i) sorted
                        return;
                    if (__comp(*__j, *__m))
                        // not yet sorted, so sort
                        goto not_sorted;
                        break;
                    __m = __j;
                }
                // [__first, __i) sorted
                return;
            }
            else
            {
@@ -5442,16 +5463,16 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
                __j = __m = __i;
                while (++__j != __last)
                {
                    if (++__j == __last)
                        // [__i, __last) sorted
                        return;
                    if (__comp(*__j, *__m))
                        // not yet sorted, so sort
                        goto not_sorted;
                        break;
                    __m = __j;
                }
                // [__i, __last) sorted
                return;
            }
        }
not_sorted:
        // __nth_element on range containing __nth
        if (__nth < __i)
        {
@@ -5467,7 +5488,7 @@ not_sorted:
}

template <class _RandomAccessIterator, class _Compare>
inline _LIBCPP_INLINE_VISIBILITY
inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
void
nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
{
@@ -5476,7 +5497,7 @@ nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomA
}

template <class _RandomAccessIterator>
inline _LIBCPP_INLINE_VISIBILITY
inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
void
nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last)
{
+45 −43
Original line number Diff line number Diff line
@@ -9,61 +9,63 @@
// <algorithm>

// template<RandomAccessIterator Iter>
//   requires ShuffleIterator<Iter>
//         && LessThanComparable<Iter::value_type>
//   void
//   requires ShuffleIterator<Iter> && LessThanComparable<Iter::value_type>
//   constexpr void  // constexpr in C++20
//   nth_element(Iter first, Iter nth, Iter last);

#include <algorithm>
#include <random>
#include <cassert>

#include "test_macros.h"
#include "test_iterators.h"
#include "MoveOnly.h"

std::mt19937 randomness;

void
test_one(int N, int M)
template<class T, class Iter>
TEST_CONSTEXPR_CXX20 bool test()
{
    assert(N != 0);
    assert(M < N);
    int* array = new int[N];
    for (int i = 0; i < N; ++i)
        array[i] = i;
    std::shuffle(array, array+N, randomness);
    std::nth_element(array, array+M, array+N);
    assert(array[M] == M);
    std::nth_element(array, array+N, array+N); // begin, end, end
    delete [] array;
    int orig[15] = {3,1,4,1,5, 9,2,6,5,3, 5,8,9,7,9};
    T work[15] = {3,1,4,1,5, 9,2,6,5,3, 5,8,9,7,9};
    for (int n = 0; n < 15; ++n) {
        for (int m = 0; m < n; ++m) {
            std::nth_element(Iter(work), Iter(work+m), Iter(work+n));
            assert(std::is_permutation(work, work+n, orig));
            // No element to m's left is greater than m.
            for (int i = 0; i < m; ++i) {
                assert(!(work[i] > work[m]));
            }
            // No element to m's right is less than m.
            for (int i = m; i < n; ++i) {
                assert(!(work[i] < work[m]));
            }
            std::copy(orig, orig+15, work);
        }
    }

void
test(int N)
    {
    test_one(N, 0);
    test_one(N, 1);
    test_one(N, 2);
    test_one(N, 3);
    test_one(N, N/2-1);
    test_one(N, N/2);
    test_one(N, N/2+1);
    test_one(N, N-3);
    test_one(N, N-2);
    test_one(N, N-1);
        T input[] = {3,1,4,1,5,9,2};
        std::nth_element(Iter(input), Iter(input+4), Iter(input+7));
        assert(input[4] == 4);
        assert(input[5] + input[6] == 5 + 9);
    }
    return true;
}

int main(int, char**)
{
    int d = 0;
    std::nth_element(&d, &d, &d);
    assert(d == 0);
    test(256);
    test(257);
    test(499);
    test(500);
    test(997);
    test(1000);
    test(1009);
    test<int, random_access_iterator<int*> >();
    test<int, int*>();

#if TEST_STD_VER >= 11
    test<MoveOnly, random_access_iterator<MoveOnly*>>();
    test<MoveOnly, MoveOnly*>();
#endif

#if TEST_STD_VER >= 20
    static_assert(test<int, random_access_iterator<int*>>());
    static_assert(test<int, int*>());
    static_assert(test<MoveOnly, random_access_iterator<MoveOnly*>>());
    static_assert(test<MoveOnly, MoveOnly*>());
#endif

    return 0;
}
+43 −61
Original line number Diff line number Diff line
@@ -9,81 +9,63 @@
// <algorithm>

// template<RandomAccessIterator Iter, StrictWeakOrder<auto, Iter::value_type> Compare>
//   requires ShuffleIterator<Iter>
//         && CopyConstructible<Compare>
//   void
//   requires ShuffleIterator<Iter> && CopyConstructible<Compare>
//   constexpr void  // constexpr in C++20
//   nth_element(Iter first, Iter nth, Iter last, Compare comp);

#include <algorithm>
#include <functional>
#include <vector>
#include <random>
#include <cassert>
#include <cstddef>
#include <memory>
#include <functional>

#include "test_macros.h"
#include "test_iterators.h"
#include "MoveOnly.h"

struct indirect_less
template<class T, class Iter>
TEST_CONSTEXPR_CXX20 bool test()
{
    template <class P>
    bool operator()(const P& x, const P& y)
        {return *x < *y;}
};

std::mt19937 randomness;

void
test_one(int N, int M)
{
    assert(N != 0);
    assert(M < N);
    int* array = new int[N];
    for (int i = 0; i < N; ++i)
        array[i] = i;
    std::shuffle(array, array+N, randomness);
    std::nth_element(array, array+M, array+N, std::greater<int>());
    assert(array[M] == N-M-1);
    std::nth_element(array, array+N, array+N, std::greater<int>()); // begin, end, end
    delete [] array;
    int orig[15] = {3,1,4,1,5, 9,2,6,5,3, 5,8,9,7,9};
    T work[15] = {3,1,4,1,5, 9,2,6,5,3, 5,8,9,7,9};
    for (int n = 0; n < 15; ++n) {
        for (int m = 0; m < n; ++m) {
            std::nth_element(Iter(work), Iter(work+m), Iter(work+n), std::greater<T>());
            assert(std::is_permutation(work, work+n, orig));
            // No element to m's left is less than m.
            for (int i = 0; i < m; ++i) {
                assert(!(work[i] < work[m]));
            }
            // No element to m's right is greater than m.
            for (int i = m; i < n; ++i) {
                assert(!(work[i] > work[m]));
            }
            std::copy(orig, orig+15, work);
        }
    }

void
test(int N)
    {
    test_one(N, 0);
    test_one(N, 1);
    test_one(N, 2);
    test_one(N, 3);
    test_one(N, N/2-1);
    test_one(N, N/2);
    test_one(N, N/2+1);
    test_one(N, N-3);
    test_one(N, N-2);
    test_one(N, N-1);
        T input[] = {3,1,4,1,5,9,2};
        std::nth_element(Iter(input), Iter(input+4), Iter(input+7), std::greater<T>());
        assert(input[4] == 2);
        assert(input[5] + input[6] == 1 + 1);
    }
    return true;
}

int main(int, char**)
{
    int d = 0;
    std::nth_element(&d, &d, &d);
    assert(d == 0);
    test(256);
    test(257);
    test(499);
    test(500);
    test(997);
    test(1000);
    test(1009);
    test<int, random_access_iterator<int*> >();
    test<int, int*>();

#if TEST_STD_VER >= 11
    {
    std::vector<std::unique_ptr<int> > v(1000);
    for (int i = 0; static_cast<std::size_t>(i) < v.size(); ++i)
        v[i].reset(new int(i));
    std::nth_element(v.begin(), v.begin() + v.size()/2, v.end(), indirect_less());
    assert(static_cast<std::size_t>(*v[v.size()/2]) == v.size()/2);
    }
    test<MoveOnly, random_access_iterator<MoveOnly*>>();
    test<MoveOnly, MoveOnly*>();
#endif

#if TEST_STD_VER >= 20
    static_assert(test<int, random_access_iterator<int*>>());
    static_assert(test<int, int*>());
    static_assert(test<MoveOnly, random_access_iterator<MoveOnly*>>());
    static_assert(test<MoveOnly, MoveOnly*>());
#endif

    return 0;