LINQ之路(3):LINQ扩展
本篇文章将从三个方面来进行LINQ扩展的阐述:扩展查询操作符、自定义查询操作符和简单模拟LINQ to SQL。
1.扩展查询操作符
在实际的使用过程中,Enumerable或Queryable中的扩展方法有时并不能满足我们的需要,我们需要自己扩展一些查询操作符以满足需要。例如,下面的例子:
var r = Enumerable.Range(1, 10).Zip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20
Enumerable.ZIP扩展是用来将指定函数应用于两个序列的对应元素,以生成结果序列,这里是将序列[1...10]与序列[11...15]相对应位置的元素做加法而生成一个新的序列。内部实现如下:
static IEnumerable<TResult> ZipIterator<TFirst, TSecond, TResult>(IEnumerable<TFirst> first, IEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> resultSelector) {
using (IEnumerator<TFirst> e1 = first.GetEnumerator())
using (IEnumerator<TSecond> e2 = second.GetEnumerator())
while (e1.MoveNext() && e2.MoveNext())
yield return resultSelector(e1.Current, e2.Current);
}
很明显,取得是交集,即两个序列对应位置要都有元素才进行处理,所以上面的输出结果也是理所当然的。然而有时候,我们想以第一个序列为主序列,即结果序列的长度总是等于第一个序列的长度。我们来扩展一个查询操作符,取名为LeftZip,实现如下:
/// <summary>
/// Merge right sequence into left sequence by using the specified predicate function.
/// </summary>
/// <typeparam name="TLeft"></typeparam>
/// <typeparam name="TRight"></typeparam>
/// <typeparam name="TResult"></typeparam>
/// <param name="lefts"></param>
/// <param name="rights"></param>
/// <param name="resultSelector"></param>
/// <returns></returns>
public static IEnumerable<TResult> LeftZip<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
{
if(lefts == null)
throw new ArgumentNullException("lefts");
if(rights == null)
throw new ArgumentNullException("rights");
if (resultSelector == null)
throw new ArgumentNullException("resultSelector");
return LeftZipImpl(lefts, rights, resultSelector);
}
/// <summary>
/// The Implementation of LeftZip
/// </summary>
/// <typeparam name="TLeft"></typeparam>
/// <typeparam name="TRight"></typeparam>
/// <typeparam name="TResult"></typeparam>
/// <param name="lefts"></param>
/// <param name="rights"></param>
/// <param name="resultSelector"></param>
/// <returns></returns>
private static IEnumerable<TResult> LeftZipImpl<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
{
using (var left = lefts.GetEnumerator())
{
using (var right = rights.GetEnumerator())
{
while (left.MoveNext())
{
if (right.MoveNext())
{
yield return resultSelector(left.Current, right.Current);
}
else
{
do
{
yield return resultSelector(left.Current, default(TRight));
} while (left.MoveNext());
yield break;
}
}
}
}
}
调用LeftZip,代码如下:
var r = Enumerable.Range(1, 10).LeftZip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20
//6
//7
//8
//9
//10
2.自定义查询操作符
之前,我们在实现枚举器的时候有一种自实现形式,即不继承IEnumerable和IEnumerator接口,自定义一个实现GetEnumerator()的类和一个实现Current和MoveNext的类,即可使用foreach进行迭代。我们还知道LINQ语句转换成了扩展方法的链式调用,标准查询操作符转换了同名扩展方法(首字母大写)。那么,如果我们自己去实现标准查询操作符的同名扩展方法,会不会得到执行呢?
开始尝试,创建一个静态类LinqExtensions,实现Where扩展方法,如下:
/// <summary>
/// Filters a sequence of values based on a predicate.
/// </summary>
/// <typeparam name="TResult"></typeparam>
/// <param name="source"></param>
/// <param name="predicate"></param>
/// <returns></returns>
public static IEnumerable<TResult> Where<TResult>(this IEnumerable<TResult> source,
Func<TResult, bool> predicate)
{
if (source == null)
throw new ArgumentNullException("source");
if (predicate == null)
throw new ArgumentNullException("predicate");
return WhereImpl(source, predicate);
}
/// <summary>
/// The implementation of Where
/// </summary>
/// <typeparam name="TResult"></typeparam>
/// <param name="source"></param>
/// <param name="predicate"></param>
/// <returns></returns>
private static IEnumerable<TResult> WhereImpl<TResult>(this IEnumerable<TResult> source,
Func<TResult, bool> predicate)
{
using (var e = source.GetEnumerator())
{
while (e.MoveNext())
{
if (predicate(e.Current))
yield return e.Current;
}
}
}
调用部分代码如下:
var r = from e in Enumerable.Range(1, 10)
where e%2 == 0
select e;
foreach (var i in r)
{
Console.WriteLine(i);
}
//output:
//2
//4
//6
//8
//10
如何判断Where扩展方法被调用了呢?调试、在VS中选中Where然后F12转到定义和在Where扩展方法实现中打印输出都可以判断。你可能会有疑问?方法签名是不是要一致呢?答案是否定的。你可以将Where扩展方法改名为Select,然后调用改为如下:
var r = from e in Enumerable.Range(1, 10)
//where e % 2 == 0
select e % 2 == 0;
//output:
//2
//4
//6
//8
//10
最后结合枚举器,举一个例子:
public class Collection<T>
{
private T[] items;
public Collection()
{
}
public Collection(IEnumerable<T> collection)
{
if (collection == null)
throw new ArgumentNullException("collection");
items = new T[collection.Count()];
Array.Copy(collection.ToArray(), items, collection.Count());
}
public static implicit operator Collection<T>(T[] arr)
{
Collection<T> collection = new Collection<T>();
collection.items = new T[arr.Length];
Array.Copy(arr, collection.items, arr.Length);
return collection;
}
public ItemEnumerator GetEnumerator()
{
return new ItemEnumerator(items);
}
#region Item Enumerator
public class ItemEnumerator : IDisposable
{
private T[] items;
private int index = -1;
public ItemEnumerator(T[] arr)
{
this.items = arr;
}
/// <summary>
/// Current属性
/// </summary>
public T Current
{
get
{
if (index < 0 || index > items.Length - 1)
throw new InvalidOperationException();
return items[index];
}
}
/// <summary>
/// MoveNext方法
/// </summary>
/// <returns></returns>
public bool MoveNext()
{
if (index < items.Length - 1)
{
index++;
return true;
}
else
{
return false;
}
}
public void Reset()
{
index = -1;
}
#region IDisposable 成员
public void Dispose()
{
index = -1;
}
#endregion
}
#endregion
}
public static class EnumerableExtensions
{
public static Collection<T> Where<T>(this Collection<T> source, Func<T, bool> predicate)
{
if (source == null)
throw new ArgumentNullException("source");
if (predicate == null)
throw new ArgumentNullException("predicate");
return WhereImpl(source, predicate).ToCollection();
}
private static IEnumerable<T> WhereImpl<T>(this Collection<T> source, Func<T, bool> predicate)
{
using (var e = source.GetEnumerator())
{
while (e.MoveNext())
{
if (predicate(e.Current))
{
yield return e.Current;
}
}
}
}
public static Collection<TResult> Select<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
{
if (source == null)
throw new ArgumentNullException("source");
if (selector == null)
throw new ArgumentNullException("selector");
return SelectImpl(source, selector).ToCollection();
}
private static IEnumerable<TResult> SelectImpl<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
{
using (var e = source.GetEnumerator())
{
while (e.MoveNext())
{
yield return selector(e.Current);
}
}
}
public static Collection<T> ToCollection<T>(this IEnumerable<T> source)
{
if (source == null)
throw new ArgumentNullException("source");
return new Collection<T>(source);
}
}
包含两个类,一个是作为数据源,一个是用于扩展,调用方法如下:
Collection<int> collection = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
var r = from c in collection
where c % 2 == 0
select c;
foreach (var i in r)
{
Console.WriteLine(i);
}
//output:
//2
//4
//6
//8
//10
3.简单模拟LINQ to SQL
在LINQ之路(2)中,我们简单介绍了LINQ to SQL的原理。在这里,我们通过简单模拟LINQ to SQL来更一步了解LINQ to SQL原理。
首先创建数据源,创建Query
public class Query<T> : IQueryable<T>
{
#region 字段
private QueryProvider provider;
private Expression expression;
#endregion
#region 属性
#endregion
#region 构造函数
public Query(QueryProvider provider)
{
if (provider == null)
throw new ArgumentNullException("provider");
this.provider = provider;
this.expression = Expression.Constant(this);
}
public Query(QueryProvider provider, Expression expression)
{
if (provider == null)
throw new ArgumentNullException("provider");
if (expression == null)
throw new ArgumentNullException("expression");
if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
throw new ArgumentOutOfRangeException("expression");
this.provider = provider;
this.expression = expression;
}
#endregion
#region 方法
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>) this.provider.Execute(this.expression)).GetEnumerator();
}
#endregion
#region IEnumerable<T> 成员
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return this.GetEnumerator();
}
#endregion
#region IEnumerable 成员
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}
#endregion
#region IQueryable 成员
Type IQueryable.ElementType
{
get { return typeof(T); }
}
Expression IQueryable.Expression
{
get { return this.expression; }
}
IQueryProvider IQueryable.Provider
{
get { return this.provider; }
}
#endregion
}
这个类比较简单,实现接口并初始化参数。
再来看Provider,创建QueryProvider类,实现IQueryProvider接口:
public class QueryProvider:IQueryProvider
{
#region 字段
private IDbConnection dbConnection;
#endregion
#region 属性
#endregion
#region 构造函数
public QueryProvider(IDbConnection dbConnection)
{
this.dbConnection = dbConnection;
}
#endregion
#region 方法
#endregion
#region IQueryProvider 成员
public IQueryable<TElement> CreateQuery<TElement>(System.Linq.Expressions.Expression expression)
{
return new Query<TElement>(this, expression);
}
public IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
{
var type = expression.Type;
try
{
return (IQueryable) Activator.CreateInstance(typeof (Query<>).MakeGenericType(type), this, expression);
}
catch (TargetInvocationException e)
{
throw e.InnerException;
}
}
TResult IQueryProvider.Execute<TResult>(System.Linq.Expressions.Expression expression)
{
return (TResult) this.Execute(expression);
}
object IQueryProvider.Execute(System.Linq.Expressions.Expression expression)
{
return this.Execute(expression);
}
public virtual object Execute(Expression expression)
{
if(expression == null)
throw new ArgumentNullException("expression");
return ExecuteImpl(expression);
}
private IEnumerable ExecuteImpl(Expression expression)
{
//var type = expression.Type;
//var entityType = type.GetGenericArguments()[0];
List<Product> products = new List<Product>();
QueryTranslator queryTranslator = new QueryTranslator();
var cmdText = queryTranslator.Translate(expression);
IDbCommand cmd = dbConnection.CreateCommand();
cmd.CommandText = cmdText;
using (IDataReader dataReader = cmd.ExecuteReader())
{
while (dataReader.Read())
{
Product product = new Product();
product.ID = dataReader.GetInt32(0);
product.Name = dataReader.GetString(1);
product.Type = dataReader.GetInt32(2);
products.Add(product);
}
}
return products;
}
#endregion
}
再来看看查询翻译类,创建QueryTranslator类,继承自ExpressionVisitor抽象类:
public class QueryTranslator:ExpressionVisitor
{
#region 字段
private StringBuilder sb;
#endregion
#region 属性
#endregion
#region 构造函数
public QueryTranslator()
{
}
#endregion
#region 方法
public string Translate(Expression expression)
{
this.sb = new StringBuilder();
this.Visit(expression);
return this.sb.ToString();
}
private static Expression StripQuotes(Expression e)
{
while (e.NodeType == ExpressionType.Quote)
{
e = ((UnaryExpression) e).Operand;
}
return e;
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType == typeof (Queryable) &&
node.Method.Name == "Where")
{
sb.Append("SELECT * FROM (");
this.Visit(node.Arguments[0]);
sb.Append(") AS T WHERE ");
LambdaExpression lambda = (LambdaExpression) StripQuotes(node.Arguments[1]);
this.Visit(lambda.Body);
return node;
}
throw new NotSupportedException(string.Format("The Method '{0}' is not supported", node.Method.Name));
}
protected override Expression VisitBinary(BinaryExpression node)
{
sb.Append("(");
this.Visit(node.Left);
switch (node.NodeType)
{
case ExpressionType.Equal:
sb.Append(" = ");
break;
case ExpressionType.NotEqual:
sb.Append(" <> ");
break;
case ExpressionType.GreaterThan:
sb.Append(" > ");
break;
case ExpressionType.GreaterThanOrEqual:
sb.Append(" >= ");
break;
case ExpressionType.LessThan:
sb.Append(" < ");
break;
case ExpressionType.LessThanOrEqual:
sb.Append(" <= ");
break;
default:
throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", node.NodeType));
}
this.Visit(node.Right);
sb.Append(")");
return node;
}
protected override Expression VisitConstant(ConstantExpression node)
{
IQueryable q = node.Value as IQueryable;
if (q != null)
{
sb.Append("SELECT * FROM ");
sb.Append(DataContext.MetaTables.FirstOrDefault(f => f.Type == q.ElementType).TableName);
return node;
}
else if(node.Value == null)
{
sb.Append("NULL");
}
else
{
switch (Type.GetTypeCode(node.Value.GetType()))
{
case TypeCode.Boolean:
sb.Append(((bool) node.Value) ? 1 : 0);
break;
case TypeCode.String:
sb.AppendFormat("'{0}'", node.Value);
break;
case TypeCode.Object:
throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", node.Value));
default:
sb.Append(node.Value);
break;
}
}
return node;
}
protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
{
sb.Append(node.Member.Name);
return node;
}
throw new NotSupportedException(string.Format("The member '{0}' is not supported", node.Member.Name));
}
#endregion
}
重写Visit相关方法,以Visitor模式解析表达式目录树。
最后来看下DataContext的实现:
public class DataContext : IDisposable
{
#region 字段
private IDbConnection dbConnection;
private static List<MetaTable> metaTables;
#endregion
#region 属性
public TextWriter Log { get; set; }
public IDbConnection DbConnection
{
get { return this.dbConnection; }
}
public static List<MetaTable> MetaTables
{
get { return metaTables; }
}
#endregion
#region 构造函数
public DataContext(string connString)
{
if (connString == null)
throw new ArgumentNullException(connString);
dbConnection = new SqlConnection(connString);
dbConnection.Open();
InitTables();
}
#endregion
#region 方法
private void InitTables()
{
metaTables = new List<MetaTable>();
var props = this.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
foreach (var prop in props)
{
var propType = prop.PropertyType;
if (propType.IsGenericType && propType.GetGenericTypeDefinition() == typeof (Query<>))
{
var entityType = propType.GetGenericArguments()[0];
var entityAttr = entityType.GetCustomAttribute<MappingAttribute>(true);
if (entityAttr != null)
{
var metaTable = new MetaTable();
metaTable.Type = entityType;
metaTable.TableName = entityAttr.Name;
metaTable.MappingAttribute = entityAttr;
var columnProps = entityType.GetProperties(BindingFlags.Public | BindingFlags.Instance);
foreach (var columnProp in columnProps)
{
var columnPropAttr = columnProp.GetCustomAttribute<MappingAttribute>(true);
if (columnPropAttr != null)
{
MetaColumn metaColumn = new MetaColumn();
metaColumn.MappingAttribute = columnPropAttr;
metaColumn.ColumnName = columnPropAttr.Name;
metaColumn.PropertyInfo = columnProp;
metaTable.MetaColumns.Add(metaColumn);
}
}
metaTables.Add(metaTable);
}
}
}
}
#endregion
#region IDisposable 成员
protected virtual void Dispose(bool disposing)
{
if (!disposing) return;
if (dbConnection != null)
dbConnection.Close();
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
#endregion
}
[Database(Name = "IT_Company")]
public class QueryDataContext : DataContext
{
public QueryDataContext(string connString)
: base(connString)
{
QueryProvider provider = new QueryProvider(DbConnection);
Products = new Query<Product>(provider);
}
public Query<Product> Products
{
get;
set;
}
}
调用如下:
class Program
{
private static readonly string connString =
"Data Source=.;Initial Catalog=IT_Company;Persist Security Info=True;User ID=sa;Password=123456";
static void Main(string[] args)
{
using (var context = new QueryDataContext(connString))
{
var query = from product in context.Products
where product.Type == 1
select product;
foreach (var product in query)
{
Console.WriteLine(product.Name);
}
Console.ReadKey();
}
}
}
//output:
//MG500
//MG1000