1. 程式人生 > >【Go 原始碼分析】從 sort.go 看排序演算法的工程實踐

【Go 原始碼分析】從 sort.go 看排序演算法的工程實踐

go version go1.11 darwin/amd64file: src/sort/sort.go

排序演算法有很多種類,比如快排、堆排、插入排序等。各種排序演算法各有其優劣性,在實際生產過程中用到的排序演算法(或者說 Sort 函式)通常是由幾種排序演算法組合而成的。通過分析 sort.go 原始檔,我們一起看一下 go 語言的排序演算法實踐。

不穩定排序演算法

不穩定排序演算法指的是 不保證排序後相同大小元素的原始次序不變的排序演算法

基本思想

首先是入口函式 Sort(data Interface),以及 Interface 介面的定義。

// 滿足 sort.Interface 的型別(比如各種資料集合)可以使用 sort 包中的 Sort 函式進行排序。
// 集合中的元素可以被數字型下標列舉
type Interface interface {
   // 集合中元素的數量
   Len() int
   // Less 函式判斷下標 i 的元素是否應該放在下標 j 的前面
   Less(i, j int) bool
   // Swap 函式交換下標 i j 對應的元素
   Swap(i, j int)
}

func Sort(data Interface) {
   n := data.Len()
   quickSort(data, 0, n, maxDepth(n))
}

其中,maxDepth 是快排遞迴的最大深度,其取值為 2*ceil(lg(n+1))

func maxDepth(n int) int {
   var depth int
   for i := n; i > 0; i >>= 1 {
      depth++
   }
   return depth * 2
}

入口的 Sort 函式呼叫的 quickSort 並不完全是快排。

quickSort 函式的整體框架是快排:當切片資料量較大時,使用快排把資料分割成兩個子問題(doPivot),把較小規模的子問題進行遞迴,較大規模的子問題繼續迭代(實現上的一種 trick,相當於遞迴,只不過少了一層函式呼叫),如果迭代或遞迴的深度超過 maxDepth

,則使用堆排序;當切片資料量較小(<= 12)時,採用希爾排序法。

quickSort

// 該函式會把 data[a, b) 區間的元素進行排序,下面稱該區間為切片 slice
func quickSort(data Interface, a, b, maxDepth int) {
   // 如果切片長度不大於 12 ,則使用希爾排序,否則,使用下面的方法排序
   for b-a > 12 {
      if maxDepth == 0 { // 如果遞迴到最大深度,則使用堆排序
         heapSort(data, a, b)
         return
      }
      maxDepth--
      // doPivot 是快排核心演算法,它取一點為軸,把不大於軸的元素放左邊,大於軸的元素放右邊,返回小於軸部分資料的最後一個下標,以及大於軸部分資料的第一個下標
      // 下標位置 a...mlo,pivot,mhi...b
      // data[a...mlo] <= data[pivot]
      // data[mhi...b] > data[pivot]
      mlo, mhi := doPivot(data, a, b)
      // 避免較大規模的子問題遞迴呼叫,保證棧深度最大為 maxDepth
      // 解釋:因為迴圈肯定比遞迴呼叫節省時間,但是兩個子問題只能一個進行迴圈,另一個只能用遞迴。這裡是把較小規模的子問題進行遞迴,較大規模子問題進行迴圈。
      if mlo-a < b-mhi {
         quickSort(data, a, mlo, maxDepth)
         a = mhi // 相當於 quickSort(data, mhi, b)
      } else {
         quickSort(data, mhi, b, maxDepth)
         b = mlo // 相當於 quickSort(data, a, mlo)
      }
   }
   
   // 較小資料集使用希爾排序
   // 第一次步長為 6,第二次步長為 1(其實就是插入排序了)
   if b-a > 1 {
      // Do ShellSort pass with gap 6
      // It could be written in this simplified form cause b-a <= 12
      for i := a + 6; i < b; i++ {
         if data.Less(i, i-6) {
            data.Swap(i, i-6)
         }
      }
      insertionSort(data, a, b)
   }
}

插入排序

插入排序的思想比較簡單:把資料分為已排序(左)和未排序(右)的兩部分,每次取未排序的第一個值,放到已排序部分中正確的地方。

InsertionSort

// Insertion sort
func insertionSort(data Interface, a, b int) {
   for i := a + 1; i < b; i++ {
      for j := i; j > a && data.Less(j, j-1); j-- {
         data.Swap(j, j-1)
      }
   }
}

堆排序

一般來說,堆排序的第一步是構建最大堆,第二步是從堆頂取出當前堆最大元素,與堆尾交換,並使堆大小減1;迴圈第二步,直到堆中沒有元素。

sort.go 中堆排序的核心函式是 siftDown(data Interface, lo, hi, first int),它用於維護(和構建)最大堆的性質。

// siftDown 維護了切片 data[lo, hi) 的最大堆性質
// first 是 lo hi 相當於陣列的偏移
func siftDown(data Interface, lo, hi, first int) {
   root := lo
   for {
      child := 2*root + 1    // 左孩子節點下標
      if child >= hi {       // 如果左孩子超出切片,則 break
         break
      }
      // child + 1 是右孩子節點
      // 以下部分程式碼會把root、左孩子及右孩子節點中的最大值調換到 root 位置
      if child+1 < hi && data.Less(first+child, first+child+1) {
         child++
      }
      if !data.Less(first+root, first+child) {
         return  // 如果 root 位置已經是最大值,則直接 return
      }
      // 如果 root 不是最大值,則把最大值調換到 root,並以調換了的 child 為 root 繼續迴圈
      data.Swap(first+root, first+child)
      root = child
   }
}

func heapSort(data Interface, a, b int) {
   first := a
   lo := 0
   hi := b - a

   // 從堆底構建最大堆
   for i := (hi - 1) / 2; i >= 0; i-- {
      siftDown(data, i, hi, first)
   }

   // 把堆頂元素移動到尾部,並繼續維護最大堆的性質
   for i := hi - 1; i >= 0; i-- {
      data.Swap(first, first+i)
      siftDown(data, lo, i, first)
   }
}

快速排序之陣列切分

快速排序的核心程式碼是切片切分,即把切片根據選定的軸切分成兩部分(不大於軸的部分,和大於軸的部分)。

瞭解快排的朋友可能知道,快排最壞時間複雜度是 O(n**2)。最壞情況是每次切分的切片極不均衡,可能全是大於軸的部分,也可能全是不大於軸的部分。所以選擇合適的軸是很必要的。

doPivot 在切分之前,先使用 medianOfThree 函式選擇一個肯定不是最大和最小的值作為軸,放在了切片首位。然後把不小於 data[pivot] 的資料放在了 [lo, b) 區間,把大於 data[pivot] 的資料放在了 (c, hi-1] 區間(其中 data[hi-1] >= data[pivot])。

之後,該演算法又估算了等於 data[pivot] 的數量,如果數量過多,則把與 data[pivot] 相等的資料放到了中間部分 區間為(b, c-1)。最後把 data[pivot] 交換到了 b-1 的位置。

至此,資料被切分成三個區間。data[lo, b-1)data[b-1, c)data[c, hi)

medianOfThree

// medianOfThree 函式把 data[m0,m1,m2] 的中間值移動到了 m1 的位置
// 同時使三個值的大小順序為 data[m0] <= data[m1] <= data[m2]
func medianOfThree(data Interface, m1, m0, m2 int) {
   // sort 3 elements
   if data.Less(m1, m0) {
      data.Swap(m1, m0)
   }
   // data[m0] <= data[m1]
   if data.Less(m2, m1) {
      data.Swap(m2, m1)
      // data[m0] <= data[m2] && data[m1] < data[m2]
      if data.Less(m1, m0) {
         data.Swap(m1, m0)
      }
   }
   // now data[m0] <= data[m1] <= data[m2]
}

其中,!data.Less(i, j) 可以看做 data[i] >= data[j]

doPivot

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
   m := int(uint(lo+hi) >> 1) // trick:避免整型溢位的
   if hi-lo > 40 {
      // Tukey's ``Ninther,'' median of three medians of three.
      s := (hi - lo) / 8
      medianOfThree(data, lo, lo+s, lo+2*s)
      medianOfThree(data, m, m-s, m+s)
      medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
   }
   medianOfThree(data, lo, m, hi-1)

   // 以下程式碼達成目標為:
   // data[lo] = pivot (set up by ChoosePivot)
   // data[lo < i < a] < pivot
   // data[a <= i < b] <= pivot
   // data[b <= i < c] unexamined
   // data[c <= i < hi-1] > pivot
   // data[hi-1] >= pivot
   pivot := lo
   a, c := lo+1, hi-1

   for ; a < c && data.Less(a, pivot); a++ {
   }
   b := a
   for {
      for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
      }
      for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
      }
      if b >= c {
         break
      }
      // data[b] > pivot; data[c-1] <= pivot
      data.Swap(b, c-1)
      b++
      c--
   }
   // If hi-c<3 then there are duplicates (by property of median of nine).
   // Let be a bit more conservative, and set border to 5.
   protect := hi-c < 5
   if !protect && hi-c < (hi-lo)/4 {
      // Lets test some points for equality to pivot
      dups := 0
      if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
         data.Swap(c, hi-1)
         c++
         dups++
      }
      if !data.Less(b-1, pivot) { // data[b-1] = pivot
         b--
         dups++
      }
      // m-lo = (hi-lo)/2 > 6
      // b-lo > (hi-lo)*3/4-1 > 8
      // ==> m < b ==> data[m] <= pivot
      if !data.Less(m, pivot) { // data[m] = pivot
         data.Swap(m, b-1)
         b--
         dups++
      }
      // if at least 2 points are equal to pivot, assume skewed distribution
      protect = dups > 1
   }
   if protect {
      // Protect against a lot of duplicates
      // Add invariant:
      // data[a <= i < b] unexamined
      // data[b <= i < c] = pivot
      for {
         for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
         }
         for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
         }
         if a >= b {
            break
         }
         // data[a] == pivot; data[b-1] < pivot
         data.Swap(a, b-1)
         a++
         b--
      }
   }
   // Swap pivot into middle
   data.Swap(pivot, b-1)
   return b - 1, c
}

以上是不穩定排序演算法的實現。

穩定排序演算法

穩定排序演算法保持相等元素的原始次序。

go 中使用的穩定排序演算法為 symMerge,暫時稱為歸併排序吧,雖然跟我在《演算法導論》上看到過的歸併排序演算法不一樣。

這裡用到的歸併排序演算法是一種原址排序演算法:首先,它把切片按照每 blockSize:=20 個元素為一個切片,進行插入排序;迴圈合併相鄰的兩個 block,每次迴圈 blockSize 擴大二倍,直到 blockSize > n 為止。

func Stable(data Interface) {
   stable(data, data.Len())
}

func stable(data Interface, n int) {
   blockSize := 20 // 初始 blockSize 設定為 20
   a, b := 0, blockSize
   // 對每個塊(以及剩餘不足blockSize的一個塊)進行查詢排序
   for b <= n {
      insertionSort(data, a, b)
      a = b
      b += blockSize
   }
   insertionSort(data, a, n)

   for blockSize < n {
      a, b = 0, 2*blockSize
      // 每兩個 blockSize 進行合併
      for b <= n {
         symMerge(data, a, a+blockSize, b)
         a = b
         b += 2 * blockSize
      }
      // 剩餘一個多 blockSize 進行合併
      if m := a + blockSize; m < n {
         symMerge(data, a, m, n)
      }
      blockSize *= 2
   }
}

symMerge 是一種原址合併演算法,

func symMerge(data Interface, a, m, b int) {
   // 為了避免不必要的遞迴,當 data[a:m] 或者 data[m:b] 只有一個元素時,直接插入到另一個子陣列中的對應位置。

   if m-a == 1 {
      // 使用二分查詢,找到合適的位置,並插入資料
      i := m
      j := b
      for i < j {
         h := int(uint(i+j) >> 1)
         if data.Less(h, a) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[a] reaches the position before i.
      for k := a; k < i-1; k++ {
         data.Swap(k, k+1)
      }
      return
   }

   // 同上
   // Avoid unnecessary recursions of symMerge
   // by direct insertion of data[m] into data[a:m]
   // if data[m:b] only contains one element.
   if b-m == 1 {
      // Use binary search to find the lowest index i
      // such that data[i] > data[m] for a <= i < m.
      // Exit the search loop with i == m in case no such index exists.
      i := a
      j := m
      for i < j {
         h := int(uint(i+j) >> 1)
         if !data.Less(m, h) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[m] reaches the position i.
      for k := m; k > i; k-- {
         data.Swap(k, k-1)
      }
      return
   }

   mid := int(uint(a+b) >> 1)
   n := mid + m
   var start, r int
   if m > mid {
      start = n - b
      r = mid
   } else {
      start = a
      r = m
   }
   p := n - 1

   for start < r {
      c := int(uint(start+r) >> 1)
      if !data.Less(p-c, c) {
         start = c + 1
      } else {
         r = c
      }
   }

   end := n - start
   if start < m && m < end {
      rotate(data, start, m, end)
   }
   if a < start && start < mid {
      symMerge(data, a, start, mid)
   }
   if mid < end && end < b {
      symMerge(data, mid, end, b)
   }
   
   // 寫在後面
   // 上面這段大致意思是從兩個子切片相鄰位置找到合適的區間進行旋轉然後對旋轉後得到的子切片遞迴合併。具體真沒看懂。
}

以及 rotate 的實現:

// 假設兩個切片為 u = data[a:m] v = data[m:b]
// 整個資料為 'x u v y',則 rotate 會把資料旋轉為 'x v u y'
func rotate(data Interface, a, m, b int) {
   i := m - a
   j := b - m

   for i != j {
      if i > j {
         swapRange(data, m-i, m, j)
         i -= j
      } else {
         swapRange(data, m-i, m+j-i, i)
         j -= i
      }
   }
   // i == j
   swapRange(data, m-i, m, i)
}

func swapRange(data Interface, a, b, n int) {
   for i := 0; i < n; i++ {
      data.Swap(a+i, b+i)
   }
}

以上是穩定排序 Stable 的全部程式碼。

使用方法

sort.go 中完成了基礎資料型別的 Interface 實現。比如 []int 型別。

type IntSlice []int

func (p IntSlice) Len() int           { return len(p) }
func (p IntSlice) Less(i, j int) bool { return p[i] < p[j] }
func (p IntSlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }

對於基本型別來說,可以直接像這樣使用。

a := []int{4, 1, 3, 7, 4, 2, 6, 3, 5, 6}
sort.Sort(sort.IntSlice(a))
fmt.Println(a)

對於複雜資料型別來說,只要實現了 sort.Interface 介面,即可使用 sort.Sort 或者 sort.Stable 函式進行排序了。