分享:扩展Include关联查询

View Code
/// <summary>
        /// 扩展了DbQuery的Include方法
        /// </summary>
        /// <typeparam name="T">进行关联的对象类型</typeparam>
        /// <param name="query">扩展的ObjectQuery:T</param>
        /// <param name="selector">指定关联对象属性的Linq表达式</param>
        /// <returns></returns>
        public static ObjectQuery<T> Include<T>(this ObjectQuery<T> query, Expression<Func<T, object>> selector)
        {
            string path = new PropertyPathVisitor().GetPropertyPath(selector);
            return query.Include(path);
        }

        /// <summary>
        /// 扩展了DbQuery的Include方法
        /// </summary>
        /// <typeparam name="T">进行关联的对象类型</typeparam>
        /// <param name="query">扩展的DbQuery:T</param>
        /// <param name="selector">指定关联对象属性的Linq表达式</param>
        /// <returns></returns>
        public static DbQuery<T> Include<T>(this DbQuery<T> query, Expression<Func<T, object>> selector)
        {
            string path = new PropertyPathVisitor().GetPropertyPath(selector);
            return query.Include(path);
        }
用于获取属性名称的linq表达式树
View Code
/// <summary>
        /// 属性名称的linq表达式树
        /// </summary>
        class PropertyPathVisitor : ExpressionVisitor
        {
            private Stack<string> _stack;

            /// <summary>
            /// 通过指定的Linq表达式获取属性名称
            /// </summary>
            /// <param name="expression">指定的属性名称</param>
            /// <returns></returns>
            public string GetPropertyPath(Expression expression)
            {
                _stack = new Stack<string>();

                Visit(expression);
                return _stack.Aggregate(new StringBuilder(), (sb, name) => (sb.Length > 0 ? sb.Append(".") : sb).Append(name)).ToString();
            }

            /// <summary>
            /// 
            /// </summary>
            /// <param name="expression"></param>
            /// <returns></returns>
            protected override Expression VisitMember(MemberExpression expression)
            {
                if (_stack != null)
                    _stack.Push(expression.Member.Name);

                return base.VisitMember(expression);
            }

            /// <summary>
            /// 
            /// </summary>
            /// <param name="expression"></param>
            /// <returns></returns>
            protected override Expression VisitMethodCall(MethodCallExpression expression)
            {
                if (IsLinqOperator(expression.Method))
                {
                    for (int i = 1; i < expression.Arguments.Count; i++)
                    {
                        Visit(expression.Arguments[i]);
                    }
                    Visit(expression.Arguments[0]);
                    return expression;
                }
                return base.VisitMethodCall(expression);
            }

            /// <summary>
            /// 判断指定的方法是否是可以执行Linq操作符
            /// </summary>
            /// <param name="method">发现方法的属性并提供对方法元数据的访问。</param>
            /// <returns></returns>
            private static bool IsLinqOperator(MethodInfo method)
            {
                if (method.DeclaringType != typeof(Queryable) && method.DeclaringType != typeof(Enumerable))
                    return false;

                return Attribute.GetCustomAttribute(method, typeof(ExtensionAttribute)) != null;
            }
        }

使用方法:

var lsit = db.ProjectPersonRels.Include(p => p.PersonInfo).Where(p => p.ProjectID == projectId).ToList();
posted @ 2013-01-07 10:14  easeyeah  阅读(561)  评论(0编辑  收藏  举报