1. 程式人生 > >記一次被yield return坑的歷程。

記一次被yield return坑的歷程。

事情的經過是這樣的:

我用C#寫了一個很簡單的一個通過迭代生成序列的函式。

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);

    var current = initVal;
    while (--length >= 0
) { yield return (current = f(current)); } }

其中NullCheck用於檢查引數是否為null,如果是則丟擲ArgumentNullException異常。

對應的,我寫了如下單元測試程式碼去檢測這個異常。

public void TestIterate()
{
    Func<int, int> f = null;
    Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7));
    
    // Other tests
}

但是,這個測試出乎意料的fail了。

一開始,我以為是NullCheck函式的問題,可我把NullCheck直接換成了if語句,還是通不過。

後來我在Iterate函式下斷點並除錯。結果偵錯程式根本沒有停在斷點上,直接執行完了測試。

我以為是我測試的方法不對,所以我不斷的修改測試程式碼,甚至還一度以為是.NET的Unit Tests出了bug。

最終,我在這個測試程式碼發現了問題:

Assert.Throws<ArgumentNullException>(() =>
{
    var seq = f.Iterate(1, 7);
    foreach (int ele in
seq) Console.WriteLine(ele); });

當我除錯這個測試時,程式停在了我之前在Iterate函式上下的斷點。

於是,我在 var seq = f.Iterate(1, 7); 上下斷點,並逐步執行。這時我發現,當程式執行到 var seq = f.Iterate(1, 7); 時並不會進入Iterate函式;而是當程式執行到foreach語句後才進入。

這就要涉及到yield return的具體工作流程。當函式程式碼中出現yield return,呼叫這個函式會直接返回一個IEnumerable<T>或IEnumerator<T>物件,並不會執行函式體的任何程式碼。這些程式碼都被封裝到了返回物件的內部。它們會在你開始列舉的時候開始執行。

因此,上面兩個Check並不會在函式呼叫時執行,而是在當你開始foreach的時候才執行。

這並不是我想要的結果。我希望在呼叫函式時就檢查引數合法性,如果不合法便直接丟擲異常。

解決這個問題有兩種途徑,一是把它拆成兩個函式:

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
    return IterateWithoutCheck(f, initVal, length);
}

private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length)
{
    var current = initVal;
    while (--length >= 0)
    {
        yield return (current = f(current));
    }
}

或者,你也可以將這個函式包裝成一個類。

class FunctionIterator<T> : IEnumerable<T>
{
    private readonly Func<T, T> f;
    private readonly T initVal;
    private readonly int length;
        
    public FunctionIterator(Func<T, T> f, T initVal, int length)
    {
        Checker.NullCheck(nameof(f), f);
        Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
        this.f = f;
        this.initVal = initVal;
        this.length = length;
    }

    public IEnumerator<T> GetEnumerator()
    {
        T current = initVal;

        for (int i = 0; i < length; ++i)
            yield return (current = f(current));
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }
}