反射练习-简易ORM
本篇是我学习反射的一个应用小场景而做的学习笔记,主要是一个小的总结
使用技术:泛型、反射、表达式树...
客户端调用:
static void Main(string[] args) { var connection = new SqlConnection("Data Source=.;User Id=sa;Password=123456;Database=fanDB;"); //增 connection.Insert(new Person() { Name = "fan11", Age = 1 }); connection.Insert(new List<Person> { new Person() { Name = "fan432", Age = 24 }, new Person() { Name = "fan", Age = 4 } }); //删 connection.Delete<Person>(5); connection.Delete(new Person() { ID = 6 }); //改 connection.Update(new Person() { ID = 17, Name = "fanfan", Age = 18 }); //查 var list = connection.Select<Person>(p => p.Name == "fan" || p.Name.Contains("fan1") || p.Name.StartsWith("fan") || p.Name.EndsWith("fan") && p.Age > 3); Console.ReadKey(); }
ORM:
public static class ORM { private const string ID_NAME = "ID"; private const string INSERT_SQL = "INSERT INTO @TABLE_NAME(@COLUMNS) VALUES(@VALUES)"; private const string SELECT_SQL = "SELECT * FROM @TABLE_NAME WHERE @WHERE"; private const string DELETE_SQL = "DELETE FROM @TABLE_NAME WHERE @WHERE"; private const string UPDATE_SQL = "UPDATE @TABLE_NAME SET @UPDATE_COLUMNS WHERE @WHERE"; private static readonly ConcurrentDictionary<Type, PropertyInfo[]> PROPERTIES_CACHE = new System.Collections.Concurrent.ConcurrentDictionary<Type, PropertyInfo[]>(); private static readonly WhereBuilder WHERE_BUILDER = null;//通过Expression生成where static ORM() { WHERE_BUILDER = new WhereBuilder('[', ']'); } public static int Insert<T>(this SqlConnection connection, T entity) { int result = 0; var t = typeof(T); var tableName = t.Name; var columnInfoList = GetColumnInfos(entity); var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME); var columnNames = excludeIDColumns.Select(c => c.Name); var columnParameterNames = excludeIDColumns.Select(c => "@" + c.Name); string sql = INSERT_SQL.Replace("@TABLE_NAME", tableName) .Replace("@COLUMNS", string.Join(',', columnNames)) .Replace("@VALUES", string.Join(',', columnParameterNames)); SqlParameter[] paras = excludeIDColumns.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray(); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } public static int Insert<T>(this SqlConnection connection, List<T> list) { int result = 0; foreach (var entity in list) { result += connection.Insert(entity); } return result; } public static List<T> Select<T>(this SqlConnection connection, Expression<Func<T, bool>> whereExp) where T : new() { List<T> list = new List<T>(); var t = typeof(T); var tableName = t.Name; var wherePart = WHERE_BUILDER.ToSql<T>(whereExp); var whereParameter = wherePart.Parameters; var paras = whereParameter.Select(p => new SqlParameter(p.Key, p.Value)).ToArray(); string sql = SELECT_SQL.Replace("@TABLE_NAME", tableName) .Replace("@WHERE", wherePart.Sql); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); using (var reader = command.ExecuteReader()) { while (reader.Read()) { list.Add(ReaderToEntity<T>(reader)); } } } return list; } public static int Delete<T>(this SqlConnection connection, int ID) { int result = 0; var t = typeof(T); var tableName = t.Name; string sql = DELETE_SQL .Replace("@TABLE_NAME", tableName) .Replace("@WHERE", $"{ID_NAME}=@{ID_NAME}"); SqlParameter[] paras = new SqlParameter[] { new SqlParameter("@" + ID_NAME, ID) }; OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } public static int Delete<T>(this SqlConnection connection, T entity) { var IDProperty = entity.GetType().GetProperty(ID_NAME); int ID = (int)IDProperty.GetValue(entity); return connection.Delete<T>(ID); } public static int Update<T>(this SqlConnection connection, T entity) { int result = 0; var t = typeof(T); var tableName = t.Name; var columnInfoList = GetColumnInfos(entity); var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME); var columnNames = excludeIDColumns.Select(c => c.Name); var columnParameters = excludeIDColumns.Select(c => c.Name + "=@" + c.Name); string sql = UPDATE_SQL.Replace("@TABLE_NAME", tableName) .Replace("@UPDATE_COLUMNS", string.Join(',', columnParameters)) .Replace("@WHERE", $"{ID_NAME}=@ID"); SqlParameter[] paras = columnInfoList.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray(); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } private static T ReaderToEntity<T>(SqlDataReader reader) where T : new() { var entity = Activator.CreateInstance(typeof(T)); var propertyInfos = GetPropertys<T>(); foreach (var propertyInfo in propertyInfos) { var value = reader[propertyInfo.Name]; propertyInfo.SetValue(entity, value); } return (T)entity; } private static PropertyInfo[] GetPropertys<T>() { return PROPERTIES_CACHE.GetOrAdd(typeof(T), t => { return t.GetProperties(); }); } private static List<ColumnInfo> GetColumnInfos<T>(T entity) { var t = entity.GetType(); var columnInfos = new List<ColumnInfo>(); var properties = GetPropertys<T>(); for (int i = 0; i < properties.Length; i++) { var prop = properties[i]; columnInfos.Add(new ColumnInfo(prop.Name, prop.PropertyType.FullName, prop.GetValue(entity))); } return columnInfos; } private static DbType GetDbType(string typeName) { DbType type = DbType.String; switch (typeName) { case "System.String": type = DbType.String; break; case "System.Int32": type = DbType.Int32; break; case "System.Decimal": type = DbType.Decimal;break; //其他类型自己扩展,我就不加了 Guid、DateTime... } return type; } private static void OpenConnection(IDbConnection connection) { if (connection.State != ConnectionState.Open) { connection.Open(); } } } public class ColumnInfo { public ColumnInfo(string name, string typeName, object value) { this.Name = name; this.TypeName = typeName; this.Value = value; } public string Name { get; set; } public string TypeName { get; set; } public object Value { get; set; } }
WhereBuilder:将表达式树转成where子句(从第三方扒下来的)
using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Text; /// <summary> /// 生成Where条件的SQL语句 /// Generating SQL from expression trees /// </summary> public class WhereBuilder { private readonly char _columnBeginChar = '['; private readonly char _columnEndChar = ']'; private System.Collections.ObjectModel.ReadOnlyCollection<ParameterExpression> expressParameterNameCollection; public WhereBuilder(char columnChar = '`') { this._columnBeginChar = this._columnEndChar = columnChar; } public WhereBuilder(char columnBeginChar = '[', char columnEndChar = ']') { this._columnBeginChar = columnBeginChar; this._columnEndChar = columnEndChar; } /// <summary> /// LINQ转SQL /// </summary> /// <typeparam name="T"></typeparam> /// <param name="expression"></param> /// <returns></returns> public WherePart ToSql<T>(Expression<Func<T, bool>> expression) { var i = 1; if (expression.Parameters.Count > 0) { this.expressParameterNameCollection = expression.Parameters; } return Recurse(ref i, expression.Body, isUnary: true); } /// <summary> /// LINQ转SQL /// </summary> /// <typeparam name="T"></typeparam> /// <param name="i">种子值</param> /// <param name="expression"></param> /// <returns></returns> public WherePart ToSql<T>(ref int i, Expression<Func<T, bool>> expression) { if (expression.Parameters.Count > 0) { this.expressParameterNameCollection = expression.Parameters; } return Recurse(ref i, expression.Body, isUnary: true); } /// <summary> /// LINQ转SQL /// </summary> /// <param name="i">种子值</param> /// <param name="expression"></param> /// <param name="isUnary"></param> /// <param name="prefix"></param> /// <param name="postfix"></param> /// <returns></returns> private WherePart Recurse(ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null) { //运算符表达式 if (expression is UnaryExpression) { var unary = (UnaryExpression)expression; //示例:m.birthday=DateTime.Now if (unary.NodeType == ExpressionType.Convert) { var value = GetValue(expression); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } else { //示例:m.Birthday>'2018-10-31' return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true)); } } if (expression is BinaryExpression) { var body = (BinaryExpression)expression; return WherePart.Concat(Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right)); } //常量值表达式 //示例右侧表达式:m.ID=123; if (expression is ConstantExpression) { var constant = (ConstantExpression)expression; var value = constant.Value; if (value is int) { return WherePart.IsSql(value.ToString()); } if (value is string) { value = prefix + (string)value + postfix; } if (value is bool && isUnary) { return WherePart.Concat(WherePart.IsParameter(i++, value), "=", WherePart.IsSql("1")); } return WherePart.IsParameter(i++, value); } //成员表达式 if (expression is MemberExpression) { var member = (MemberExpression)expression; var memberExpress = member.Expression; bool isContainsParameterExpress = false; this.IsContainsParameterExpress(member, ref isContainsParameterExpress); if (member.Member is PropertyInfo && isContainsParameterExpress) { var property = (PropertyInfo)member.Member; //var colName = _tableDef.GetColumnNameFor(property.Name); var colName = property.Name; if (isUnary && member.Type == typeof(bool)) { return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++, true)); } return WherePart.IsSql(string.Format("{0}{1}{2}", this._columnBeginChar, colName, this._columnEndChar)); } if (member.Member is FieldInfo || !isContainsParameterExpress) { var value = GetValue(member); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception($"Expression does not refer to a property or field: {expression}"); } //方法表达式 if (expression is MethodCallExpression) { var methodCall = (MethodCallExpression)expression; //属性表达式中的参数表达式是否是表达式参数集合中的实例(或者表达式中包含的其他表达式中的参数表达式) bool isContainsParameterExpress = false; this.IsContainsParameterExpress(methodCall, ref isContainsParameterExpress); if (isContainsParameterExpress) { // LIKE queries: if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%", postfix: "%")); } if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], postfix: "%")); } if (methodCall.Method == typeof(string).GetMethod("EndsWith", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%")); } // IN queries: if (methodCall.Method.Name == "Contains") { Expression collection; Expression property; if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 2) { collection = methodCall.Arguments[0]; property = methodCall.Arguments[1]; } else if (!methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 1) { collection = methodCall.Object; property = methodCall.Arguments[0]; } else { throw new Exception("Unsupported method call: " + methodCall.Method.Name); } var values = (IEnumerable)GetValue(collection); return WherePart.Concat(Recurse(ref i, property), "IN", WherePart.IsCollection(ref i, values)); } } else { var value = GetValue(expression); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception("Unsupported method call: " + methodCall.Method.Name); } //New表达式 if (expression is NewExpression) { var member = (NewExpression)expression; var value = GetValue(member); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception("Unsupported expression: " + expression.GetType().Name); } /// <summary> /// 判断表达式内部是否含有变量M /// </summary> /// <param name="expression">表达式</param> /// <returns></returns> private void IsContainsParameterExpress(Expression expression, ref bool result) { if (this.expressParameterNameCollection != null && this.expressParameterNameCollection.Count > 0 && expression != null) { if (expression is MemberExpression) { if (this.expressParameterNameCollection.Contains(((MemberExpression)expression).Expression)) { result = true; } } else if (expression is MethodCallExpression) { MethodCallExpression methodCallExpression = (MethodCallExpression)expression; if (methodCallExpression.Object != null) { if (methodCallExpression.Object is MethodCallExpression) { //判断示例1:m.ID.ToString().Contains("123") this.IsContainsParameterExpress(methodCallExpression.Object, ref result); } else if (methodCallExpression.Object is MemberExpression) { //判断示例2:m.ID.Contains(123) MemberExpression MemberExpression = (MemberExpression)methodCallExpression.Object; if (MemberExpression.Expression != null && this.expressParameterNameCollection.Contains(MemberExpression.Expression)) { result = true; } } } //判断示例3: int[] ids=new ids[]{1,2,3}; ids.Contains(m.ID) if (result == false && methodCallExpression.Arguments != null && methodCallExpression.Arguments.Count > 0) { foreach (Expression express in methodCallExpression.Arguments) { if (express is MemberExpression || express is MethodCallExpression) { this.IsContainsParameterExpress(express, ref result); } else if (this.expressParameterNameCollection.Contains(express)) { result = true; break; } } } } } } private static object GetValue(Expression member) { // source: http://stackoverflow.com/a/2616980/291955 var objectMember = Expression.Convert(member, typeof(object)); var getterLambda = Expression.Lambda<Func<object>>(objectMember); var getter = getterLambda.Compile(); return getter(); } private static string NodeTypeToString(ExpressionType nodeType) { switch (nodeType) { case ExpressionType.Add: return "+"; case ExpressionType.And: return "&"; case ExpressionType.AndAlso: return "AND"; case ExpressionType.Divide: return "/"; case ExpressionType.Equal: return "="; case ExpressionType.ExclusiveOr: return "^"; case ExpressionType.GreaterThan: return ">"; case ExpressionType.GreaterThanOrEqual: return ">="; case ExpressionType.LessThan: return "<"; case ExpressionType.LessThanOrEqual: return "<="; case ExpressionType.Modulo: return "%"; case ExpressionType.Multiply: return "*"; case ExpressionType.Negate: return "-"; case ExpressionType.Not: return "NOT"; case ExpressionType.NotEqual: return "<>"; case ExpressionType.Or: return "|"; case ExpressionType.OrElse: return "OR"; case ExpressionType.Subtract: return "-"; } throw new Exception($"Unsupported node type: {nodeType}"); } } public class WherePart { /// <summary> /// 含有参数变量的SQL语句 /// </summary> public string Sql { get; set; } /// <summary> /// SQL语句中的参数变量 /// </summary> public Dictionary<string, object> Parameters { get; set; } = new Dictionary<string, object>(); public static WherePart IsSql(string sql) { return new WherePart() { Parameters = new Dictionary<string, object>(), Sql = sql }; } public static WherePart IsParameter(int count, object value) { return new WherePart() { Parameters = { { count.ToString(), value } }, Sql = $"@{count}" }; } public static WherePart IsCollection(ref int countStart, IEnumerable values) { var parameters = new Dictionary<string, object>(); var sql = new StringBuilder("("); foreach (var value in values) { parameters.Add((countStart).ToString(), value); sql.Append($"@{countStart},"); countStart++; } if (sql.Length == 1) { sql.Append("null,"); } sql[sql.Length - 1] = ')'; return new WherePart() { Parameters = parameters, Sql = sql.ToString() }; } public static WherePart Concat(string @operator, WherePart operand) { return new WherePart() { Parameters = operand.Parameters, Sql = $"({@operator} {operand.Sql})" }; } public static WherePart Concat(WherePart left, string @operator, WherePart right) { return new WherePart() { Parameters = left.Parameters.Union(right.Parameters).ToDictionary(kvp => kvp.Key, kvp => kvp.Value), Sql = $"({left.Sql} {@operator} {right.Sql})" }; } }