使用表达式从IDataReader上填充到实体

有这么一个东西

using(IDbConnection conn = GetDbConnection()) {
    var db = conn.DbContext();
    db.DbSet.Select<T>()
        .InnerJoin<T2>((t,t2) => t.XXX == t2.XXX)
        .Where(t => whereLambda)
        .OrderBy(t => t.XXX)
        .Paging(from, to);
}

我希望通过这样来得到一个 IEnumerable<T>

IEnumerable<T> result = db.Query<T>();

DbContext 上添加一个拓展方法

public static class DbContextEx {   
    public static IEnumerable<T> Query<T>(this DbContext self) {
        // DbExec 只是一个类型为 IDbAction 的属性,负责执行 Sql
        return self.DbExec.Query<T>(self.Sql, self.SqlParameter);
    }
}

// IDbAction.Query 
public IEnumerable<T> Query<T>(string sql, object p) {
    try {
        // conn 是一个 IDbConnection
        IEnumerable<T> ret = conn.Query<T>(sql, param: p);
        return ret;
        
    } catch (Exception ex) {
        throw ex;
    }
}

IDbConnection 上添加 Query 拓展方法。没错,就是参考Dapper写的。当然也可以直接使用Dapper,但是对于我来说,Dapper或许太大了。

private static IEnumerable<T> internalQuery<T>(this IDbConnection conn, CommandDefinition command) {
    // 缓存
    var parameter = command.Parameters;
    Certificate certificate = new Certificate(command.CommandText, command.CommandType, conn, typeof(T), parameter?.GetType());
    CacheInfo cacheInfo = CacheInfo.GetCacheInfo(certificate, parameter);
    // 读取
    IDbCommand cmd = null;
    IDataReader reader = null;
    var wasClosed = conn.State == ConnectionState.Closed;
    try {
        cmd = command.SetupCommand(conn, cacheInfo.ParameterReader);
        if (wasClosed)
            conn.Open();
        reader = ExecuteReaderWithFlagsFallback(cmd, wasClosed, CommandBehavior.SingleResult);
        if (cacheInfo.Deserializer == null) {
            cacheInfo.Deserializer = BuildDeserializer<T>(reader);
        }
        var props = typeof(T).GetProperties();
        while (reader.Read()) {
            var val = cacheInfo.Deserializer(reader);
            yield return GetValue<T>(val);
        }
    } finally {
        // dispose
        if (reader != null) {
            if (!reader.IsClosed) {
                try {
                    cmd.Cancel();
                } catch {
                }
            }
            reader.Dispose();
        }
        if (wasClosed) {
            conn.Close();
        }
        cmd?.Dispose();
    }
// BuildDeserializer
private static Func<IDataReader, object> BuildDeserializer<T>(IDataReader reader) {
    IDeserializer des = new ExpressionBuilder();
    return des.BuildDeserializer<T>(reader);
}

本文问题的关键就是 ExpressionBuilder 实现的 BuildDeserializer方法

public Func<IDataReader, object> BuildDeserializer<T>(IDataReader reader) {
    return BuildFunc<T>(reader, CultureInfo.CurrentCulture, false);
}

BuildFunc 的工作内容

Expression动态生成这么一段代码,很明显,就是一个签名跟Func<IDataReader, object>一样的Lambda

record => {
    return new T {
        Member0 = (memberType)record.Get_XXX[0],
        Member1 = (memberType)record.Get_XXX[1],
        Member2 = (memberType)record.Get_XXX[2],
        Member3 = (memberType)record.Get_XXX[3],
        Member4 = record.IsDBNull(4) ? default(memberType) : (memberType)recordGet_XXX[4],
    }
}

一点思考,关于SortedDictionary,此前使用了List存储MemberBinding,经常无规律出现一些异常,或许是与DataReader的流式读取有关。

第一版代码,使用GetOrdinalGetValueDataReader获取值。存在问题:1.类型转换存在Bug;2.在DotNetCore上null值处理异常;

public Func<IDataReader, object> BuildDeserializer(IDataReader reader, Type targetType) {
    IDataReader dataReader = reader;
    Type type = targetType;

    //List<MemberBinding> Bindings = new List<MemberBinding>();
    Type SourceType = typeof(IDataReader);
    ParameterExpression SourceInstance = Expression.Parameter(SourceType, "reader");

    DataTable SchemaTable = dataReader.GetSchemaTable();

    var props = type.GetProperties(BindingFlags.Public | BindingFlags.Instance);

    #region memberBinding
    //List<MemberBinding> bindings = new List<MemberBinding>();
    //for (int i = 0; i < dataReader.FieldCount; i++) {
    //    var prop = props.FirstOrDefault(p => p.CanWrite && p.Name.ToLower() == dataReader.GetName(i)ToLower());
    //    if (prop == null) continue;
    //    var propType = Nullable.GetUnderlyingType(prop.PropertyType) ?? prop.PropertyType;
    //    var fieldType = dataReader.GetFieldType(i);

    //    if (!typeMethodMap.TryGetValue(fieldType, out var mi))
    //        mi = DataRecord_GetValue;

    //    var indexExp = Expression.Constant(i);
    //    // reader.Get_XXX(i)
    //    var valueExp = Expression.Call(SourceInstance, mi, indexExp);
    //    //var valueConvertExp = Expression.Convert(valueExp, fieldType);
    //    var converted = GetRealValueExpression(valueExp, propType);
    //    var binding = Expression.Bind(prop, converted);
    //    bindings.Add(binding);
    //}
    //var exEntityInstance = Expression.New(targetType);
    //var bindExp = Expression.MemberInit(exEntityInstance, bindings.ToArray());
    //var lambda = Expression.Lambda<Func<IDataReader, object>>(bindExp, SourceInstance);
    #endregion

    List<Expression> body = new List<Expression>();
    var eExp = Expression.Variable(type, "e");
    var eAssignExp = Expression.Assign(eExp, Expression.New(type));
    // var e = new T();
    body.Add(eAssignExp);

    for (int i = 0; i < props.Length; i++) {
        var prop = props[i];
        if (prop == null || !prop.CanWrite) continue;
        var propType =  prop.PropertyType;//Nullable.GetUnderlyingType(prop.PropertyType) ??
        var indexExp = Expression.Call(SourceInstance, DataRecord_GetOrdinal, Expression.Constant(prop.Name));
        //if (!typeMethodMap.TryGetValue(fieldType, out var mi))
        //    mi = DataRecord_GetValue;
        // e.XXX = reader.GetXX(i)
        var valueExp = Expression.Call(SourceInstance, DataRecord_GetValue, indexExp);
        var propAssign = Expression.Assign(Expression.Property(eExp, prop), GetRealValueExpression(valueExp, propType));
        //var isDBNullExp = Expression.Call(SourceInstance, DataRecord_IsDBNull, indexExp);

        //Expression.IfThenElse()

        body.Add(propAssign);
    }

    // return e;
    body.Add(eExp);
    var block = Expression.Block(
        new[] { eExp },
        body.ToArray()
        );
    var lambda = Expression.Lambda<Func<IDataReader, object>>(block, SourceInstance);
    return lambda.Compile();

    //
    Expression GetRealValueExpression(Expression valueExp, Type targetPropType) {

        //return Expression.Convert(Expression.Call(Convert_ChangeType, valueExp, Expression.Constant(targetPropType)), targetPropType);
        var temp = Expression.Variable(targetPropType, "temp");
        var checkDbNull = Expression.TypeIs(valueExp, typeof(DBNull));
        var checkNull = Expression.Equal(valueExp, Expression.Constant(null));
        valueExp = Expression.Convert(valueExp, typeof(object));
        /*
         * if(reader.Get_XXX(i) is DBNull){
         *     return default;
         * } else {
         *     return (type)Convert.ChangeType(reader.Get_XXX(i),type)
         * }
         */
        return Expression.Block(
            new[] { temp },
            Expression.IfThenElse(
            Expression.OrElse(checkDbNull, checkNull),
            Expression.Default(targetPropType),
            Expression.Assign(temp, Expression.Convert(Expression.Call(Convert_ChangeType, valueExp, Expression.Constant(targetPropType)), targetPropType))),
            temp
            );
    }

完整代码,从SchemaTable中获取字段的Type,根据Type选择不同的Get方法,减少拆箱装箱操作。使用SortedDictionary保证DataReader读取的顺序。

/// <summary>
/// record => {
///     return new T {
///         Member0 = (memberType)record.Get_XXX[0],
///         Member1 = (memberType)record.Get_XXX[1],
///         Member2 = (memberType)record.Get_XXX[2],
///         Member3 = (memberType)record.Get_XXX[3],
///         Member4 = record.IsDBNull(4) ? default(memberType) : (memberType)record.Get_XXX[4],
///     }
/// }
/// </summary>
/// <typeparam name="Target"></typeparam>
/// <param name="RecordInstance"></param>
/// <param name="Culture"></param>
/// <param name="MustMapAllProperties"></param>
/// <returns></returns>
private Func<IDataRecord, object> BuildFunc<Target>(IDataRecord RecordInstance, CultureInfo Culture, boolMustMapAllProperties) {
    ParameterExpression recordInstanceExp = Expression.Parameter(typeof(IDataRecord), "Record");
    Type TargetType = typeof(Target);
    DataTable SchemaTable = ((IDataReader)RecordInstance).GetSchemaTable();
    Expression Body = default(Expression);
    // 元组处理
    if (TargetType.FullName.StartsWith("System.Tuple`")) {
        ConstructorInfo[] Constructors = TargetType.GetConstructors();
        if (Constructors.Count() != 1)
            throw new ArgumentException("Tuple must have one Constructor");
        var Constructor = Constructors[0];
        var Parameters = Constructor.GetParameters();
        if (Parameters.Length > 7)
            throw new NotSupportedException("Nested Tuples are not supported");
        Expression[] TargetValueExpressions = new Expression[Parameters.Length];
        for (int Ordinal = 0; Ordinal < Parameters.Length; Ordinal++) {
            var ParameterType = Parameters[Ordinal].ParameterType;
            if (Ordinal >= RecordInstance.FieldCount) {
                if (MustMapAllProperties) { throw new ArgumentException("Tuple has more fields than the DataReader"); }
                TargetValueExpressions[Ordinal] = Expression.Default(ParameterType);
            } else {
                TargetValueExpressions[Ordinal] = GetTargetValueExpression(
                                                RecordInstance,
                                                Culture,
                                                recordInstanceExp,
                                                SchemaTable,
                                                Ordinal,
                                                ParameterType);
            }
        }
        Body = Expression.New(Constructor, TargetValueExpressions);
    }
    // 基础类型处理 eg: IEnumable<int>  IEnumable<string>
    else if (TargetType.IsElementaryType()) {
        const int Ordinal = 0;
        Expression TargetValueExpression = GetTargetValueExpression(
                                                RecordInstance,
                                                Culture,
                                                recordInstanceExp,
                                                SchemaTable,
                                                Ordinal,
                                                TargetType);
        ParameterExpression TargetExpression = Expression.Variable(TargetType, "Target");
        Expression AssignExpression = Expression.Assign(TargetExpression, TargetValueExpression);
        Body = Expression.Block(new ParameterExpression[] { TargetExpression }, AssignExpression);
    }
    // 其他
    else {                
        SortedDictionary<int, MemberBinding> Bindings = new SortedDictionary<int, MemberBinding>();
        // 字段处理 Field
        foreach (FieldInfo TargetMember in TargetType.GetFields(BindingFlags.Public | BindingFlags.Instance)) {
            Action work = delegate {
                for (int Ordinal = 0; Ordinal < RecordInstance.FieldCount; Ordinal++) {
                    //Check if the RecordFieldName matches the TargetMember
                    if (MemberMatchesName(TargetMember, RecordInstance.GetName(Ordinal))) {
                        Expression TargetValueExpression = GetTargetValueExpression(
                                                                RecordInstance,
                                                                Culture,
                                                                recordInstanceExp,
                                                                SchemaTable,
                                                                Ordinal,
                                                                TargetMember.FieldType);
                        //Create a binding to the target member
                        MemberAssignment BindExpression = Expression.Bind(TargetMember, TargetValueExpression);
                        Bindings.Add(Ordinal, BindExpression);
                        return;
                    }
                }
                //If we reach this code the targetmember did not get mapped
                if (MustMapAllProperties) {
                    throw new ArgumentException(String.Format("TargetField {0} is not matched by any field in the DataReader", TargetMember.Name));
                }
            };
            work();
        }
        // 属性处理 Property
        foreach (PropertyInfo TargetMember in TargetType.GetProperties(BindingFlags.Public | BindingFlags.Instance)) {
            if (TargetMember.CanWrite) {
                Action work = delegate {
                    for (int Ordinal = 0; Ordinal < RecordInstance.FieldCount; Ordinal++) {
                        //Check if the RecordFieldName matches the TargetMember
                        if (MemberMatchesName(TargetMember, RecordInstance.GetName(Ordinal))) {
                            Expression TargetValueExpression = GetTargetValueExpression(
                                                                    RecordInstance,
                                                                    Culture,
                                                                    recordInstanceExp,
                                                                    SchemaTable,
                                                                    Ordinal,
                                                                    TargetMember.PropertyType);
                            //Create a binding to the target member
                            MemberAssignment BindExpression = Expression.Bind(TargetMember, TargetValueExpression);
                            Bindings.Add(Ordinal, BindExpression);
                            return;
                        }
                    }
                    //If we reach this code the targetmember did not get mapped
                    if (MustMapAllProperties) {
                        throw new ArgumentException(String.Format("TargetProperty {0} is not matched by any Field in the DataReader", TargetMember.Name));
                    }
                };
                work();
            }
        }
        Body = Expression.MemberInit(Expression.New(TargetType), Bindings.Values);
    }
    //Compile as Delegate
    return Expression.Lambda<Func<IDataRecord, object>>(Body, recordInstanceExp).Compile();
}
private static bool MemberMatchesName(MemberInfo Member, string Name) {
    string FieldnameAttribute = GetColumnNameAttribute();
    return FieldnameAttribute.ToLower() == Name.ToLower() || Member.Name.ToLower() == Name.ToLower();
    string GetColumnNameAttribute() {
        if (Member.GetCustomAttributes(typeof(ColumnNameAttribute), true).Count() > 0) {
            return ((ColumnNameAttribute)Member.GetCustomAttributes(typeof(ColumnNameAttribute), true)[0]).Name;
        } else {
            return string.Empty;
        }
    }
}
private static Expression GetTargetValueExpression(
    IDataRecord RecordInstance,
    CultureInfo Culture,
    ParameterExpression recordInstanceExp,
    DataTable SchemaTable,
    int Ordinal,
    Type TargetMemberType) {
    Type RecordFieldType = RecordInstance.GetFieldType(Ordinal);
    bool AllowDBNull = Convert.ToBoolean(SchemaTable.Rows[Ordinal]["AllowDBNull"]);
    Expression RecordFieldExpression = GetRecordFieldExpression(recordInstanceExp, Ordinal, RecordFieldType);
    Expression ConvertedRecordFieldExpression = GetConversionExpression(RecordFieldType, RecordFieldExpression, TargetMemberType, Culture);
    MethodCallExpression NullCheckExpression = GetNullCheckExpression(recordInstanceExp, Ordinal);
    //Create an expression that assigns the converted value to the target
    Expression TargetValueExpression = default(Expression);
    if (AllowDBNull) {
        TargetValueExpression = Expression.Condition(
            NullCheckExpression,
            Expression.Default(TargetMemberType),
            ConvertedRecordFieldExpression,
            TargetMemberType
            );
    } else {
        TargetValueExpression = ConvertedRecordFieldExpression;
    }
    return TargetValueExpression;
}
private static Expression GetRecordFieldExpression(ParameterExpression recordInstanceExp, int Ordinal, TypeRecordFieldType) {
    //MethodInfo GetValueMethod = default(MethodInfo);
    typeMapMethod.TryGetValue(RecordFieldType, out var GetValueMethod);
    if (GetValueMethod == null)
        GetValueMethod = DataRecord_GetValue;
    Expression RecordFieldExpression;
    if (object.ReferenceEquals(RecordFieldType, typeof(byte[]))) {
        RecordFieldExpression = Expression.Call(GetValueMethod, new Expression[] { recordInstanceExp, Expression.Constant(Ordinal, typeof(int)) });
    } else {
        RecordFieldExpression = Expression.Call(recordInstanceExp, GetValueMethod, Expression.Constant(Ordinal, typeof(int)));
    }
    return RecordFieldExpression;
}
private static MethodCallExpression GetNullCheckExpression(ParameterExpression RecordInstance, int Ordinal){
    MethodCallExpression NullCheckExpression = Expression.Call(RecordInstance, DataRecord_IsDBNull, Expression.Constant(Ordinal, typeof(int)));
    return NullCheckExpression;
}
private static Expression GetConversionExpression(Type SourceType, Expression SourceExpression, TypeTargetType, CultureInfo Culture) {
    Expression TargetExpression;
    if (ReferenceEquals(TargetType, SourceType)) {
        TargetExpression = SourceExpression;
    } else if (ReferenceEquals(SourceType, typeof(string))) {
        TargetExpression = GetParseExpression(SourceExpression, TargetType, Culture);
    } else if (ReferenceEquals(TargetType, typeof(string))) {
        TargetExpression = Expression.Call(SourceExpression, SourceType.GetMethod("ToString", Type.EmptyTypes));
    } else if (ReferenceEquals(TargetType, typeof(bool))) {
        MethodInfo ToBooleanMethod = typeof(Convert).GetMethod("ToBoolean", new[] { SourceType });
        TargetExpression = Expression.Call(ToBooleanMethod, SourceExpression);
    } else if (ReferenceEquals(SourceType, typeof(Byte[]))) {
        TargetExpression = GetArrayHandlerExpression(SourceExpression, TargetType);
    } else {
        TargetExpression = Expression.Convert(SourceExpression, TargetType);
    }
    return TargetExpression;
}
private static Expression GetArrayHandlerExpression(Expression sourceExpression, Type targetType) {
    Expression TargetExpression = default(Expression);
    if (object.ReferenceEquals(targetType, typeof(byte[]))) {
        TargetExpression = sourceExpression;
    } else if (object.ReferenceEquals(targetType, typeof(MemoryStream))) {
        ConstructorInfo ConstructorInfo = targetType.GetConstructor(new[] { typeof(byte[]) });
        TargetExpression = Expression.New(ConstructorInfo, sourceExpression);
    } else {
        throw new ArgumentException("Cannot convert a byte array to " + targetType.Name);
    }
    return TargetExpression;
}
private static Expression GetParseExpression(Expression SourceExpression, Type TargetType, CultureInfoCulture) {
    Type UnderlyingType = GetUnderlyingType(TargetType);
    if (UnderlyingType.IsEnum) {
        MethodCallExpression ParsedEnumExpression = GetEnumParseExpression(SourceExpression, UnderlyingType);
        //Enum.Parse returns an object that needs to be unboxed
        return Expression.Unbox(ParsedEnumExpression, TargetType);
    } else {
        Expression ParseExpression = default(Expression);
        switch (UnderlyingType.FullName) {
            case "System.Byte":
            case "System.UInt16":
            case "System.UInt32":
            case "System.UInt64":
            case "System.SByte":
            case "System.Int16":
            case "System.Int32":
            case "System.Int64":
            case "System.Double":
            case "System.Decimal":
                ParseExpression = GetNumberParseExpression(SourceExpression, UnderlyingType, Culture);
                break;
            case "System.DateTime":
                ParseExpression = GetDateTimeParseExpression(SourceExpression, UnderlyingType, Culture);
                break;
            case "System.Boolean":
            case "System.Char":
                ParseExpression = GetGenericParseExpression(SourceExpression, UnderlyingType);
                break;
            default:
                throw new ArgumentException(string.Format("Conversion from {0} to {1} is not supported", "String", TargetType));
        }
        if (Nullable.GetUnderlyingType(TargetType) == null) {
            return ParseExpression;
        } else {
            //Convert to nullable if necessary
            return Expression.Convert(ParseExpression, TargetType);
        }
    }
    Expression GetGenericParseExpression(Expression sourceExpression, Type type) {
        MethodInfo ParseMetod = type.GetMethod("Parse", new[] { typeof(string) });
        MethodCallExpression CallExpression = Expression.Call(ParseMetod, new[] { sourceExpression });
        return CallExpression;
    }
    Expression GetDateTimeParseExpression(Expression sourceExpression, Type type, CultureInfo culture) {
        MethodInfo ParseMetod = type.GetMethod("Parse", new[] { typeof(string), typeof(DateTimeFormatInfo) });
        ConstantExpression ProviderExpression = Expression.Constant(culture.DateTimeFormat, typeof(DateTimeFormatInfo));
        MethodCallExpression CallExpression = Expression.Call(ParseMetod, new[] { sourceExpression, ProviderExpression });
        return CallExpression;
    }
    MethodCallExpression GetEnumParseExpression(Expression sourceExpression, Type type) {
        //Get the MethodInfo for parsing an Enum
        MethodInfo EnumParseMethod = typeof(Enum).GetMethod("Parse", new[] { typeof(Type), typeof(string), typeof(bool) });
        ConstantExpression TargetMemberTypeExpression = Expression.Constant(type);
        ConstantExpression IgnoreCase = Expression.Constant(true, typeof(bool));
        //Create an expression the calls the Parse method
        MethodCallExpression CallExpression = Expression.Call(EnumParseMethod, new[] { TargetMemberTypeExpression, sourceExpression, IgnoreCase });
        return CallExpression;
    }
    MethodCallExpression GetNumberParseExpression(Expression sourceExpression, Type type, CultureInfo culture) {
        MethodInfo ParseMetod = type.GetMethod("Parse", new[] { typeof(string), typeof(NumberFormatInfo) });
        ConstantExpression ProviderExpression = Expression.Constant(culture.NumberFormat, typeof(NumberFormatInfo));
        MethodCallExpression CallExpression = Expression.Call(ParseMetod, new[] { sourceExpression, ProviderExpression });
        return CallExpression;
    }
}
private static Type GetUnderlyingType(Type targetType) {
    return Nullable.GetUnderlyingType(targetType) ?? targetType;
}

自用ORM Github地址

posted @ 2021-05-23 19:29  yaoqinglin_mtiter  阅读(249)  评论(0编辑  收藏  举报