1. 程式人生 > >快速選擇(quick_select) 演算法分析

快速選擇(quick_select) 演算法分析

快速選擇演算法,就是從給定的一個集合S={a1,a2,...an}中選出第K個大小的數,或者給出其所在的下標之類的。

如果使用排序,比如merge_sort,然後返回第K個元素的下標,複雜度是O(NlogN)

如果使用heap_sort,或者優先佇列,則複雜度是O(NlogK)

如果使用quick _sort的一個變種,叫 quick select,則平均複雜度為O(N),最壞複雜度為O(N^2)

如果使用一種線性選擇演算法,則可以達到最壞O(N)的複雜度,不過實際應用中,該演算法通常比quick select慢1到2倍,所以並不常用(參考Blum, Floyd, Pratt, Rivest, and Tarjan 1973  Time bounds for selection)

演算法思想:

(1)利用快速排序的分治思想,求得待搜尋陣列按照的主元S[q](pivot)(主元的選定有好幾種方法,這裡不詳細討論,可參考快速排序),以主元為界分成左右兩個區間

(2)通過比較主元的位置,判斷第K個大小的數在主元左區間?在主元又區間?還是就是主元?(還要注意邊界條件的判斷,有可能在邊界)

(3)進入子區間遞迴呼叫


這裡實現了stl風格的quick select,僅僅作為一個mark

#include <algorithm>
#include <cassert>
namespace algorithm
{
template<typename _Tp>
const _Tp& choose_pivot(const _Tp& x, const _Tp& y, const _Tp& z)
{
        if( (x < y && y < z)||(z < y && y < x) )
                return y;     
        else if( (z < x && x < y)||(y < x && x < z) )
                return x;
        else
                return z;
}
template<typename _Tp,typename _Compare>
const _Tp& choose_pivot(const _Tp& x, const _Tp& y,const _Tp& z, _Compare comp)
{
        if( (comp(x,y) && comp(y,z))||(comp(z,y)&&comp(y,x)) )
                return y;
        else if( (comp(z,x) && comp(x,y))||(comp(y,x)&&comp(x,z)))
                return x;
        return z;
}
template<typename _RandomAccessIterator,typename _Tp>
_RandomAccessIterator quick_partition(_RandomAccessIterator first,
                _RandomAccessIterator last,_Tp pivot)
{
        while( true ){
                while( *first < pivot )       
                        ++first;
                --last;
                while( pivot < *last )
                        --last;
                if( first >= last )
                        return first;
                std::swap(*first,*last);
                ++first;
        }
}
template<typename _RandomAccessIterator,typename _Tp, typename _Compare>
_RandomAccessIterator quick_partition(_RandomAccessIterator first,
        _RandomAccessIterator last, _Tp pivot, _Compare comp)
{
        while( true ){
                while( comp(*first,pivot) == true )   
                        ++first;
                --last;
                while( comp(pivot,*last) == true )
                        --last;
                if( first >= last )
                        return first;
                std::swap(*first,*last);
                ++first;        
        }
}

template<typename _RandomAccessIterator>
_RandomAccessIterator quick_select(_RandomAccessIterator first, _RandomAccessIterator last, size_t kth)
{
        typedef typename std::iterator_traits<_RandomAccessIterator>::value_type _ValueType;
        typedef typename std::iterator_traits<_RandomAccessIterator>::difference_type _DistanceType;
        if( first == last || last-first <=(_DistanceType)kth )//out of range
                return last;
        _ValueType pivot;
        _RandomAccessIterator mid;
        while( true )  
        {
                if( kth == 0 )        
                        return std::min_element(first,last);
                else if( first+kth == last - 1 )
                        return std::max_element(first,last);
                else{
                        mid = first+(last-first)/2;
                        pivot = choose_pivot(*first,*mid,*(last-1));
                        mid = quick_partition(first,last,pivot);
                        if( mid-first > (_DistanceType)kth )
                                last = mid;     
                        else{
                                kth -= mid-first;               
                                first = mid;
                        }
                }
                assert( last-first > (_DistanceType)kth);
        }
}

template<typename _RandomAccessIterator,typename _Compare>
_RandomAccessIterator quick_select(_RandomAccessIterator first, _RandomAccessIterator last, size_t kth,_Compare comp)
{
        typedef typename std::iterator_traits<_RandomAccessIterator>::value_type _ValueType;
        typedef typename std::iterator_traits<_RandomAccessIterator>::difference_type _DistanceType;
        if( first == last || last-first <=(_DistanceType)kth )//out of range
                return last;
        _ValueType pivot;
        _RandomAccessIterator mid;
        while( true )  
        {
                if( kth == 0 )        
                        return std::min_element(first,last,comp);
                else if( first+kth == last - 1 )
                        return std::max_element(first,last,comp);
                else{
                        mid = first+(last-first)/2;
                        pivot = choose_pivot(*first,*mid,*(last-1),comp);
                        mid = quick_partition(first,last,pivot,comp);
                        if( mid-first > (_DistanceType)kth )
                                last = mid;     
                        else{
                                kth -= mid-first;               
                                first = mid;
                        }
                }
                assert( last-first > (_DistanceType)kth);
        }
}
} //namespace