【手撸一个ORM】第八步、查询工具类
一、实体查询
using MyOrm.Commons; using MyOrm.DbParameters; using MyOrm.Expressions; using MyOrm.Mappers; using MyOrm.Reflections; using MyOrm.SqlBuilder; using System; using System.Collections.Generic; using System.Data; using System.Data.SqlClient; using System.Linq; using System.Linq.Expressions; using System.Text; namespace MyOrm.Queryable { public class MyQueryable<T> where T : class , new () { private readonly string _connectionString; // 要查询的导航属性 private readonly Dictionary<string, string[]> _includeProperties = new Dictionary<string, string[]>(); // Where子句中包含导航属性 private List<string> _whereProperties = new List<string>(); // 导航属性的缓存 private readonly List<MyEntity> _entityCache = new List<MyEntity>(); // Select子句 private readonly List<SelectResolveResult> _selectProperties = new List<SelectResolveResult>(); // 主表信息 private readonly MyEntity _masterEntity; // 查询需要的参数 private readonly MyDbParameters _parameters = new MyDbParameters(); // 是否已经调用过Where方法 private bool _hasInitWhere; // 拼接好的where子句 private string _where; // 拼接好的order by子句 private string _orderBy; // 构造方法 public MyQueryable(string connectionString) { _masterEntity = MyEntityContainer.Get(typeof(T)); _connectionString = connectionString; } #region Include public MyQueryable<T> Include<TProperty>(Expression<Func<T, TProperty>> expression) where TProperty : IEntity { if (expression.Body.NodeType == ExpressionType.MemberAccess) { var memberExpr = (MemberExpression)expression.Body; if (memberExpr.Expression != null && memberExpr.Expression.NodeType == ExpressionType.Parameter && memberExpr.Member.GetType().IsClass) { _includeProperties.TryAdd(memberExpr.Member.Name, new string[]{}); } } return this; } public MyQueryable<T> Include<TProperty>( Expression<Func<T, TProperty>> property, Expression<Func<TProperty, object>> fields) where TProperty : IEntity { if (property.Body.NodeType == ExpressionType.MemberAccess) { var visitor = new ObjectMemberVisitor(); visitor.Visit(property); var member = visitor.GetPropertyList().First(); visitor.Clear(); visitor.Visit(fields); var fieldList = visitor.GetPropertyList(); _includeProperties.TryAdd(member, fieldList.ToArray()); } return this; } public MyQueryable<T> Include(string navPropertyName) { var property = _masterEntity.Properties.Single(p => p.Name == navPropertyName); if (property != null) { if (property.JoinAble) { _includeProperties.Add(property.Name, new string[]{}); } } return this; } public MyQueryable<T> Include(string navPropertyName, string[] fields) { var property = _masterEntity.Properties.Single(p => p.Name == navPropertyName); if (property != null) { if (property.JoinAble) { _includeProperties.Add(property.Name, fields); } } return this; } #endregion #region Where public MyQueryable<T> Where(Expression<Func<T, bool>> expr) { if (_hasInitWhere) { throw new ArgumentException("每个查询只能调用一次Where方法"); } _hasInitWhere = true; var condition = new QueryConditionResolver<T>(_masterEntity); var result = condition.Resolve(expr.Body); _where = result.Condition; _parameters.AddParameters(result.Parameters); _entityCache.AddRange(result.NavPropertyList); _whereProperties = result.NavPropertyList.Select(p => p.Name).ToList(); return this; } #endregion #region OrderBy,ThenOrderBy public MyQueryable<T> OrderBy<TProperty>(Expression<Func<T, TProperty>> expression, MyDbOrderBy orderBy = MyDbOrderBy.Asc) { if (expression.Body.NodeType == ExpressionType.MemberAccess) { _orderBy = GetOrderByString((MemberExpression)expression.Body); if (orderBy == MyDbOrderBy.Desc) { _orderBy += " DESC"; } } return this; } public MyQueryable<T> ThenOrderBy<TProperty>(Expression<Func<T, TProperty>> expression, MyDbOrderBy orderBy = MyDbOrderBy.Asc) { if (string.IsNullOrWhiteSpace(_orderBy)) { throw new ArgumentNullException(nameof(_orderBy), "排序字段为空,必须先调用OrderBy或OrderByDesc才能调用此方法"); } if (expression.Body.NodeType == ExpressionType.MemberAccess) { _orderBy += "," + GetOrderByString((MemberExpression)expression.Body); if (orderBy == MyDbOrderBy.Desc) { _orderBy += " DESC"; } } return this; } #endregion #region Select public MySelect<TTarget> Select<TTarget>(Expression<Func<T, object>> expression) { var visitor = new SelectExpressionResolver(); visitor.Visit(expression); _selectProperties.AddRange(visitor.GetPropertyList()); return new MySelect<TTarget>(_connectionString, GetFields(), GetFrom(), _where, _parameters, _orderBy); } #endregion #region 输出 public List<T> ToList() { var fields = GetFields(); var from = GetFrom(); var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.Select(from, fields, _where, _orderBy); var visitor = new SqlDataReaderConverter<T>(); List<T> result; using (var conn = new SqlConnection(_connectionString)) { var command = new SqlCommand(sql, conn); command.Parameters.AddRange(_parameters.Parameters); conn.Open(); using (var sdr = command.ExecuteReader()) { result = visitor.ConvertToEntityList(sdr); } } return result; } public List<T> ToPageList(int pageIndex, int pageSize, out int recordCount) { var fields = GetFields(); var from = GetFrom(); recordCount = 0; var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.PagingSelect(from, fields, _where, _orderBy, pageIndex, pageSize); var command = new SqlCommand(sql); command.Parameters.AddRange(_parameters.Parameters); var param = new SqlParameter("@RecordCount", SqlDbType.Int) { Direction = ParameterDirection.Output }; command.Parameters.Add(param); List<T> result; using (var conn = new SqlConnection(_connectionString)) { conn.Open(); command.Connection = conn; using (var sdr = command.ExecuteReader()) { var handler = new SqlDataReaderConverter<T>(_includeProperties.Select(p => p.Key).ToArray()); result = handler.ConvertToEntityList(sdr); } } recordCount = (int)param.Value; return result; } public T FirstOrDefault() { var fields = GetFields(); var from = GetFrom(); var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.Select(from, fields, _where, _orderBy, 1); using (var conn = new SqlConnection(_connectionString)) { conn.Open(); var command = new SqlCommand(sql, conn); command.Parameters.AddRange(_parameters.Parameters); var sdr = command.ExecuteReader(); var handler = new SqlDataReaderConverter<T>(_includeProperties.Select(p => p.Key).ToArray()); return handler.ConvertToEntity2(sdr); } } #endregion #region 辅助方法 /// 把要用到的导航属性的MyEntity缓存到一个List里,不需要每次都要到字典中获取 private MyEntity GetIncludePropertyEntityInfo(Type type) { var entity = _entityCache.FirstOrDefault(e => e.Name == type.FullName); if (entity != null) return entity; entity = MyEntityContainer.Get(type); _entityCache.Add(entity); return entity; } // 获取Select子句 public string GetFields() { if (_selectProperties.Count == 0) { var masterFields = string.Join( ",", _masterEntity .Properties .Where(p => p.IsMap) .Select(p => $"[{_masterEntity.TableName}].[{p.FieldName}] AS [{p.Name}]") ); if (_includeProperties.Count > 0) { var sb = new StringBuilder(masterFields); sb.Append(","); var includeProperties = _includeProperties.OrderBy(i => i); foreach (var property in includeProperties) { var prop = _masterEntity.Properties.Single(p => p.Name == property.Key); var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType); if (property.Value.Length == 0) { sb.Append( string.Join(",", propEntity.Properties.Where(p => p.IsMap).Select(p => $"[{propEntity.TableName}].[{p.FieldName}] AS [{property.Key}_{p.Name}]")) ); } else { sb.Append( string.Join(",", propEntity.Properties.Where(p => p.IsMap && property.Value.Contains(p.Name)) .Select(p => $"[{propEntity.TableName}].[{p.FieldName}] AS [{property.Key}_{p.Name}]")) ); } } return sb.ToString(); } return masterFields; } else { _includeProperties.Clear(); var sb = new StringBuilder(); foreach (var property in _selectProperties) { if (string.IsNullOrWhiteSpace(property.FieldName)) { var prop = _masterEntity.Properties.Single(p => p.Name == property.PropertyName); if (prop != null) { sb.Append($",[{_masterEntity.TableName}].[{prop.FieldName}] AS [{property.MemberName}]"); } } else { if (_masterEntity.Properties.Any(p => p.Name == property.PropertyName)) { _includeProperties.Add(property.PropertyName, new string[] {}); var prop = _masterEntity.Properties.Single(p => p.Name == property.PropertyName); var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType); var field = propEntity.Properties.Single(p => p.Name == property.FieldName); if (field != null) { sb.Append( $",[{property.PropertyName}].[{field.FieldName}] AS [{property.MemberName}]"); } } } } return sb.Remove(0, 1).ToString(); } } // 获取From子句 public string GetFrom() { var masterTable = $"[{_masterEntity.TableName}]"; var allJoinProperties = _includeProperties.Select(p => p.Key).Concat(_whereProperties).Distinct().ToList(); if (allJoinProperties.Any()) { var sb = new StringBuilder(masterTable); foreach (var property in allJoinProperties) { var prop = _masterEntity.Properties.Single(p => p.Name == property); if (prop != null) { var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType); sb.Append($" LEFT JOIN [{propEntity.TableName}] AS [{property}] ON [{_masterEntity.TableName}].[{prop.ForeignKey}]=[{propEntity.TableName}].[{propEntity.KeyColumn}]"); } } return sb.ToString(); } return masterTable; } // 获取OrderBy子句 private string GetOrderByString(MemberExpression expression) { expression.GetRootType(out var stack); if (stack.Count == 1) { var propName = stack.Pop(); var prop = _masterEntity.Properties.Single(p => p.Name == propName); return $"[{_masterEntity.TableName}].[{prop.FieldName}]"; } if (stack.Count == 2) { var slavePropName = stack.Pop(); var propertyName = stack.Pop(); var masterProp = _masterEntity.Properties.Single(p => p.Name == propertyName); var slaveEntity = GetIncludePropertyEntityInfo(masterProp.PropertyInfo.PropertyType); var slaveProperty = slaveEntity.Properties.Single(p => p.Name == slavePropName); return $"[{masterProp.Name}].[{slaveProperty.FieldName}]"; } return string.Empty; } #endregion } }
二、按需查询 Select<T>()
using MyOrm.DbParameters; using MyOrm.Mappers; using MyOrm.SqlBuilder; using System.Collections.Generic; using System.Data; using System.Data.SqlClient; namespace MyOrm.Queryable { public class MySelect<T> { private readonly string _connectionString; private readonly string _fields; private readonly string _table; private readonly string _where; private readonly string _orderBy; private readonly MyDbParameters _parameters; public MySelect(string connectionString, string fields, string table, string where, MyDbParameters dbParameters, string orderBy) { _fields = fields; _table = table; _where = where; _parameters = dbParameters; _orderBy = orderBy; _connectionString = connectionString; } public List<T> ToList() { var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy); var visitor = new SqlDataReaderMapper(); List<T> result; using (var conn = new SqlConnection(_connectionString)) { var command = new SqlCommand(sql, conn); command.Parameters.AddRange(_parameters.Parameters); conn.Open(); using (var sdr = command.ExecuteReader()) { result = visitor.ConvertToList<T>(sdr); } } return result; } //public List<dynamic> DynamicList() //{ // var sqlBuilder = new SqlServerBuilder(); // var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy); // var visitor = new SqlDataReaderMapper(); // List<dynamic> result; // using (var conn = new SqlConnection(_connectionString)) // { // var command = new SqlCommand(sql, conn); // command.Parameters.AddRange(_parameters.Parameters); // conn.Open(); // using (var sdr = command.ExecuteReader()) // { // result = visitor.ConvertToList(sdr); // } // } // return result; //} public List<T> ToPageList(int pageIndex, int pageSize, out int recordCount) { recordCount = 0; var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.PagingSelect2008(_table, _fields, _where, _orderBy, pageIndex, pageSize); var command = new SqlCommand(sql); command.Parameters.AddRange(_parameters.Parameters); var param = new SqlParameter("@RecordCount", SqlDbType.Int) { Direction = ParameterDirection.Output }; command.Parameters.Add(param); List<T> result; using (var conn = new SqlConnection(_connectionString)) { conn.Open(); command.Connection = conn; using (var sdr = command.ExecuteReader()) { var handler = new SqlDataReaderMapper(); result = handler.ConvertToList<T>(sdr); } } recordCount = (int)param.Value; return result; } //public List<dynamic> ToPageListDynamic(int pageIndex, int pageSize, out int recordCount) //{ // recordCount = 0; // var sqlBuilder = new SqlServerBuilder(); // var sql = sqlBuilder.PagingSelect2008(_table, _fields, _where, _orderBy, pageIndex, pageSize); // var command = new SqlCommand(sql); // command.Parameters.AddRange(_parameters.Parameters); // var param = new SqlParameter("@RecordCount", SqlDbType.Int) {Direction = ParameterDirection.Output}; // command.Parameters.Add(param); // List<dynamic> result; // using (var conn = new SqlConnection(_connectionString)) // { // conn.Open(); // command.Connection = conn; // using (var sdr = command.ExecuteReader()) // { // var handler = new SqlDataReaderMapper(); // result = handler.ConvertToList(sdr); // } // } // recordCount = (int) param.Value; // return result; //} public T FirstOrDefault() { var sqlBuilder = new SqlServerBuilder(); var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy, 1); using (var conn = new SqlConnection(_connectionString)) { conn.Open(); var command = new SqlCommand(sql, conn); command.Parameters.AddRange(_parameters.Parameters); var sdr = command.ExecuteReader(); var handler = new SqlDataReaderMapper(); return handler.ConvertToEntity<T>(sdr); } } //public dynamic FirstOrDefaultDynamic() //{ // var sqlBuilder = new SqlServerBuilder(); // var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy, 1); // using (var conn = new SqlConnection(_connectionString)) // { // conn.Open(); // var command = new SqlCommand(sql, conn); // command.Parameters.AddRange(_parameters.Parameters); // var sdr = command.ExecuteReader(); // var handler = new SqlDataReaderMapper(); // return handler.ConvertToEntity(sdr); // } //} } }