1. 程式人生 > >C# ConcurrentBag的實現原理

C# ConcurrentBag的實現原理

基礎 滿足 represent safe 接下來 finally 頭指針 tlist 進行

目錄

  • 一、前言
  • 二、ConcurrentBag類
  • 三、 ConcurrentBag線程安全實現原理
    • 1. ConcurrentBag的私有字段
    • 2. 用於數據存儲的TrehadLocalList類
    • 3. ConcurrentBag實現新增元素
    • 4. ConcurrentBag 如何實現叠代器模式
  • 四、總結
  • 筆者水平有限,如果錯誤歡迎各位批評指正!


一、前言

筆者最近在做一個項目,項目中為了提升吞吐量,使用了消息隊列,中間實現了生產消費模式,在生產消費者模式中需要有一個集合,來存儲生產者所生產的物品,筆者使用了最常見的List<T>集合類型。

由於生產者線程有很多個,消費者線程也有很多個,所以不可避免的就產生了線程同步的問題。開始筆者是使用lock

關鍵字,進行線程同步,但是性能並不是特別理想,然後有網友說可以使用SynchronizedList<T>來代替使用List<T>達到線程安全的目的。於是筆者就替換成了SynchronizedList<T>,但是發現性能依舊糟糕,於是查看了SynchronizedList<T>的源代碼,發現它就是簡單的在List<T>提供的API的基礎上加了lock,所以性能基本與筆者實現方式相差無幾。

最後筆者找到了解決的方案,使用ConcurrentBag<T>類來實現,性能有很大的改觀,於是筆者查看了ConcurrentBag<T>

的源代碼,實現非常精妙,特此在這記錄一下。

二、ConcurrentBag類

ConcurrentBag<T>實現了IProducerConsumerCollection<T>接口,該接口主要用於生產者消費者模式下,可見該類基本就是為生產消費者模式定制的。然後還實現了常規的IReadOnlyCollection<T>類,實現了該類就需要實現IEnumerable<T>、IEnumerable、 ICollection類。

ConcurrentBag<T>對外提供的方法沒有List<T>那麽多,但是同樣有Enumerable

實現的擴展方法。類本身提供的方法如下所示。

名稱 說明
Add 將對象添加到 ConcurrentBag 中。
CopyTo 從指定數組索引開始,將 ConcurrentBag 元素復制到現有的一維 Array 中。
Equals(Object) 確定指定的 Object 是否等於當前的 Object。 (繼承自 Object。)
Finalize 允許對象在“垃圾回收”回收之前嘗試釋放資源並執行其他清理操作。 (繼承自 Object。)
GetEnumerator 返回循環訪問 ConcurrentBag 的枚舉器。
GetHashCode 用作特定類型的哈希函數。 (繼承自 Object。)
GetType 獲取當前實例的 Type。 (繼承自 Object。)
MemberwiseClone 創建當前 Object 的淺表副本。 (繼承自 Object。)
ToArray 將 ConcurrentBag 元素復制到新數組。
ToString 返回表示當前對象的字符串。 (繼承自 Object。)
TryPeek 嘗試從 ConcurrentBag 返回一個對象但不移除該對象。
TryTake 嘗試從 ConcurrentBag 中移除並返回對象。

三、 ConcurrentBag線程安全實現原理

1. ConcurrentBag的私有字段

ConcurrentBag線程安全實現主要是通過它的數據存儲的結構和細顆粒度的鎖。

   public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T>
    {
        // ThreadLocalList對象包含每個線程的數據
        ThreadLocal<ThreadLocalList> m_locals;
 
        // 這個頭指針和尾指針指向中的第一個和最後一個本地列表,這些本地列表分散在不同線程中
        // 允許在線程局部對象上枚舉
        volatile ThreadLocalList m_headList, m_tailList;
 
        // 這個標誌是告知操作線程必須同步操作
        // 在GlobalListsLock 鎖中 設置
        bool m_needSync;

}

首選我們來看它聲明的私有字段,其中需要註意的是集合的數據是存放在ThreadLocal線程本地存儲中的。也就是說訪問它的每個線程會維護一個自己的集合數據列表,一個集合中的數據可能會存放在不同線程的本地存儲空間中,所以如果線程訪問自己本地存儲的對象,那麽是沒有問題的,這就是實現線程安全的第一層,使用線程本地存儲數據

然後可以看到ThreadLocalList m_headList, m_tailList;這個是存放著本地列表對象的頭指針和尾指針,通過這兩個指針,我們就可以通過遍歷的方式來訪問所有本地列表。它使用volatile修飾,所以它是線程安全的。

最後又定義了一個標誌,這個標誌告知操作線程必須進行同步操作,這是實現了一個細顆粒度的鎖,因為只有在幾個條件滿足的情況下才需要進行線程同步。

2. 用於數據存儲的TrehadLocalList類

接下來我們來看一下ThreadLocalList類的構造,該類就是實際存儲了數據的位置。實際上它是使用雙向鏈表這種結構進行數據存儲。

[Serializable]
// 構造了雙向鏈表的節點
internal class Node
{
    public Node(T value)
    {
        m_value = value;
    }
    public readonly T m_value;
    public Node m_next;
    public Node m_prev;
}

/// <summary>
/// 集合操作類型
/// </summary>
internal enum ListOperation
{
    None,
    Add,
    Take
};

/// <summary>
/// 線程鎖定的類
/// </summary>
internal class ThreadLocalList
{
    // 雙向鏈表的頭結點 如果為null那麽表示鏈表為空
    internal volatile Node m_head;

    // 雙向鏈表的尾節點
    private volatile Node m_tail;

    // 定義當前對List進行操作的種類 
    // 與前面的 ListOperation 相對應
    internal volatile int m_currentOp;

    // 這個列表元素的計數
    private int m_count;

    // The stealing count
    // 這個不是特別理解 好像是在本地列表中 刪除某個Node 以後的計數
    internal int m_stealCount;

    // 下一個列表 可能會在其它線程中
    internal volatile ThreadLocalList m_nextList;

    // 設定鎖定是否已進行
    internal bool m_lockTaken;

    // The owner thread for this list
    internal Thread m_ownerThread;

    // 列表的版本,只有當列表從空變為非空統計是底層
    internal volatile int m_version;

    /// <summary>
    /// ThreadLocalList 構造器
    /// </summary>
    /// <param name="ownerThread">擁有這個集合的線程</param>
    internal ThreadLocalList(Thread ownerThread)
    {
        m_ownerThread = ownerThread;
    }
    /// <summary>
    /// 添加一個新的item到鏈表首部
    /// </summary>
    /// <param name="item">The item to add.</param>
    /// <param name="updateCount">是否更新計數.</param>
    internal void Add(T item, bool updateCount)
    {
        checked
        {
            m_count++;
        }
        Node node = new Node(item);
        if (m_head == null)
        {
            Debug.Assert(m_tail == null);
            m_head = node;
            m_tail = node;
            m_version++; // 因為進行初始化了,所以將空狀態改為非空狀態
        }
        else
        {
            // 使用頭插法 將新的元素插入鏈表
            node.m_next = m_head;
            m_head.m_prev = node;
            m_head = node;
        }
        if (updateCount) // 更新計數以避免此添加同步時溢出
        {
            m_count = m_count - m_stealCount;
            m_stealCount = 0;
        }
    }

    /// <summary>
    /// 從列表的頭部刪除一個item
    /// </summary>
    /// <param name="result">The removed item</param>
    internal void Remove(out T result)
    {
        // 雙向鏈表刪除頭結點數據的流程
        Debug.Assert(m_head != null);
        Node head = m_head;
        m_head = m_head.m_next;
        if (m_head != null)
        {
            m_head.m_prev = null;
        }
        else
        {
            m_tail = null;
        }
        m_count--;
        result = head.m_value;

    }

    /// <summary>
    /// 返回列表頭部的元素
    /// </summary>
    /// <param name="result">the peeked item</param>
    /// <returns>True if succeeded, false otherwise</returns>
    internal bool Peek(out T result)
    {
        Node head = m_head;
        if (head != null)
        {
            result = head.m_value;
            return true;
        }
        result = default(T);
        return false;
    }

    /// <summary>
    /// 從列表的尾部獲取一個item
    /// </summary>
    /// <param name="result">the removed item</param>
    /// <param name="remove">remove or peek flag</param>
    internal void Steal(out T result, bool remove)
    {
        Node tail = m_tail;
        Debug.Assert(tail != null);
        if (remove) // Take operation
        {
            m_tail = m_tail.m_prev;
            if (m_tail != null)
            {
                m_tail.m_next = null;
            }
            else
            {
                m_head = null;
            }
            // Increment the steal count
            m_stealCount++;
        }
        result = tail.m_value;
    }


    /// <summary>
    /// 獲取總計列表計數, 它不是線程安全的, 如果同時調用它, 則可能提供不正確的計數
    /// </summary>
    internal int Count
    {
        get
        {
            return m_count - m_stealCount;
        }
    }
}

從上面的代碼中我們可以更加驗證之前的觀點,就是ConcurentBag<T>在一個線程中存儲數據時,使用的是雙向鏈表ThreadLocalList實現了一組對鏈表增刪改查的方法。

3. ConcurrentBag實現新增元素

接下來我們看一看ConcurentBag<T>是如何新增元素的。

/// <summary>
/// 嘗試獲取無主列表,無主列表是指線程已經被暫停或者終止,但是集合中的部分數據還存儲在那裏
/// 這是避免內存泄漏的方法
/// </summary>
/// <returns></returns>
private ThreadLocalList GetUnownedList()
{
    //此時必須持有全局鎖
    Contract.Assert(Monitor.IsEntered(GlobalListsLock));

    // 從頭線程列表開始枚舉 找到那些已經被關閉的線程
    // 將它所在的列表對象 返回
    ThreadLocalList currentList = m_headList;
    while (currentList != null)
    {
        if (currentList.m_ownerThread.ThreadState == System.Threading.ThreadState.Stopped)
        {
            currentList.m_ownerThread = Thread.CurrentThread; // the caller should acquire a lock to make this line thread safe
            return currentList;
        }
        currentList = currentList.m_nextList;
    }
    return null;
}
/// <summary>
/// 本地幫助方法,通過線程對象檢索線程線程本地列表
/// </summary>
/// <param name="forceCreate">如果列表不存在,那麽創建新列表</param>
/// <returns>The local list object</returns>
private ThreadLocalList GetThreadList(bool forceCreate)
{
    ThreadLocalList list = m_locals.Value;

    if (list != null)
    {
        return list;
    }
    else if (forceCreate)
    {
        // 獲取用於更新操作的 m_tailList 鎖
        lock (GlobalListsLock)
        {
            // 如果頭列表等於空,那麽說明集合中還沒有元素
            // 直接創建一個新的
            if (m_headList == null)
            {
                list = new ThreadLocalList(Thread.CurrentThread);
                m_headList = list;
                m_tailList = list;
            }
            else
            {
               // ConcurrentBag內的數據是以雙向鏈表的形式分散存儲在各個線程的本地區域中
                // 通過下面這個方法 可以找到那些存儲有數據 但是已經被停止的線程
                // 然後將已停止線程的數據 移交到當前線程管理
                list = GetUnownedList();
                // 如果沒有 那麽就新建一個列表 然後更新尾指針的位置
                if (list == null)
                {
                    list = new ThreadLocalList(Thread.CurrentThread);
                    m_tailList.m_nextList = list;
                    m_tailList = list;
                }
            }
            m_locals.Value = list;
        }
    }
    else
    {
        return null;
    }
    Debug.Assert(list != null);
    return list;
}
/// <summary>
/// Adds an object to the <see cref="ConcurrentBag{T}"/>.
/// </summary>
/// <param name="item">The object to be added to the
/// <see cref="ConcurrentBag{T}"/>. The value can be a null reference
/// (Nothing in Visual Basic) for reference types.</param>
public void Add(T item)
{
    // 獲取該線程的本地列表, 如果此線程不存在, 則創建一個新列表 (第一次調用 add)
    ThreadLocalList list = GetThreadList(true);
    // 實際的數據添加操作 在AddInternal中執行
    AddInternal(list, item);
}

/// <summary>
/// </summary>
/// <param name="list"></param>
/// <param name="item"></param>
private void AddInternal(ThreadLocalList list, T item)
{
    bool lockTaken = false;
    try
    {
        #pragma warning disable 0420
        Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add);
        #pragma warning restore 0420
        // 同步案例:
        // 如果列表計數小於兩個, 因為是雙向鏈表的關系 為了避免與任何竊取線程發生沖突 必須獲取鎖
        // 如果設置了 m_needSync, 這意味著有一個線程需要凍結包 也必須獲取鎖
        if (list.Count < 2 || m_needSync)
        {
            // 將其重置為None 以避免與竊取線程的死鎖
            list.m_currentOp = (int)ListOperation.None;
            // 鎖定當前對象
            Monitor.Enter(list, ref lockTaken);
        }
        // 調用 ThreadLocalList.Add方法 將數據添加到雙向鏈表中
        // 如果已經鎖定 那麽說明線程安全  可以更新Count 計數
        list.Add(item, lockTaken);
    }
    finally
    {
        list.m_currentOp = (int)ListOperation.None;
        if (lockTaken)
        {
            Monitor.Exit(list);
        }
    }
}

從上面代碼中,我們可以很清楚的知道Add()方法是如何運行的,其中的關鍵就是GetThreadList()方法,通過該方法可以獲取當前線程的數據存儲列表對象,假如不存在數據存儲列表,它會自動創建或者通過GetUnownedList()方法來尋找那些被停止但是還存儲有數據列表的線程,然後將數據列表返回給當前線程中,防止了內存泄漏。

在數據添加的過程中,實現了細顆粒度的lock同步鎖,所以性能會很高。刪除和其它操作與新增類似,本文不再贅述。

4. ConcurrentBag 如何實現叠代器模式

看完上面的代碼後,我很好奇ConcurrentBag<T>是如何實現IEnumerator來實現叠代訪問的,因為ConcurrentBag<T>是通過分散在不同線程中的ThreadLocalList來存儲數據的,那麽在實現叠代器模式時,過程會比較復雜。

後面再查看了源碼之後,發現ConcurrentBag<T>為了實現叠代器模式,將分在不同線程中的數據全都存到一個List<T>集合中,然後返回了該副本的叠代器。所以每次訪問叠代器,它都會新建一個List<T>的副本,這樣雖然浪費了一定的存儲空間,但是邏輯上更加簡單了。

/// <summary>
/// 本地幫助器方法釋放所有本地列表鎖
/// </summary>
private void ReleaseAllLocks()
{
    // 該方法用於在執行線程同步以後 釋放掉所有本地鎖
    // 通過遍歷每個線程中存儲的 ThreadLocalList對象 釋放所占用的鎖
    ThreadLocalList currentList = m_headList;
    while (currentList != null)
    {

        if (currentList.m_lockTaken)
        {
            currentList.m_lockTaken = false;
            Monitor.Exit(currentList);
        }
        currentList = currentList.m_nextList;
    }
}

/// <summary>
/// 從凍結狀態解凍包的本地幫助器方法
/// </summary>
/// <param name="lockTaken">The lock taken result from the Freeze method</param>
private void UnfreezeBag(bool lockTaken)
{
    // 首先釋放掉 每個線程中 本地變量的鎖
    // 然後釋放全局鎖
    ReleaseAllLocks();
    m_needSync = false;
    if (lockTaken)
    {
        Monitor.Exit(GlobalListsLock);
    }
}

/// <summary>
/// 本地幫助器函數等待所有未同步的操作
/// </summary>
private void WaitAllOperations()
{
    Contract.Assert(Monitor.IsEntered(GlobalListsLock));

    ThreadLocalList currentList = m_headList;
    // 自旋等待 等待其它操作完成
    while (currentList != null)
    {
        if (currentList.m_currentOp != (int)ListOperation.None)
        {
            SpinWait spinner = new SpinWait();
            // 有其它線程進行操作時,會將cuurentOp 設置成 正在操作的枚舉
            while (currentList.m_currentOp != (int)ListOperation.None)
            {
                spinner.SpinOnce();
            }
        }
        currentList = currentList.m_nextList;
    }
}

/// <summary>
/// 本地幫助器方法獲取所有本地列表鎖
/// </summary>
private void AcquireAllLocks()
{
    Contract.Assert(Monitor.IsEntered(GlobalListsLock));

    bool lockTaken = false;
    ThreadLocalList currentList = m_headList;
    
    // 遍歷每個線程的ThreadLocalList 然後獲取對應ThreadLocalList的鎖
    while (currentList != null)
    {
        // 嘗試/最後 bllock 以避免在獲取鎖和設置所采取的標誌之間的線程港口
        try
        {
            Monitor.Enter(currentList, ref lockTaken);
        }
        finally
        {
            if (lockTaken)
            {
                currentList.m_lockTaken = true;
                lockTaken = false;
            }
        }
        currentList = currentList.m_nextList;
    }
}

/// <summary>
/// Local helper method to freeze all bag operations, it
/// 1- Acquire the global lock to prevent any other thread to freeze the bag, and also new new thread can be added
/// to the dictionary
/// 2- Then Acquire all local lists locks to prevent steal and synchronized operations
/// 3- Wait for all un-synchronized operations to be done
/// </summary>
/// <param name="lockTaken">Retrieve the lock taken result for the global lock, to be passed to Unfreeze method</param>
private void FreezeBag(ref bool lockTaken)
{
    Contract.Assert(!Monitor.IsEntered(GlobalListsLock));

    // 全局鎖定可安全地防止多線程調用計數和損壞 m_needSync
    Monitor.Enter(GlobalListsLock, ref lockTaken);

    // 這將強制同步任何將來的添加/執行操作
    m_needSync = true;

    // 獲取所有列表的鎖
    AcquireAllLocks();

    // 等待所有操作完成
    WaitAllOperations();
}

/// <summary>
/// 本地幫助器函數返回列表中的包項, 這主要由 CopyTo 和 ToArray 使用。
/// 這不是線程安全, 應該被稱為凍結/解凍袋塊
/// 本方法是私有的 只有使用 Freeze/UnFreeze之後才是安全的 
/// </summary>
/// <returns>List the contains the bag items</returns>
private List<T> ToList()
{
    Contract.Assert(Monitor.IsEntered(GlobalListsLock));
    // 創建一個新的List
    List<T> list = new List<T>();
    ThreadLocalList currentList = m_headList;
    // 遍歷每個線程中的ThreadLocalList 將裏面的Node的數據 添加到list中
    while (currentList != null)
    {
        Node currentNode = currentList.m_head;
        while (currentNode != null)
        {
            list.Add(currentNode.m_value);
            currentNode = currentNode.m_next;
        }
        currentList = currentList.m_nextList;
    }

    return list;
}

/// <summary>
/// Returns an enumerator that iterates through the <see
/// cref="ConcurrentBag{T}"/>.
/// </summary>
/// <returns>An enumerator for the contents of the <see
/// cref="ConcurrentBag{T}"/>.</returns>
/// <remarks>
/// The enumeration represents a moment-in-time snapshot of the contents
/// of the bag.  It does not reflect any updates to the collection after 
/// <see cref="GetEnumerator"/> was called.  The enumerator is safe to use
/// concurrently with reads from and writes to the bag.
/// </remarks>
public IEnumerator<T> GetEnumerator()
{
    // Short path if the bag is empty
    if (m_headList == null)
        return new List<T>().GetEnumerator(); // empty list

    bool lockTaken = false;
    try
    {
        // 首先凍結整個 ConcurrentBag集合
        FreezeBag(ref lockTaken);
        // 然後ToList 再拿到 List的 IEnumerator
        return ToList().GetEnumerator();
    }
    finally
    {
        UnfreezeBag(lockTaken);
    }
}

由上面的代碼可知道,為了獲取叠代器對象,總共進行了三步主要的操作。

  1. 使用FreezeBag()方法,凍結整個ConcurrentBag<T>集合。因為需要生成集合的List<T>副本,生成副本期間不能有其它線程更改損壞數據。
  2. ConcurrrentBag<T>生成List<T>副本。因為ConcurrentBag<T>存儲數據的方式比較特殊,直接實現叠代器模式困難,考慮到線程安全和邏輯,最佳的辦法是生成一個副本。
  3. 完成以上操作以後,就可以使用UnfreezeBag()方法解凍整個集合。

那麽FreezeBag()方法是如何來凍結整個集合的呢?也是分為三步走。

  1. 首先獲取全局鎖,通過Monitor.Enter(GlobalListsLock, ref lockTaken);這樣一條語句,這樣其它線程就不能凍結集合。
  2. 然後獲取所有線程中ThreadLocalList的鎖,通過`AcquireAllLocks()方法來遍歷獲取。這樣其它線程就不能對它進行操作損壞數據。
  3. 等待已經進入了操作流程線程結束,通過WaitAllOperations()方法來實現,該方法會遍歷每一個ThreadLocalList對象的m_currentOp屬性,確保全部處於None操作。

完成以上流程後,那麽就是真正的凍結了整個ConcurrentBag<T>集合,要解凍的話也類似。在此不再贅述。

四、總結

下面給出一張圖,描述了ConcurrentBag<T>是如何存儲數據的。通過每個線程中的ThreadLocal來實現線程本地存儲,每個線程中都有這樣的結構,互不幹擾。然後每個線程中的m_headList總是指向ConcurrentBag<T>的第一個列表,m_tailList指向最後一個列表。列表與列表之間通過m_locals 下的 m_nextList相連,構成一個單鏈表。

數據存儲在每個線程的m_locals中,通過Node類構成一個雙向鏈表。

技術分享圖片

以上就是有關ConcurrentBag<T>類的實現,筆者的一些記錄和解析。

筆者水平有限,如果錯誤歡迎各位批評指正!

附上ConcurrentBag<T>源碼地址:戳一戳

C# ConcurrentBag的實現原理