模拟EF CodeFist 实现自己的ORM

一.什么是ORM

    对象关系映射(Object Relational Mapping,简称ORM)模式是一种为了解决面向对象与关系数据库存在的互不匹配的现象的技术。

    简单来说,ORM 是通过使用描述对象和数据库之间映射的元数据,将程序中的对象自动持久化到关系数据库中或者将数据库的数据拉取出来

二.EF基本原理

 1.EF 是微软以 ADO.NET 为基础所发展出来的对象关系对应 (O/R Mapping) 解决方案

 2.EF 核心对象DbContext,其基本原理是,实现系统IQueryable<T>接口,通过反射,获取SQL语句,操作数据库

三.模拟EF

1.模拟EF 首先要自定义解析lamdba表达式(解析表达式 是个难点,需要仔细调试)

    1)构造表达式解析入口

       

 1 /// <summary>
 2         /// 通过Lambda解析为Sql
 3         /// </summary>
 4         /// <param name="func"></param>
 5         /// <returns></returns>
 6         public static string GetSqlByExpression(Expression func, DirectionType dirType = DirectionType.None)
 7         {
 8             var getExp = func;
 9             var result = "";
10             if (getExp is UnaryExpression)
11             {
12                 result += VisitUnaryExpression((UnaryExpression)getExp);
13             }
14             if (getExp is BinaryExpression)
15             {
16                 result += VisitBinaryExpression((BinaryExpression)getExp);
17             }
18             if (getExp is TypeBinaryExpression)
19             {
20                 result += VisitTypeBinaryExpression((TypeBinaryExpression)getExp);
21             }
22             if (getExp is ConditionalExpression)
23             {
24                 result += VisitConditionalExpression((ConditionalExpression)getExp);
25             }
26             if (getExp is ConstantExpression)
27             {
28                 result += VisitConstantExpression((ConstantExpression)getExp);
29             }
30             if (getExp is ParameterExpression)
31             {
32                 result += VisitParameterExpression((ParameterExpression)getExp);
33             }
34             if (getExp is MemberExpression)
35             {
36                 result += VisitMemberExpression((MemberExpression)getExp, dirType);
37             }
38             if (getExp is LambdaExpression)
39             {
40                 result += VisitLambdaExpression((LambdaExpression)getExp);
41             }
42             if (getExp is NewExpression)
43             {
44                 result += VisitNewExpression((NewExpression)getExp);
45             }
46             if (getExp is NewArrayExpression)
47             {
48                 result += VisitNewArrayExpression((NewArrayExpression)getExp);
49             }
50             if (getExp is InvocationExpression)
51             {
52                 result += VisitInvocationExpression((InvocationExpression)getExp);
53             }
54             if (getExp is MemberInitExpression)
55             {
56                 result += VisitMemberInitExpression((MemberInitExpression)getExp);
57             }
58             if (getExp is ListInitExpression)
59             {
60                 result += VisitListInitExpression((ListInitExpression)getExp);
61             }
62             if (getExp is MethodCallExpression)
63             {
64                 result += VisitMethodCallExpression((MethodCallExpression)getExp);
65             }
66             return result;
67 
68         }
lamdba解析入口

    2)根据不同的类型,构建不同的解析方法

 /// <summary>
        /// 判断包含变量的表达式
        /// </summary>
        /// <param name="func"></param>
        /// <returns></returns>
        private static string VisitMemberExpression(MemberExpression func, DirectionType dirType)
        {
            object value;
            if (dirType == DirectionType.Left || dirType == DirectionType.None)
            {
                value = func.Member.Name;
            }
            else
            {


                switch (func.Type.Name)
                {
                    case "Int32":
                        {
                            var getter = Expression.Lambda<Func<int>>(func).Compile();
                            value = getter();
                        }
                        break;
                    case "String":
                        {
                            var getter = Expression.Lambda<Func<string>>(func).Compile();
                            value = "'" + getter() + "'";
                        }
                        break;
                    case "DateTime":
                        {
                            var getter = Expression.Lambda<Func<DateTime>>(func).Compile();
                            value = "'" + getter().ToString("yyyy-MM-dd HH:mm:ss") + "'";
                        }
                        break;
                    default:
                        {
                            var getter = Expression.Lambda<Func<object>>(func).Compile();
                            value = getter();
                        }
                        break;
                }
            }
            return value.ToString();
        }
View Code
   private static string VisitUnaryExpression(UnaryExpression func)
        {
            var result = "";
            result = GetSqlByExpression(func.Operand);
            return result;
        }
View Code
 1 private static string VisitBinaryExpression(BinaryExpression func)
 2         {
 3             //{(((p.Id == "1") AndAlso (p.OrderNo == "fasdf")) AndAlso (p.CreateTime == DateTime.Now))}
 4             var result = "(";
 5             result += "" + GetSqlByExpression(func.Left, DirectionType.Left) + "";
 6             result += GetNodeType(func.NodeType);
 7             result += "" + GetSqlByExpression(func.Right, DirectionType.Right) + "";
 8             result += ")";
 9             return result;
10         }
View Code
private static string VisitTypeBinaryExpression(TypeBinaryExpression func)
        {
            return "";
        }
View Code
private static string VisitConditionalExpression(ConditionalExpression func)
        {
            return "";
        }
View Code
 1 private static string VisitConstantExpression(ConstantExpression func)
 2         {
 3             var result = "";
 4             if (func.Value.GetType() == typeof(String))
 5             {
 6                 result += "'" + (func.Value.ToString()) + "'";
 7             }
 8             else if (func.Value.GetType() == typeof(Int32))
 9             {
10                 result += "" + (func.Value.ToString()) + "";
11             }
12             else
13             {
14                 throw new Exception("请实现类型");
15             }
16             return result;
17         }
View Code
 1 private static string VisitParameterExpression(ParameterExpression func)
 2         {
 3             var propers = func.Type.GetProperties();
 4             string result = "";
 5 
 6             for (int i = 0; i < propers.Length; i++)
 7             {
 8                 var item = propers[i];
 9                 var itemStr = GetProperInfo(item);
10                 if (!string.IsNullOrEmpty(itemStr))
11                 {
12                     result += itemStr + ",";
13                 }
14             }
15             result = result.TrimEnd(',');
16             return result;
17         }
View Code
 1  /// <summary>
 2         /// 判断包含函数的表达式
 3         /// </summary>
 4         /// <param name="func"></param>
 5         /// <returns></returns>
 6         private static String VisitMethodCallExpression(MethodCallExpression func)
 7         {
 8             var result = "";
 9             if (func.Method.Name == "Where")
10             {
11                 result += " Where ";
12                 var cente = func.Arguments[1];
13                 result += GetSqlByExpression(cente);
14             }
15             else if (func.Method.Name.Contains("Contains"))
16             {
17                 //获得调用者的内容元素
18                 var getter = Expression.Lambda<Func<object>>(func.Object).Compile();
19                 var data = getter() as IEnumerable;
20                 //获得字段
21                 var caller = func.Arguments[0];
22                 while (caller.NodeType == ExpressionType.Call)
23                 {
24                     caller = (caller as MethodCallExpression).Object;
25                 }
26                 var field = VisitMemberExpression(caller as MemberExpression, DirectionType.Left);
27                 var list = (from object i in data select "'" + i + "'").ToList();
28                 result += field + " IN (" + string.Join(",", list.Cast<string>().ToArray()) + ") ";
29             }
30             else if (func.Method.Name.Contains("Select"))
31             {
32                 result += " Select ";
33                 var cente = func.Arguments[1];
34                 result += GetSqlByExpression(cente);
35             }
36             return result;
37         }
View Code
1  private static string VisitLambdaExpression(LambdaExpression func)
2         {
3             var result = "";
4             result += GetSqlByExpression(func.Body);
5             return result;
6         }
View Code
1  private static string VisitNewExpression(NewExpression func)
2         {
3             var result = "";
4             result += GetSqlByExpression(func.Arguments[0]);
5             return result;
6         }
View Code

  3)根据 ExpressionType 判断条件类型

 

 1 private static string GetNodeType(ExpressionType expType)
 2         {
 3             var result = "";
 4             if (expType == ExpressionType.AndAlso)
 5             {
 6                 result += " and ";
 7             }
 8             if (expType == ExpressionType.Or)
 9             {
10                 result += " or ";
11             }
12             if (expType == ExpressionType.Equal)
13             {
14                 result += " = ";
15             }
16             if (expType == ExpressionType.NotEqual)
17             {
18                 result += " <> ";
19             }
20             if (expType == ExpressionType.Conditional)
21             {
22                 result += " > ";
23             }
24             if (expType == ExpressionType.LessThan)
25             {
26                 result += " < ";
27             }
28             if (expType == ExpressionType.GreaterThanOrEqual)
29             {
30                 result += " >= ";
31             }
32             if (expType == ExpressionType.LeftShiftAssign)
33             {
34                 result += " <= ";
35             }
36             return result;
37         }
View Code

   4)根据ExpressionType 判断Expression 子类类型(本列中未用到)

 

 1         public static Expression GetExpression(Expression exp)
 2         {
 3             if (exp == null)
 4                 return exp;
 5             switch (exp.NodeType)
 6             {
 7                 case ExpressionType.Negate:
 8                 case ExpressionType.NegateChecked:
 9                 case ExpressionType.Not:
10                 case ExpressionType.Convert:
11                 case ExpressionType.ConvertChecked:
12                 case ExpressionType.ArrayLength:
13                 case ExpressionType.Quote:
14                 case ExpressionType.TypeAs:
15                     return (UnaryExpression)exp;
16                 case ExpressionType.Add:
17                 case ExpressionType.AddChecked:
18                 case ExpressionType.Subtract:
19                 case ExpressionType.SubtractChecked:
20                 case ExpressionType.Multiply:
21                 case ExpressionType.MultiplyChecked:
22                 case ExpressionType.Divide:
23                 case ExpressionType.Modulo:
24                 case ExpressionType.And:
25                 case ExpressionType.AndAlso:
26                 case ExpressionType.Or:
27                 case ExpressionType.OrElse:
28                 case ExpressionType.LessThan:
29                 case ExpressionType.LessThanOrEqual:
30                 case ExpressionType.GreaterThan:
31                 case ExpressionType.GreaterThanOrEqual:
32                 case ExpressionType.Equal:
33                 case ExpressionType.NotEqual:
34                 case ExpressionType.Coalesce:
35                 case ExpressionType.ArrayIndex:
36                 case ExpressionType.RightShift:
37                 case ExpressionType.LeftShift:
38                 case ExpressionType.ExclusiveOr:
39                     return (BinaryExpression)exp;
40                 case ExpressionType.TypeIs:
41                     return (TypeBinaryExpression)exp;
42                 case ExpressionType.Conditional:
43                     return (ConditionalExpression)exp;
44                 case ExpressionType.Constant:
45                     return (ConstantExpression)exp;
46                 case ExpressionType.Parameter:
47                     return (ParameterExpression)exp;
48                 case ExpressionType.MemberAccess:
49                     return (MemberExpression)exp;
50                 case ExpressionType.Call:
51                     return (MethodCallExpression)exp;
52                 case ExpressionType.Lambda:
53                     return (LambdaExpression)exp;
54                 case ExpressionType.New:
55                     return (NewExpression)exp;
56                 case ExpressionType.NewArrayInit:
57                 case ExpressionType.NewArrayBounds:
58                     return (NewArrayExpression)exp;
59                 case ExpressionType.Invoke:
60                     return (InvocationExpression)exp;
61                 case ExpressionType.MemberInit:
62                     return (MemberInitExpression)exp;
63                 case ExpressionType.ListInit:
64                     return (ListInitExpression)exp;
65                 default:
66                     throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
67             }
View Code

2.构建上下文类型IECContext

  

 1  public class IECContext: IDisposable
 2     {
 3         public static string ConnectionString { get; set; }
 4         public static string ConnectionKey { get; set; }
 5 
 6         public IECContext(string key)
 7         {
 8             ConnectionKey = key;
 9             ConnectionString = ConfigurationManager.ConnectionStrings[key].ConnectionString;
10         }
11 
12         public void Dispose()
13         {
14             this.Dispose();
15         }
16     }
View Code

3.构建DBSet 对象,需要实现IQueryable<T> 接口

 1 public class DBSet<T> : IQueryable<T>, IQueryable, IEnumerable<T>, IEnumerable, IOrderedQueryable<T>, IOrderedQueryable
 2     {
 3 
 4         QueryProvider provider;
 5         Expression expression;
 6         public DBSet(QueryProvider provider)
 7         {
 8 
 9             if (provider == null)
10             {
11                 throw new ArgumentNullException("QueryProvider不能为空");
12             }
13             this.provider = provider;
14             this.expression = Expression.Constant(this);
15         }
16 
17         public DBSet(QueryProvider provider, Expression expression)
18         {
19             if (provider == null)
20             {
21 
22                 throw new ArgumentNullException("QueryProvider不能为空");
23             }
24 
25             if (expression == null)
26             {
27                 throw new ArgumentNullException("Expression不能为空");
28             }
29 
30             if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
31             {
32                 throw new ArgumentOutOfRangeException("类型错误");
33             }
34 
35             this.provider = provider;
36             this.expression = expression;
37         }
38         Expression IQueryable.Expression
39         {
40             get { return this.expression; }
41         }
42 
43         Type IQueryable.ElementType
44         {
45             get { return typeof(T); }
46         }
47         IQueryProvider IQueryable.Provider
48         {
49             get { return this.provider; }
50         }
51         public IEnumerator<T> GetEnumerator()
52         {
53             return ((IEnumerable<T>)this.provider.Execute(this.expression)).GetEnumerator();
54         }
55         IEnumerator IEnumerable.GetEnumerator()
56         {
57             return ((IEnumerable)this.provider.Execute(this.expression)).GetEnumerator();
58         }
59 
60     }
View Code

 3.1添加 Incloud方法,兼容多表链接查询(比较重要)

 public DBSet<T> Incloud(string cloudName)
        {
            if (provider.CloudNames == null)
            {
                provider.CloudNames = new Queue<string>();
            }
            provider.CloudNames.Enqueue(cloudName);
            return this;
        }
View Code

 

 

4.实现 IQueryProvider接口

public abstract class QueryProvider : IQueryProvider
    {
        protected QueryProvider() { }
        public string SQL { get; set; }

        public Queue<string> CloudNames { get; set; }
        IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression)
        {
            var sqlWherr = ExpressionHelp.GetSqlByExpression(expression);
            if (string.IsNullOrEmpty(SQL))
            {
                if (sqlWherr.ToLower().Contains("select"))
                {
                    SQL = string.Format("{0} from {1}", sqlWherr, GetTableName<S>());
                }
                else
                {

                    SQL = string.Format("select *  from {0} {1}", GetTableName<S>(), sqlWherr);
                }
            }
            else
            {
                if (sqlWherr.ToLower().Contains("select"))
                {
                    SQL = string.Format("{0} from ({1}) t", sqlWherr, SQL);
                }
                else
                {

                    SQL = string.Format("select *  from ({0}) t {1}", sqlWherr, GetTableName<S>());
                }
            }

            return new DBSet<S>(this, expression);
        }
        private string GetTableName<T>()
        {
            return TableName();
        }
        IQueryable IQueryProvider.CreateQuery(Expression expression)
        {
            return null;

        }
        S IQueryProvider.Execute<S>(Expression expression)
        {
            return (S)this.Execute(expression);
        }

        object IQueryProvider.Execute(Expression expression)
        {
            return this.Execute(expression);
        }

        public abstract object Execute(Expression expression);
        public abstract string TableName();
    }
View Code

5.最后 实现 ToQueryList 扩展方法,将转换成功的SQL 链接数据库查询(比较重要,也比较麻烦)

 1 public static List<T> ToQueryList<T>(this IQueryable<T> query)
 2         {
 3             var sql = query.ToString();
 4             ExecProcSql execProc = new ExecProcSql(IECContext.ConnectionKey);
 5             var dataSet = execProc.ExecuteDataSet(sql);
 6             var dt = dataSet.Tables[0];
 7             var list = dt.DataSetToList<T>();
 8             var myQuery = query as DBSet<T>;
 9 
10             if (myQuery != null)
11             {
12                 var queue = myQuery.GetColudNames();
13                 if (queue.Count > 0)
14                 {
15                     var count = queue.Count;
16                     for (int i = 0; i < count; i++)
17                     {
18                         var coludName = queue.Dequeue();
19 
20                         list = GetClouds(default(T), list, coludName);
21                     }
22                 }
23 
24             }
25 
26 
27             return list;
28         }
View Code

5.1 应用 反射 和递归,进行字表数据查询,复制(比较重要,也比较麻烦)

 1 private static List<T> GetClouds<T>(T t, List<T> list, string cloudName)
 2         {
 3             if (list == null)
 4             {
 5                 list = new List<T>();
 6                 if (t != null)
 7                 {
 8                     list.Add(t);
 9                 }
10             }
11 
12             var result = list;
13             var clouds = cloudName.Split(new char[] { '.' }, StringSplitOptions.RemoveEmptyEntries).ToList();
14             if (clouds.Count <= 0)
15             {
16                 return result;
17             }
18             var proper = typeof(T).GetProperty(clouds[0]);
19             if (proper == null)
20             {
21                 throw new Exception("属性不存在");
22             }
23             string sql = "";
24             List<string> ids = new List<string>();
25             for (int i = 0; i < result.Count; i++)
26             {
27                 var p = typeof(T).GetProperty("Id");
28                 if (p == null)
29                 {
30                     throw new Exception("必须存在Id 列");
31                 }
32                 var id = p.GetValue(result[i]);
33                 if (id != null)
34                 {
35                     ids.Add(id.ToString());
36                 }
37             }
38 
39             clouds.RemoveAt(0);
40             //如果是 一对多 对象
41             if (proper.PropertyType.GetInterface("IEnumerable") == typeof(System.Collections.IEnumerable))
42             {
43                 var pType = proper.PropertyType.GetGenericArguments()[0];
44                 sql = string.Format("SELECT * FROM {0} where {1} in ({2})", pType.Name, typeof(T).Name + "_Id", string.Join(",", ids.Select(p => "'" + p + "'").ToArray()));
45                 ExecProcSql execProc = new ExecProcSql(IECContext.ConnectionKey);
46                 var dataSet = execProc.ExecuteDataSet(sql);
47                 var dt = dataSet.Tables[0];
48                 GetDataSetToList(dt, result, proper.Name, clouds);
49             }
50             else//如果是一对一 对象
51             {
52                 sql = string.Format("select * from {0} where {1} in({2})", typeof(T).Name, "Id", string.Join(",", ids.Select(p => "'" + p + "'").ToArray()));
53 
54                 ExecProcSql execProc = new ExecProcSql(IECContext.ConnectionKey);
55                 var dataSet = execProc.ExecuteDataSet(sql);
56                 ///T 类型的集合
57                 var dt1 = dataSet.Tables[0];
58                 ids = new List<string>();
59                 //var preItem=
60                 for (int i = 0; i < dt1.Rows.Count; i++)
61                 {
62                     var dr = dt1.Rows[i];
63                     if (dt1.Columns.Contains(proper.Name + "_Id"))
64                     {
65                         var value = dr[proper.Name + "_Id"].ToString();
66                         ids.Add(value);
67                     }
68 
69                 }
70                 ids = ids.Distinct().ToList();
71                 if (ids.Count <= 0)
72                 {
73                     return result;
74                 }
75                 sql = string.Format("select * from {0} where {1} in({2})", proper.PropertyType.Name, "Id", string.Join(",", ids.Select(p => "'" + p + "'").ToArray()));
76                 var dataSet2 = execProc.ExecuteDataSet(sql);
77                 ///cloudName 类型的集合
78                 var dt2 = dataSet2.Tables[0];
79                 CloudDataTableToList(dt1, dt2, result, proper.Name, clouds);
80             }
81             return result;
82 
83         }
View Code

由此,自己的ORM 基本完成(这里只介绍查询,新增,修改,删除相对而言比较简单,不做具体介绍)

四.配置,测试

1.继承IECContext

 public class ECContext : IECContext
    {
        public ECContext() : base("DBContext")
        {
            QueryOrder = new DBSet<OrdersInfo>(new MyQueryProvider<OrdersInfo>());
        }
        public DBSet<OrdersInfo> QueryOrder { get; set; }

        public DBSet<User> User { get; set; }

        public DBSet<OrderItemInfo> OrderItemInfo { get; set; }

        public DBSet<ProductInfo> ProductInfo { get; set; }

        public DBSet<Info> Info { get; set; }
    }
View Code

2.添加业务模型,以订单,订单项,产品,产品信息 为列

 public class OrdersInfo
    {
        public string Id { get; set; }

        public string OrderNo { get; set; }

        public DateTime CreateTime { get; set; }

        public int Member { get; set; }

        public MemberType MemberType { get; set; }

        public User User { get; set; }

        public List<OrderItemInfo> OrderItems { get; set; }


    }

    public class User
    {
        public string Id { get; set; }

        public string Name { get; set; }
    }

    public class OrderItemInfo
    {
        public string Id { get; set; }
        public OrdersInfo Order { get; set; }

        public string ProductName { get; set; }

        public ProductInfo ProductInfo { get; set; }
    }
    public class ProductInfo
    {
        public string Id { get; set; }
        public string ProductName { get; set; }

        public Info Info { get; set; }

    }
    public class Info
    {
        public string Id { get; set; }

        public decimal Price { get; set; }
    }
    public enum MemberType
    {
        None = 1
    }
View Code

3.配送config

<connectionStrings>
    <add name="DBContext" connectionString="server=.;database=test01;uid=sa;pwd=12345678" />
  </connectionStrings>

4.添加测试数据

 

 5.添加测试代码

 static void Main(string[] args)
        {
            using (ECContext context = new ECContext())
            {
                var listst = new List<string> {
                    "1",
                    "2",
                    "3",

                };
                //var obj = context.QueryOrder.Incloud("OrderItemInfo").Where(p => p.Id == "1" && p.OrderNo == "fasdf" && p.Member == 1 && p.CreateTime > DateTime.Now && listst.Contains(p.OrderNo));
                var m = 10;
                var date = DateTime.Now;
                var objInfo2 = context.QueryOrder.Incloud("OrderItems.ProductInfo.Info").Where(p => p.Id == "1").ToQueryList();

                //var obj = context.QueryOrder.Where(p => listst.Contains(p.OrderNo));
                //var obj = context.QueryOrder.Where(p => p.Id == "1"&&p.OrderNo=="fasdf" &&p.Member==1 && p.CreateTime>DateTime.Now&& listst.Contains(p.OrderNo));
                return;
            }
        }
View Code

6.运行

 

 五.写在结尾

 本文重点是 如何解析lamdba表达式 和字表查询部分。由于是自己写的一个demo,实在 看不上眼,就不上传源码了

 本文对于EF 里的 缓存机制,反射机制,数据缓存机制 不做介绍,在demo中 也没有涉及,只是简单的一个类似于EF的 orm框架

 EF 用的时间比较长,一直想写一个自己的ORM 挑战下自己,所以花了一天半左右的时间 写了一个简单的实现,可能会有很多bug ,后期努力。

欢迎各位指正,谢谢 

 

posted @ 2017-07-29 14:15  疯痴傻  阅读(686)  评论(1编辑  收藏  举报