LINQ之路(3):LINQ扩展

本篇文章将从三个方面来进行LINQ扩展的阐述:扩展查询操作符、自定义查询操作符和简单模拟LINQ to SQL。

1.扩展查询操作符

在实际的使用过程中,Enumerable或Queryable中的扩展方法有时并不能满足我们的需要,我们需要自己扩展一些查询操作符以满足需要。例如,下面的例子:

var r = Enumerable.Range(1, 10).Zip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
    Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20

Enumerable.ZIP扩展是用来将指定函数应用于两个序列的对应元素,以生成结果序列,这里是将序列[1...10]与序列[11...15]相对应位置的元素做加法而生成一个新的序列。内部实现如下:

static IEnumerable<TResult> ZipIterator<TFirst, TSecond, TResult>(IEnumerable<TFirst> first, IEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> resultSelector) {
        using (IEnumerator<TFirst> e1 = first.GetEnumerator())
            using (IEnumerator<TSecond> e2 = second.GetEnumerator())
                while (e1.MoveNext() && e2.MoveNext())
                    yield return resultSelector(e1.Current, e2.Current);
    }

很明显,取得是交集,即两个序列对应位置要都有元素才进行处理,所以上面的输出结果也是理所当然的。然而有时候,我们想以第一个序列为主序列,即结果序列的长度总是等于第一个序列的长度。我们来扩展一个查询操作符,取名为LeftZip,实现如下:

	/// <summary>
    /// Merge right sequence into left sequence by using the specified predicate function.
    /// </summary>
    /// <typeparam name="TLeft"></typeparam>
    /// <typeparam name="TRight"></typeparam>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="lefts"></param>
    /// <param name="rights"></param>
    /// <param name="resultSelector"></param>
    /// <returns></returns>
    public static IEnumerable<TResult> LeftZip<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
        IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
    {
        if(lefts == null)
            throw new ArgumentNullException("lefts");
        if(rights == null)
            throw new ArgumentNullException("rights");
        if (resultSelector == null)
            throw new ArgumentNullException("resultSelector");
        return LeftZipImpl(lefts, rights, resultSelector);
    }
    /// <summary>
    /// The Implementation of LeftZip
    /// </summary>
    /// <typeparam name="TLeft"></typeparam>
    /// <typeparam name="TRight"></typeparam>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="lefts"></param>
    /// <param name="rights"></param>
    /// <param name="resultSelector"></param>
    /// <returns></returns>
    private static IEnumerable<TResult> LeftZipImpl<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
        IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
    {
        using (var left = lefts.GetEnumerator())
        {
            using (var right = rights.GetEnumerator())
            {
                while (left.MoveNext())
                {
                    if (right.MoveNext())
                    {
                        yield return resultSelector(left.Current, right.Current);
                    }
                    else
                    {
                        do
                        {
                            yield return resultSelector(left.Current, default(TRight));
                        } while (left.MoveNext());
                        yield break;
                    }
                }
            }
        }
    }

调用LeftZip,代码如下:

var r = Enumerable.Range(1, 10).LeftZip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
    Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20
//6
//7
//8
//9
//10

2.自定义查询操作符

之前,我们在实现枚举器的时候有一种自实现形式,即不继承IEnumerable和IEnumerator接口,自定义一个实现GetEnumerator()的类和一个实现Current和MoveNext的类,即可使用foreach进行迭代。我们还知道LINQ语句转换成了扩展方法的链式调用,标准查询操作符转换了同名扩展方法(首字母大写)。那么,如果我们自己去实现标准查询操作符的同名扩展方法,会不会得到执行呢?
开始尝试,创建一个静态类LinqExtensions,实现Where扩展方法,如下:

	/// <summary>
    /// Filters a sequence of values based on a predicate.
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="source"></param>
    /// <param name="predicate"></param>
    /// <returns></returns>
    public static IEnumerable<TResult> Where<TResult>(this IEnumerable<TResult> source,
        Func<TResult, bool> predicate)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (predicate == null)
            throw new ArgumentNullException("predicate");
        return WhereImpl(source, predicate);
    }
    /// <summary>
    /// The implementation of Where
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="source"></param>
    /// <param name="predicate"></param>
    /// <returns></returns>
    private static IEnumerable<TResult> WhereImpl<TResult>(this IEnumerable<TResult> source,
        Func<TResult, bool> predicate)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                if (predicate(e.Current))
                    yield return e.Current;
            }
        }
    }

调用部分代码如下:

var r = from e in Enumerable.Range(1, 10)
            where e%2 == 0
            select e;
        foreach (var i in r)
        {
            Console.WriteLine(i);
        }
//output:
//2
//4
//6
//8
//10

如何判断Where扩展方法被调用了呢?调试、在VS中选中Where然后F12转到定义和在Where扩展方法实现中打印输出都可以判断。你可能会有疑问?方法签名是不是要一致呢?答案是否定的。你可以将Where扩展方法改名为Select,然后调用改为如下:

var r = from e in Enumerable.Range(1, 10)
        //where e % 2 == 0
        select e % 2 == 0;
//output:
//2
//4
//6
//8
//10

最后结合枚举器,举一个例子:

public class Collection<T>
{
    private T[] items;

    public Collection()
    {

    }

    public Collection(IEnumerable<T> collection)
    {
        if (collection == null)
            throw new ArgumentNullException("collection");
        items = new T[collection.Count()];
        Array.Copy(collection.ToArray(), items, collection.Count());
    }

    public static implicit operator Collection<T>(T[] arr)
    {
        Collection<T> collection = new Collection<T>();
        collection.items = new T[arr.Length];
        Array.Copy(arr, collection.items, arr.Length);
        return collection;
    }

    public ItemEnumerator GetEnumerator()
    {
        return new ItemEnumerator(items);
    }

    #region Item Enumerator
    public class ItemEnumerator : IDisposable
    {
        private T[] items;
        private int index = -1;

        public ItemEnumerator(T[] arr)
        {
            this.items = arr;
        }
        /// <summary>
        /// Current属性
        /// </summary>
        public T Current
        {
            get
            {
                if (index < 0 || index > items.Length - 1)
                    throw new InvalidOperationException();
                return items[index];
            }
        }
        /// <summary>
        /// MoveNext方法
        /// </summary>
        /// <returns></returns>
        public bool MoveNext()
        {
            if (index < items.Length - 1)
            {
                index++;
                return true;
            }
            else
            {
                return false;
            }
        }

        public void Reset()
        {
            index = -1;
        }
        #region IDisposable 成员

        public void Dispose()
        {
            index = -1;
        }

        #endregion
    }
    #endregion
}

public static class EnumerableExtensions
{
    public static Collection<T> Where<T>(this Collection<T> source, Func<T, bool> predicate)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (predicate == null)
            throw new ArgumentNullException("predicate");
        return WhereImpl(source, predicate).ToCollection();
    }

    private static IEnumerable<T> WhereImpl<T>(this Collection<T> source, Func<T, bool> predicate)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                if (predicate(e.Current))
                {
                    yield return e.Current;
                }
            }
        }
    }

    public static Collection<TResult> Select<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (selector == null)
            throw new ArgumentNullException("selector");
        return SelectImpl(source, selector).ToCollection();
    }

    private static IEnumerable<TResult> SelectImpl<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                yield return selector(e.Current);
            }
        }
    }

    public static Collection<T> ToCollection<T>(this IEnumerable<T> source)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        return new Collection<T>(source);
    }
}

包含两个类,一个是作为数据源,一个是用于扩展,调用方法如下:

Collection<int> collection = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
        var r = from c in collection
                where c % 2 == 0
                select c;
        foreach (var i in r)
        {
            Console.WriteLine(i);
        }

//output:
//2
//4
//6
//8
//10

3.简单模拟LINQ to SQL

在LINQ之路(2)中,我们简单介绍了LINQ to SQL的原理。在这里,我们通过简单模拟LINQ to SQL来更一步了解LINQ to SQL原理。

首先创建数据源,创建Query类,实现IQueryable接口:

public class Query<T> : IQueryable<T>
{
    #region 字段

    private QueryProvider provider;
    private Expression expression;

    #endregion

    #region 属性

    #endregion

    #region 构造函数

    public Query(QueryProvider provider)
    {
        if (provider == null)
            throw new ArgumentNullException("provider");
        this.provider = provider;
        this.expression = Expression.Constant(this);
    }

    public Query(QueryProvider provider, Expression expression)
    {
        if (provider == null)
            throw new ArgumentNullException("provider");
        if (expression == null)
            throw new ArgumentNullException("expression");
        if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
            throw new ArgumentOutOfRangeException("expression");
        this.provider = provider;
        this.expression = expression;
    }
    #endregion

    #region 方法

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>) this.provider.Execute(this.expression)).GetEnumerator();
    }

    #endregion

    #region IEnumerable<T> 成员

    IEnumerator<T> IEnumerable<T>.GetEnumerator()
    {
        return this.GetEnumerator();
    }

    #endregion

    #region IEnumerable 成员

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return this.GetEnumerator();
    }

    #endregion

    #region IQueryable 成员

    Type IQueryable.ElementType
    {
        get { return typeof(T); }
    }

    Expression IQueryable.Expression
    {
        get { return this.expression; }
    }

    IQueryProvider IQueryable.Provider
    {
        get { return this.provider; }
    }

    #endregion
}

这个类比较简单,实现接口并初始化参数。
再来看Provider,创建QueryProvider类,实现IQueryProvider接口:

public class QueryProvider:IQueryProvider
{
    #region 字段

    private IDbConnection dbConnection;

    #endregion

    #region 属性

    #endregion

    #region 构造函数
    public QueryProvider(IDbConnection dbConnection)
    {
        this.dbConnection = dbConnection;
    }
    #endregion

    #region 方法

    #endregion

    #region IQueryProvider 成员

    public IQueryable<TElement> CreateQuery<TElement>(System.Linq.Expressions.Expression expression)
    {
        return new Query<TElement>(this, expression);
    }

    public IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
    {
        var type = expression.Type;
        try
        {
            return (IQueryable) Activator.CreateInstance(typeof (Query<>).MakeGenericType(type), this, expression);
        }
        catch (TargetInvocationException e)
        {
            throw e.InnerException;
        }
    }

    TResult IQueryProvider.Execute<TResult>(System.Linq.Expressions.Expression expression)
    {
        return (TResult) this.Execute(expression);
    }

    object IQueryProvider.Execute(System.Linq.Expressions.Expression expression)
    {
        return this.Execute(expression);
    }

    public virtual object Execute(Expression expression)
    {
        if(expression == null)
            throw new ArgumentNullException("expression");
        return ExecuteImpl(expression);
    }

    private IEnumerable ExecuteImpl(Expression expression)
    {
        //var type = expression.Type;
        //var entityType = type.GetGenericArguments()[0];
        List<Product> products = new List<Product>();
        QueryTranslator queryTranslator = new QueryTranslator();
        var cmdText = queryTranslator.Translate(expression);
        IDbCommand cmd = dbConnection.CreateCommand();
        cmd.CommandText = cmdText;
        using (IDataReader dataReader = cmd.ExecuteReader())
        {
            while (dataReader.Read())
            {
                Product product = new Product();
                product.ID = dataReader.GetInt32(0);
                product.Name = dataReader.GetString(1);
                product.Type = dataReader.GetInt32(2);
                products.Add(product);
            }
        }
        return products;
    }
    #endregion
}

再来看看查询翻译类,创建QueryTranslator类,继承自ExpressionVisitor抽象类:

public class QueryTranslator:ExpressionVisitor
{
    #region 字段

    private StringBuilder sb;

    #endregion

    #region 属性

    #endregion

    #region 构造函数
    public QueryTranslator()
    {

    }

    #endregion

    #region 方法

    public string Translate(Expression expression)
    {
        this.sb = new StringBuilder();
        this.Visit(expression);
        return this.sb.ToString();
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression) e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        if (node.Method.DeclaringType == typeof (Queryable) &&
            node.Method.Name == "Where")
        {
            sb.Append("SELECT * FROM (");
            this.Visit(node.Arguments[0]);
            sb.Append(") AS T WHERE ");
            LambdaExpression lambda = (LambdaExpression) StripQuotes(node.Arguments[1]);
            this.Visit(lambda.Body);
            return node;
        }
        throw new NotSupportedException(string.Format("The Method '{0}' is not supported", node.Method.Name));
    }

    protected override Expression VisitBinary(BinaryExpression node)
    {
        sb.Append("(");
        this.Visit(node.Left);
        switch (node.NodeType)
        {
            case ExpressionType.Equal:
                sb.Append(" = ");
                break;
            case ExpressionType.NotEqual:
                sb.Append(" <> ");
                break;
            case ExpressionType.GreaterThan:
                sb.Append(" > ");
                break;
            case ExpressionType.GreaterThanOrEqual:
                sb.Append(" >= ");
                break;
            case ExpressionType.LessThan:
                sb.Append(" < ");
                break;
            case ExpressionType.LessThanOrEqual:
                sb.Append(" <= ");
                break;
            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", node.NodeType));
        }
        this.Visit(node.Right);
        sb.Append(")");
        return node;
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        IQueryable q = node.Value as IQueryable;
        if (q != null)
        {
            sb.Append("SELECT * FROM ");
            sb.Append(DataContext.MetaTables.FirstOrDefault(f => f.Type == q.ElementType).TableName);
            return node;
        }
        else if(node.Value == null)
        {
            sb.Append("NULL");
        }
        else
        {
            switch (Type.GetTypeCode(node.Value.GetType()))
            {
                case TypeCode.Boolean:
                    sb.Append(((bool) node.Value) ? 1 : 0);
                    break;
                case TypeCode.String:
                    sb.AppendFormat("'{0}'", node.Value);
                    break;
                case TypeCode.Object:
                    throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", node.Value));
                default:
                    sb.Append(node.Value);
                    break;
            }
        }
        return node;
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
        {
            sb.Append(node.Member.Name);
            return node;
        }
        throw new NotSupportedException(string.Format("The member '{0}' is not supported", node.Member.Name));
    }

    #endregion
}

重写Visit相关方法,以Visitor模式解析表达式目录树。
最后来看下DataContext的实现:

public class DataContext : IDisposable
{
    #region 字段

    private IDbConnection dbConnection;
    private static List<MetaTable> metaTables; 
    #endregion

    #region 属性

    public TextWriter Log { get; set; }

    public IDbConnection DbConnection
    {
        get { return this.dbConnection; }
    }

    public static List<MetaTable> MetaTables
    {
        get { return metaTables; }
    }
    #endregion

    #region 构造函数
    public DataContext(string connString)
    {
        if (connString == null)
            throw new ArgumentNullException(connString);
        dbConnection = new SqlConnection(connString);
        dbConnection.Open();
        InitTables();
    }
    #endregion

    #region 方法

    private void InitTables()
    {
        metaTables = new List<MetaTable>();
        var props = this.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
        foreach (var prop in props)
        {
            var propType = prop.PropertyType;
            if (propType.IsGenericType && propType.GetGenericTypeDefinition() == typeof (Query<>))
            {
                var entityType = propType.GetGenericArguments()[0];
                var entityAttr = entityType.GetCustomAttribute<MappingAttribute>(true);
                if (entityAttr != null)
                {
                    var metaTable = new MetaTable();
                    metaTable.Type = entityType;
                    metaTable.TableName = entityAttr.Name;
                    metaTable.MappingAttribute = entityAttr;
                    var columnProps = entityType.GetProperties(BindingFlags.Public | BindingFlags.Instance);
                    foreach (var columnProp in columnProps)
                    {
                        var columnPropAttr = columnProp.GetCustomAttribute<MappingAttribute>(true);
                        if (columnPropAttr != null)
                        {
                            MetaColumn metaColumn = new MetaColumn();
                            metaColumn.MappingAttribute = columnPropAttr;
                            metaColumn.ColumnName = columnPropAttr.Name;
                            metaColumn.PropertyInfo = columnProp;
                            metaTable.MetaColumns.Add(metaColumn);
                        }
                    }
                    metaTables.Add(metaTable);
                }
            }
        }
    }
    #endregion

    #region IDisposable 成员

    protected virtual void Dispose(bool disposing)
    {
        if (!disposing) return;
        if (dbConnection != null)
            dbConnection.Close();
    }

    public void Dispose()
    {
        Dispose(true);
        GC.SuppressFinalize(this);
    }

    #endregion
}
[Database(Name = "IT_Company")]
public class QueryDataContext : DataContext
{
    public QueryDataContext(string connString)
        : base(connString)
    {
        QueryProvider provider = new QueryProvider(DbConnection);
        Products = new Query<Product>(provider);
    }
    public Query<Product> Products
    {
        get;
        set;
    }
}

调用如下:

class Program
{
    private static readonly string connString =
        "Data Source=.;Initial Catalog=IT_Company;Persist Security Info=True;User ID=sa;Password=123456";
    static void Main(string[] args)
    {
        using (var context = new QueryDataContext(connString))
        {
            var query = from product in context.Products
                where product.Type == 1
                select product;
            foreach (var product in query)
            {
                Console.WriteLine(product.Name);
            }
            Console.ReadKey();
        }
    }
}
//output:
//MG500
//MG1000
posted @ 2015-10-07 03:26  jello chen  阅读(1007)  评论(0编辑  收藏  举报