Linq.Where 源码分析

336 阅读2分钟

Linq.Where

代码示例:

public class Student
{
    public string Name { get; set; }
    public int Age { get; set; }
}

class Program
{
    static void Main(string[] args)
    {
        List<Student> students = new List<Student>
        {
           new Student{Name = "xiaoming",Age = 11 },
           new Student{Name = "xiaohong",Age = 12 },
           new Student{Name = "xiaohuang",Age = 13 },
        };
        foreach (var item in students.Where(x=>x.Name == "xiaoming"))
        {
           Console.WriteLine($"年龄:{item.Age}");
        }
        Console.ReadKey();
    }
}

结果: 11

解析Where源码

在解析源码前,笔者通过多次Debug,列出了方法调用间的顺序

image.png

System.Linq.Enumerable.MoveNext()会不断调用System.Collections.Generic.MoveNext(),当System.Collections.Generic.MoveNext()返回True时,System.Linq.Enumerable.MoveNext()会调用System.String.Equals()去判断数据是否匹配(笔者认为这里的调用是因为我定义的Namestring类型,不同的类型可能会有不同的判断方法,后续待验证)

笔者认为重点代码是以下这些:

namespace System.Linq
{
    private sealed partial class WhereEnumerableIterator<TSource> : Iterator<TSource>
    {
        private readonly IEnumerable<TSource> _source;
        private readonly Func<TSource, bool> _predicate;
        private IEnumerator<TSource>? _enumerator;

        public WhereEnumerableIterator(IEnumerable<TSource> source, Func<TSource, bool> predicate)
        {
            Debug.Assert(source != null);
            Debug.Assert(predicate != null);
            _source = source;
            _predicate = predicate;
        }
        public override bool MoveNext()
        {
            switch (_state)
            {
                case 1:
                    _enumerator = _source.GetEnumerator();
                    _state = 2;
                    goto case 2;
                case 2:
                    while (_enumerator.MoveNext())
                    {
                        TSource item = _enumerator.Current;
                        if (_predicate(item))
                        {
                            _current = item;
                            return true;
                        }
                    }

                    Dispose();
                    break;
            }

            return false;
        }
    }
}
    
namespace System.Collections.Generic
{
    // Implements a variable-size List that uses an array of objects to store the
    // elements. A List has a capacity, which is the allocated length
    // of the internal array. As elements are added to a List, the capacity
    // of the List is automatically increased as required by reallocating the
    // internal array.
    //
    [DebuggerTypeProxy(typeof(ICollectionDebugView<>))]
    [DebuggerDisplay("Count = {Count}")]
    [Serializable]
    [TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
    public class List<T> : IList<T>, IList, IReadOnlyList<T>
    {
	public Enumerator GetEnumerator() => new Enumerator(this);
        
        public struct Enumerator : IEnumerator<T>, IEnumerator
        {
            private readonly List<T> _list;
            private int _index;
            private readonly int _version;
            private T? _current;
 
            internal Enumerator(List<T> list)
            {
                _list = list;
                _index = 0;
                _version = list._version;
                _current = default;
            }
 
            public void Dispose()
            {
            }
 
            public bool MoveNext()
            {
                List<T> localList = _list;
 
                if (_version == localList._version && ((uint)_index < (uint)localList._size))
                {
                    _current = localList._items[_index];
                    _index++;
                    return true;
                }
                return MoveNextRare();
         }
      }
   }
}

System.Linq.WhereEnumerableIterator.MoveNext()中,_state会等于1,这会调用GetEnumerator()方法拿到当前数组的所有数据,接着再将_state设置成2,会执行System.Collections.Generic.MoveNext()方法拿到数组中的第一个元素,与Linq中传入的数据进行匹配,如果匹配成功,则返回true,如果匹配失败,则继续调用System.Collections.Generic.MoveNext()方法拿到数组中的第二个元素,以此类推,直到System.Linq.WhereEnumerableIterator.MoveNext()返回true为止。

附截取的部分调用源码

System.Linq.Enumerable

public static partial class Enumerable
{
    public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        if (source == null)
        {
            ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
        }

        if (predicate == null)
        {
            ThrowHelper.ThrowArgumentNullException(ExceptionArgument.predicate);
        }

        if (source is Iterator<TSource> iterator)
        {
            return iterator.Where(predicate);
        }

        if (source is TSource[] array)
        {
            return array.Length == 0 ?
                Empty<TSource>() :
                new WhereArrayIterator<TSource>(array, predicate);
        }

        if (source is List<TSource> list)
        {
            return new WhereListIterator<TSource>(list, predicate);
        }

        return new WhereEnumerableIterator<TSource>(source, predicate);
    }
    
     // <summary>
     /// An iterator that filters each item of a <see cref="List{TSource}"/>.
    /// </summary>
   /// <typeparam name="TSource">The type of the source list.</typeparam>
   private sealed partial class WhereListIterator<TSource> : Iterator<TSource>
{
    private readonly List<TSource> _source;
    private readonly Func<TSource, bool> _predicate;
    private List<TSource>.Enumerator _enumerator;

    public WhereListIterator(List<TSource> source, Func<TSource, bool> predicate)
    {
        Debug.Assert(source != null);
        Debug.Assert(predicate != null);
        _source = source;
        _predicate = predicate;
    }

    public override bool MoveNext()
    {
        switch (_state)
        {
            case 1:
                _enumerator = _source.GetEnumerator();
                _state = 2;
                goto case 2;
            case 2:
                while (_enumerator.MoveNext())
                {
                    TSource item = _enumerator.Current;
                    if (_predicate(item))
                    {
                        _current = item;
                        return true;
                    }
                }

                Dispose();
                break;
        }
        return false;
    }

System.StartupHookProvider

using System;
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Reflection;
using System.Runtime.Loader;
 
namespace System
{
    internal static class StartupHookProvider
    {
        private const string StartupHookTypeName = "StartupHook";
        private const string InitializeMethodName = "Initialize";
        private const string DisallowedSimpleAssemblyNameSuffix = ".dll";
 
        private static bool IsSupported => AppContext.TryGetSwitch("System.StartupHookProvider.IsSupported", out bool isSupported) ? isSupported : true;
 
        private struct StartupHookNameOrPath
        {
            public AssemblyName AssemblyName;
            public string Path;
        }
        

        // Parse a string specifying a list of assemblies and types
        // containing a startup hook, and call each hook in turn.
        private static void ProcessStartupHooks()
        {
            if (!IsSupported)
                return;
 
            // Initialize tracing before any user code can be called if EventSource is enabled.
            if (EventSource.IsSupported)
            {
                System.Diagnostics.Tracing.RuntimeEventSource.Initialize();
            }
 
            string? startupHooksVariable = AppContext.GetData("STARTUP_HOOKS") as string;
            if (startupHooksVariable == null)
            {
                return;
            }
 
            ReadOnlySpan<char> disallowedSimpleAssemblyNameChars = stackalloc char[4]
            {
                Path.DirectorySeparatorChar,
                Path.AltDirectorySeparatorChar,
                ' ',
                ','
            };
 
            // Parse startup hooks variable
            string[] startupHookParts = startupHooksVariable.Split(Path.PathSeparator);
            StartupHookNameOrPath[] startupHooks = new StartupHookNameOrPath[startupHookParts.Length];
            for (int i = 0; i < startupHookParts.Length; i++)
            {
                string startupHookPart = startupHookParts[i];
                if (string.IsNullOrEmpty(startupHookPart))
                {
                    // Leave the slot in startupHooks empty (nulls for everything). This is simpler than shifting and resizing the array.
                    continue;
                }
 
                if (Path.IsPathFullyQualified(startupHookPart))
                {
                    startupHooks[i].Path = startupHookPart;
                }
                else
                {
                    // The intent here is to only support simple assembly names, but AssemblyName .ctor accepts
                    // lot of other forms (fully qualified assembly name, strings which look like relative paths and so on).
                    // So add a check on top which will disallow any directory separator, space or comma in the assembly name.
                    for (int j = 0; j < disallowedSimpleAssemblyNameChars.Length; j++)
                    {
                        if (startupHookPart.Contains(disallowedSimpleAssemblyNameChars[j]))
                        {
                            throw new ArgumentException(SR.Format(SR.Argument_InvalidStartupHookSimpleAssemblyName, startupHookPart));
                        }
                    }
 
                    if (startupHookPart.EndsWith(DisallowedSimpleAssemblyNameSuffix, StringComparison.OrdinalIgnoreCase))
                    {
                        throw new ArgumentException(SR.Format(SR.Argument_InvalidStartupHookSimpleAssemblyName, startupHookPart));
                    }
 
                    try
                    {
                        // This will throw if the string is not a valid assembly name.
                        startupHooks[i].AssemblyName = new AssemblyName(startupHookPart);
                    }
                    catch (Exception assemblyNameException)
                    {
                        throw new ArgumentException(SR.Format(SR.Argument_InvalidStartupHookSimpleAssemblyName, startupHookPart), assemblyNameException);
                    }
                }
            }
 
            // Call each hook in turn
            foreach (StartupHookNameOrPath startupHook in startupHooks)
            {
#pragma warning disable IL2026 // suppressed in ILLink.Suppressions.LibraryBuild.xml
                CallStartupHook(startupHook);
#pragma warning restore IL2026
            }
        }
    }
}

System.MulticastDelegate

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Serialization;
 
namespace System
{
    [ClassInterface(ClassInterfaceType.None)]
    [ComVisible(true)]
    public abstract class MulticastDelegate : Delegate
    {
        // This is set under 2 circumstances
        // 1. Multicast delegate
        // 2. Wrapper delegate
        private object? _invocationList; // Initialized by VM as needed
        private IntPtr _invocationCount;
        
#pragma warning disable IDE0060
        [System.Diagnostics.DebuggerNonUserCode]
        private void CtorClosed(object target, IntPtr methodPtr)
        {
            if (target == null)
                ThrowNullThisInDelegateToInstance();
            this._target = target;
            this._methodPtr = methodPtr;
        }
    }
}      

System.String

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Unicode;

using Internal.Runtime.CompilerServices;

namespace System
{
    public partial class String
    {
        
        // Determines whether two Strings match.
        public static bool Equals(string? a, string? b)
        {
            if (object.ReferenceEquals(a, b))
            {
                return true;
            }

            if (a is null || b is null || a.Length != b.Length)
            {
                return false;
            }

            return EqualsHelper(a, b);
        }
    }
 }       

System.Collections.Generic.List

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
 
namespace System.Collections.Generic
{
    // Implements a variable-size List that uses an array of objects to store the
    // elements. A List has a capacity, which is the allocated length
    // of the internal array. As elements are added to a List, the capacity
    // of the List is automatically increased as required by reallocating the
    // internal array.
    //
    [DebuggerTypeProxy(typeof(ICollectionDebugView<>))]
    [DebuggerDisplay("Count = {Count}")]
    [Serializable]
    [TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
    public class List<T> : IList<T>, IList, IReadOnlyList<T>
    {
       public Enumerator GetEnumerator() => new Enumerator(this);
    }
}
    

参考链接

Enumerable 类 (System.Linq) | Microsoft Learn

IEnumerable 接口 (System.Collections.Generic) | Microsoft Learn

迭代器 | Microsoft Learn

使用 Visual Studio |调试 .NET 并 ASP.NET 核心源代码微软文档 (microsoft.com)