ASP.NET Core 中间件(Middleware)的使用及其源码解析(二)- 使用UseMiddleware扩展方法注册自定义中间件

有的中间件功能比较简单,有的则比较复杂,并且依赖其它组件。除了直接用 ApplicationBuilder 的 Use() 方法注册中间件外,还可以使用 ApplicationBuilder 的扩展方法 UseMiddleware() 注册自定义中间件。

废话不多说,我们在上一篇的基础上加一个自定义中间件类 CustomMiddleware ,如下所示:

using Microsoft.AspNetCore.Http;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace NETCoreMiddleware.Middlewares
{
    /// <summary>
    /// 自定义中间件
    /// </summary>
    public class CustomMiddleware
    {
        /// <summary>
        /// 保存下一个中间件
        /// </summary>
        private readonly RequestDelegate _next;

        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="next">下一个中间件</param>
        public CustomMiddleware(RequestDelegate next)
        {
            Console.WriteLine($"{typeof(CustomMiddleware)} 被构造");
            _next = next;
        }

        /// <summary>
        /// 中间件方法
        /// </summary>
        public async Task InvokeAsync(HttpContext context)
        {
            await Task.Run(() => Console.WriteLine($"This is CustomMiddleware Start"));
            await _next.Invoke(context); //可通过不调用 next 参数使请求管道短路
            await Task.Run(() => Console.WriteLine($"This is CustomMiddleware End"));
        }
    }
}

接着将该中间件装配到我们的Http请求处理管道中,如下所示(标红部分):

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

using NETCoreMiddleware.Middlewares;

namespace NETCoreMiddleware
{
    public class Startup
    {
        public Startup(IConfiguration configuration)
        {
            Configuration = configuration;
        }

        public IConfiguration Configuration { get; }

        //服务注册(往容器中添加服务)
        // This method gets called by the runtime. Use this method to add services to the container.
        public void ConfigureServices(IServiceCollection services)
        {
            services.AddControllersWithViews();
        }

        /// <summary>
        /// 配置Http请求处理管道
        /// Http请求管道模型---就是Http请求被处理的步骤
        /// 所谓管道,就是拿着HttpContext,经过多个步骤的加工,生成Response,这就是管道
        /// </summary>
        /// <param name="app"></param>
        /// <param name="env"></param>
        // This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
        public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
        {
            #region 环境参数

            if (env.IsDevelopment())
            {
                app.UseDeveloperExceptionPage();
            }
            else
            {
                app.UseExceptionHandler("/Home/Error");
            }

            #endregion 环境参数

            //静态文件中间件
            app.UseStaticFiles();

            #region Use中间件

            //中间件1
            app.Use(next =>
            {
                Console.WriteLine("middleware 1");
                return async context =>
                {
                    await Task.Run(() =>
                    {
                        Console.WriteLine("");
                        Console.WriteLine("===================================Middleware===================================");
                        Console.WriteLine($"This is middleware 1 Start");
                    });
                    await next.Invoke(context);
                    await Task.Run(() =>
                    {
                        Console.WriteLine($"This is middleware 1 End");
                        Console.WriteLine("===================================Middleware===================================");
                    });
                };
            });

            //中间件2
            app.Use(next =>
            {
                Console.WriteLine("middleware 2");
                return async context =>
                {
                    await Task.Run(() => Console.WriteLine($"This is middleware 2 Start"));
                    await next.Invoke(context); //可通过不调用 next 参数使请求管道短路
                    await Task.Run(() => Console.WriteLine($"This is middleware 2 End"));
                };
            });

            //中间件3
            app.Use(next =>
            {
                Console.WriteLine("middleware 3");
                return async context =>
                {
                    await Task.Run(() => Console.WriteLine($"This is middleware 3 Start"));
                    await next.Invoke(context);
                    await Task.Run(() => Console.WriteLine($"This is middleware 3 End"));
                };
            });

            //中间件4
            //Use方法的另外一个重载
            app.Use(async (context, next) =>
            {
                await Task.Run(() => Console.WriteLine($"This is middleware 4 Start"));
                await next();
                await Task.Run(() => Console.WriteLine($"This is middleware 4 End"));
            });

            #endregion Use中间件

            #region UseMiddleware中间件

            app.UseMiddleware<CustomMiddleware>();

            #endregion UseMiddleware中间件

            #region 终端中间件

            //app.Use(_ => handler);
            app.Run(async context =>
            {
                await Task.Run(() => Console.WriteLine($"This is Run"));
            });

            #endregion 终端中间件

            #region 最终把请求交给MVC

            app.UseRouting();
            app.UseAuthorization();
            app.UseEndpoints(endpoints =>
            {
                endpoints.MapControllerRoute(
                    name: "areas",
                    pattern: "{area:exists}/{controller=Home}/{action=Index}/{id?}");

                endpoints.MapControllerRoute(
                    name: "default",
                    pattern: "{controller=Home}/{action=Index}/{id?}");
            });

            #endregion 最终把请求交给MVC
        }
    }
}

下面我们使用命令行(CLI)方式启动我们的网站,如下所示:

启动成功后,我们来访问一下 “/home/index” ,控制台输出结果如下所示:

可以发现我们自定义的中间件生效了。

 

下面我们结合ASP.NET Core源码来分析下其实现原理: 

我们将光标移动到 UseMiddleware 处按 F12 转到定义,如下所示:

可以发现它是位于 UseMiddlewareExtensions 扩展类中的,我们找到 UseMiddlewareExtensions 类的源码,如下所示:

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.Internal;

namespace Microsoft.AspNetCore.Builder
{
    /// <summary>
    /// Extension methods for adding typed middleware.
    /// </summary>
    public static class UseMiddlewareExtensions
    {
        internal const string InvokeMethodName = "Invoke";
        internal const string InvokeAsyncMethodName = "InvokeAsync";

        private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static);

        /// <summary>
        /// Adds a middleware type to the application's request pipeline.
        /// </summary>
        /// <typeparam name="TMiddleware">The middleware type.</typeparam>
        /// <param name="app">The <see cref="IApplicationBuilder"/> instance.</param>
        /// <param name="args">The arguments to pass to the middleware type instance's constructor.</param>
        /// <returns>The <see cref="IApplicationBuilder"/> instance.</returns>
        public static IApplicationBuilder UseMiddleware<TMiddleware>(this IApplicationBuilder app, params object[] args)
        {
            return app.UseMiddleware(typeof(TMiddleware), args);
        }

        /// <summary>
        /// Adds a middleware type to the application's request pipeline.
        /// </summary>
        /// <param name="app">The <see cref="IApplicationBuilder"/> instance.</param>
        /// <param name="middleware">The middleware type.</param>
        /// <param name="args">The arguments to pass to the middleware type instance's constructor.</param>
        /// <returns>The <see cref="IApplicationBuilder"/> instance.</returns>
        public static IApplicationBuilder UseMiddleware(this IApplicationBuilder app, Type middleware, params object[] args)
        {
            if (typeof(IMiddleware).GetTypeInfo().IsAssignableFrom(middleware.GetTypeInfo()))
            {
                // IMiddleware doesn't support passing args directly since it's
                // activated from the container
                if (args.Length > 0)
                {
                    throw new NotSupportedException(Resources.FormatException_UseMiddlewareExplicitArgumentsNotSupported(typeof(IMiddleware)));
                }

                return UseMiddlewareInterface(app, middleware);
            }

            var applicationServices = app.ApplicationServices;
            return app.Use(next =>
            {
                var methods = middleware.GetMethods(BindingFlags.Instance | BindingFlags.Public);
                var invokeMethods = methods.Where(m =>
                    string.Equals(m.Name, InvokeMethodName, StringComparison.Ordinal)
                    || string.Equals(m.Name, InvokeAsyncMethodName, StringComparison.Ordinal)
                    ).ToArray();

                if (invokeMethods.Length > 1)
                {
                    throw new InvalidOperationException(Resources.FormatException_UseMiddleMutlipleInvokes(InvokeMethodName, InvokeAsyncMethodName));
                }

                if (invokeMethods.Length == 0)
                {
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoInvokeMethod(InvokeMethodName, InvokeAsyncMethodName, middleware));
                }

                var methodInfo = invokeMethods[0];
                if (!typeof(Task).IsAssignableFrom(methodInfo.ReturnType))
                {
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNonTaskReturnType(InvokeMethodName, InvokeAsyncMethodName, nameof(Task)));
                }

                var parameters = methodInfo.GetParameters();
                if (parameters.Length == 0 || parameters[0].ParameterType != typeof(HttpContext))
                {
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoParameters(InvokeMethodName, InvokeAsyncMethodName, nameof(HttpContext)));
                }

                var ctorArgs = new object[args.Length + 1];
                ctorArgs[0] = next;
                Array.Copy(args, 0, ctorArgs, 1, args.Length);
                var instance = ActivatorUtilities.CreateInstance(app.ApplicationServices, middleware, ctorArgs);
                if (parameters.Length == 1)
                {
                    return (RequestDelegate)methodInfo.CreateDelegate(typeof(RequestDelegate), instance);
                }

                var factory = Compile<object>(methodInfo, parameters);

                return context =>
                {
                    var serviceProvider = context.RequestServices ?? applicationServices;
                    if (serviceProvider == null)
                    {
                        throw new InvalidOperationException(Resources.FormatException_UseMiddlewareIServiceProviderNotAvailable(nameof(IServiceProvider)));
                    }

                    return factory(instance, context, serviceProvider);
                };
            });
        }

        private static IApplicationBuilder UseMiddlewareInterface(IApplicationBuilder app, Type middlewareType)
        {
            return app.Use(next =>
            {
                return async context =>
                {
                    var middlewareFactory = (IMiddlewareFactory)context.RequestServices.GetService(typeof(IMiddlewareFactory));
                    if (middlewareFactory == null)
                    {
                        // No middleware factory
                        throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoMiddlewareFactory(typeof(IMiddlewareFactory)));
                    }

                    var middleware = middlewareFactory.Create(middlewareType);
                    if (middleware == null)
                    {
                        // The factory returned null, it's a broken implementation
                        throw new InvalidOperationException(Resources.FormatException_UseMiddlewareUnableToCreateMiddleware(middlewareFactory.GetType(), middlewareType));
                    }

                    try
                    {
                        await middleware.InvokeAsync(context, next);
                    }
                    finally
                    {
                        middlewareFactory.Release(middleware);
                    }
                };
            });
        }

        private static Func<T, HttpContext, IServiceProvider, Task> Compile<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
        {
            // If we call something like
            //
            // public class Middleware
            // {
            //    public Task Invoke(HttpContext context, ILoggerFactory loggerFactory)
            //    {
            //
            //    }
            // }
            //

            // We'll end up with something like this:
            //   Generic version:
            //
            //   Task Invoke(Middleware instance, HttpContext httpContext, IServiceProvider provider)
            //   {
            //      return instance.Invoke(httpContext, (ILoggerFactory)UseMiddlewareExtensions.GetService(provider, typeof(ILoggerFactory));
            //   }

            //   Non generic version:
            //
            //   Task Invoke(object instance, HttpContext httpContext, IServiceProvider provider)
            //   {
            //      return ((Middleware)instance).Invoke(httpContext, (ILoggerFactory)UseMiddlewareExtensions.GetService(provider, typeof(ILoggerFactory));
            //   }

            var middleware = typeof(T);

            var httpContextArg = Expression.Parameter(typeof(HttpContext), "httpContext");
            var providerArg = Expression.Parameter(typeof(IServiceProvider), "serviceProvider");
            var instanceArg = Expression.Parameter(middleware, "middleware");

            var methodArguments = new Expression[parameters.Length];
            methodArguments[0] = httpContextArg;
            for (int i = 1; i < parameters.Length; i++)
            {
                var parameterType = parameters[i].ParameterType;
                if (parameterType.IsByRef)
                {
                    throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
                }

                var parameterTypeExpression = new Expression[]
                {
                    providerArg,
                    Expression.Constant(parameterType, typeof(Type)),
                    Expression.Constant(methodInfo.DeclaringType, typeof(Type))
                };

                var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
                methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
            }

            Expression middlewareInstanceArg = instanceArg;
            if (methodInfo.DeclaringType != typeof(T))
            {
                middlewareInstanceArg = Expression.Convert(middlewareInstanceArg, methodInfo.DeclaringType);
            }

            var body = Expression.Call(middlewareInstanceArg, methodInfo, methodArguments);

            var lambda = Expression.Lambda<Func<T, HttpContext, IServiceProvider, Task>>(body, instanceArg, httpContextArg, providerArg);

            return lambda.Compile();
        }

        private static object GetService(IServiceProvider sp, Type type, Type middleware)
        {
            var service = sp.GetService(type);
            if (service == null)
            {
                throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
            }

            return service;
        }
    }
}
Microsoft.AspNetCore.Builder.UseMiddlewareExtensions类源码

从中我们可以知道,它最终会调用下面的这个方法:

/// <summary>
/// Adds a middleware type to the application's request pipeline.
/// </summary>
/// <param name="app">The <see cref="IApplicationBuilder"/> instance.</param>
/// <param name="middleware">The middleware type.</param>
/// <param name="args">The arguments to pass to the middleware type instance's constructor.</param>
/// <returns>The <see cref="IApplicationBuilder"/> instance.</returns>
public static IApplicationBuilder UseMiddleware(this IApplicationBuilder app, Type middleware, params object[] args)
{
    if (typeof(IMiddleware).GetTypeInfo().IsAssignableFrom(middleware.GetTypeInfo()))
    {
        // IMiddleware doesn't support passing args directly since it's
        // activated from the container
        if (args.Length > 0)
        {
            throw new NotSupportedException(Resources.FormatException_UseMiddlewareExplicitArgumentsNotSupported(typeof(IMiddleware)));
        }

        return UseMiddlewareInterface(app, middleware);
    }

    var applicationServices = app.ApplicationServices;
    return app.Use(next =>
    {
        var methods = middleware.GetMethods(BindingFlags.Instance | BindingFlags.Public);
        var invokeMethods = methods.Where(m =>
            string.Equals(m.Name, InvokeMethodName, StringComparison.Ordinal)
            || string.Equals(m.Name, InvokeAsyncMethodName, StringComparison.Ordinal)
            ).ToArray();

        if (invokeMethods.Length > 1)
        {
            throw new InvalidOperationException(Resources.FormatException_UseMiddleMutlipleInvokes(InvokeMethodName, InvokeAsyncMethodName));
        }

        if (invokeMethods.Length == 0)
        {
            throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoInvokeMethod(InvokeMethodName, InvokeAsyncMethodName, middleware));
        }

        var methodInfo = invokeMethods[0];
        if (!typeof(Task).IsAssignableFrom(methodInfo.ReturnType))
        {
            throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNonTaskReturnType(InvokeMethodName, InvokeAsyncMethodName, nameof(Task)));
        }

        var parameters = methodInfo.GetParameters();
        if (parameters.Length == 0 || parameters[0].ParameterType != typeof(HttpContext))
        {
            throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoParameters(InvokeMethodName, InvokeAsyncMethodName, nameof(HttpContext)));
        }

        var ctorArgs = new object[args.Length + 1];
        ctorArgs[0] = next;
        Array.Copy(args, 0, ctorArgs, 1, args.Length);
        var instance = ActivatorUtilities.CreateInstance(app.ApplicationServices, middleware, ctorArgs);
        if (parameters.Length == 1)
        {
            return (RequestDelegate)methodInfo.CreateDelegate(typeof(RequestDelegate), instance);
        }

        var factory = Compile<object>(methodInfo, parameters);

        return context =>
        {
            var serviceProvider = context.RequestServices ?? applicationServices;
            if (serviceProvider == null)
            {
                throw new InvalidOperationException(Resources.FormatException_UseMiddlewareIServiceProviderNotAvailable(nameof(IServiceProvider)));
            }

            return factory(instance, context, serviceProvider);
        };
    });
}

其中 ActivatorUtilities 类的源码如下:

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.ExceptionServices;

#if ActivatorUtilities_In_DependencyInjection
using Microsoft.Extensions.Internal;

namespace Microsoft.Extensions.DependencyInjection
#else
namespace Microsoft.Extensions.Internal
#endif
{
    /// <summary>
    /// Helper code for the various activator services.
    /// </summary>

#if ActivatorUtilities_In_DependencyInjection
    public
#else
    // Do not take a dependency on this class unless you are explicitly trying to avoid taking a
    // dependency on Microsoft.AspNetCore.DependencyInjection.Abstractions.
    internal
#endif
    static class ActivatorUtilities
    {
        private static readonly MethodInfo GetServiceInfo =
            GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object>>((sp, t, r, c) => GetService(sp, t, r, c));

        /// <summary>
        /// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
        /// </summary>
        /// <param name="provider">The service provider used to resolve dependencies</param>
        /// <param name="instanceType">The type to activate</param>
        /// <param name="parameters">Constructor arguments not provided by the <paramref name="provider"/>.</param>
        /// <returns>An activated object of type instanceType</returns>
        public static object CreateInstance(IServiceProvider provider, Type instanceType, params object[] parameters)
        {
            int bestLength = -1;
            var seenPreferred = false;

            ConstructorMatcher bestMatcher = default;

            if (!instanceType.GetTypeInfo().IsAbstract)
            {
                foreach (var constructor in instanceType
                    .GetTypeInfo()
                    .DeclaredConstructors)
                {
                    if (!constructor.IsStatic && constructor.IsPublic)
                    {
                        var matcher = new ConstructorMatcher(constructor);
                        var isPreferred = constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), false);
                        var length = matcher.Match(parameters);

                        if (isPreferred)
                        {
                            if (seenPreferred)
                            {
                                ThrowMultipleCtorsMarkedWithAttributeException();
                            }

                            if (length == -1)
                            {
                                ThrowMarkedCtorDoesNotTakeAllProvidedArguments();
                            }
                        }

                        if (isPreferred || bestLength < length)
                        {
                            bestLength = length;
                            bestMatcher = matcher;
                        }

                        seenPreferred |= isPreferred;
                    }
                }
            }

            if (bestLength == -1)
            {
                var message = $"A suitable constructor for type '{instanceType}' could not be located. Ensure the type is concrete and services are registered for all parameters of a public constructor.";
                throw new InvalidOperationException(message);
            }

            return bestMatcher.CreateInstance(provider);
        }

        /// <summary>
        /// Create a delegate that will instantiate a type with constructor arguments provided directly
        /// and/or from an <see cref="IServiceProvider"/>.
        /// </summary>
        /// <param name="instanceType">The type to activate</param>
        /// <param name="argumentTypes">
        /// The types of objects, in order, that will be passed to the returned function as its second parameter
        /// </param>
        /// <returns>
        /// A factory that will instantiate instanceType using an <see cref="IServiceProvider"/>
        /// and an argument array containing objects matching the types defined in argumentTypes
        /// </returns>
        public static ObjectFactory CreateFactory(Type instanceType, Type[] argumentTypes)
        {
            FindApplicableConstructor(instanceType, argumentTypes, out ConstructorInfo constructor, out int?[] parameterMap);

            var provider = Expression.Parameter(typeof(IServiceProvider), "provider");
            var argumentArray = Expression.Parameter(typeof(object[]), "argumentArray");
            var factoryExpressionBody = BuildFactoryExpression(constructor, parameterMap, provider, argumentArray);

            var factoryLamda = Expression.Lambda<Func<IServiceProvider, object[], object>>(
                factoryExpressionBody, provider, argumentArray);

            var result = factoryLamda.Compile();
            return result.Invoke;
        }

        /// <summary>
        /// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
        /// </summary>
        /// <typeparam name="T">The type to activate</typeparam>
        /// <param name="provider">The service provider used to resolve dependencies</param>
        /// <param name="parameters">Constructor arguments not provided by the <paramref name="provider"/>.</param>
        /// <returns>An activated object of type T</returns>
        public static T CreateInstance<T>(IServiceProvider provider, params object[] parameters)
        {
            return (T)CreateInstance(provider, typeof(T), parameters);
        }


        /// <summary>
        /// Retrieve an instance of the given type from the service provider. If one is not found then instantiate it directly.
        /// </summary>
        /// <typeparam name="T">The type of the service</typeparam>
        /// <param name="provider">The service provider used to resolve dependencies</param>
        /// <returns>The resolved service or created instance</returns>
        public static T GetServiceOrCreateInstance<T>(IServiceProvider provider)
        {
            return (T)GetServiceOrCreateInstance(provider, typeof(T));
        }

        /// <summary>
        /// Retrieve an instance of the given type from the service provider. If one is not found then instantiate it directly.
        /// </summary>
        /// <param name="provider">The service provider</param>
        /// <param name="type">The type of the service</param>
        /// <returns>The resolved service or created instance</returns>
        public static object GetServiceOrCreateInstance(IServiceProvider provider, Type type)
        {
            return provider.GetService(type) ?? CreateInstance(provider, type);
        }

        private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
        {
            var mc = (MethodCallExpression)expr.Body;
            return mc.Method;
        }

        private static object GetService(IServiceProvider sp, Type type, Type requiredBy, bool isDefaultParameterRequired)
        {
            var service = sp.GetService(type);
            if (service == null && !isDefaultParameterRequired)
            {
                var message = $"Unable to resolve service for type '{type}' while attempting to activate '{requiredBy}'.";
                throw new InvalidOperationException(message);
            }
            return service;
        }

        private static Expression BuildFactoryExpression(
            ConstructorInfo constructor,
            int?[] parameterMap,
            Expression serviceProvider,
            Expression factoryArgumentArray)
        {
            var constructorParameters = constructor.GetParameters();
            var constructorArguments = new Expression[constructorParameters.Length];

            for (var i = 0; i < constructorParameters.Length; i++)
            {
                var constructorParameter = constructorParameters[i];
                var parameterType = constructorParameter.ParameterType;
                var hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out var defaultValue);

                if (parameterMap[i] != null)
                {
                    constructorArguments[i] = Expression.ArrayAccess(factoryArgumentArray, Expression.Constant(parameterMap[i]));
                }
                else
                {
                    var parameterTypeExpression = new Expression[] { serviceProvider,
                        Expression.Constant(parameterType, typeof(Type)),
                        Expression.Constant(constructor.DeclaringType, typeof(Type)),
                        Expression.Constant(hasDefaultValue) };
                    constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression);
                }

                // Support optional constructor arguments by passing in the default value
                // when the argument would otherwise be null.
                if (hasDefaultValue)
                {
                    var defaultValueExpression = Expression.Constant(defaultValue);
                    constructorArguments[i] = Expression.Coalesce(constructorArguments[i], defaultValueExpression);
                }

                constructorArguments[i] = Expression.Convert(constructorArguments[i], parameterType);
            }

            return Expression.New(constructor, constructorArguments);
        }

        private static void FindApplicableConstructor(
            Type instanceType,
            Type[] argumentTypes,
            out ConstructorInfo matchingConstructor,
            out int?[] parameterMap)
        {
            matchingConstructor = null;
            parameterMap = null;

            if (!TryFindPreferredConstructor(instanceType, argumentTypes, ref matchingConstructor, ref parameterMap) &&
                !TryFindMatchingConstructor(instanceType, argumentTypes, ref matchingConstructor, ref parameterMap))
            {
                var message = $"A suitable constructor for type '{instanceType}' could not be located. Ensure the type is concrete and services are registered for all parameters of a public constructor.";
                throw new InvalidOperationException(message);
            }
        }

        // Tries to find constructor based on provided argument types
        private static bool TryFindMatchingConstructor(
            Type instanceType,
            Type[] argumentTypes,
            ref ConstructorInfo matchingConstructor,
            ref int?[] parameterMap)
        {
            foreach (var constructor in instanceType.GetTypeInfo().DeclaredConstructors)
            {
                if (constructor.IsStatic || !constructor.IsPublic)
                {
                    continue;
                }

                if (TryCreateParameterMap(constructor.GetParameters(), argumentTypes, out int?[] tempParameterMap))
                {
                    if (matchingConstructor != null)
                    {
                        throw new InvalidOperationException($"Multiple constructors accepting all given argument types have been found in type '{instanceType}'. There should only be one applicable constructor.");
                    }

                    matchingConstructor = constructor;
                    parameterMap = tempParameterMap;
                }
            }

            return matchingConstructor != null;
        }

        // Tries to find constructor marked with ActivatorUtilitiesConstructorAttribute
        private static bool TryFindPreferredConstructor(
            Type instanceType,
            Type[] argumentTypes,
            ref ConstructorInfo matchingConstructor,
            ref int?[] parameterMap)
        {
            var seenPreferred = false;
            foreach (var constructor in instanceType.GetTypeInfo().DeclaredConstructors)
            {
                if (constructor.IsStatic || !constructor.IsPublic)
                {
                    continue;
                }

                if (constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), false))
                {
                    if (seenPreferred)
                    {
                        ThrowMultipleCtorsMarkedWithAttributeException();
                    }

                    if (!TryCreateParameterMap(constructor.GetParameters(), argumentTypes, out int?[] tempParameterMap))
                    {
                        ThrowMarkedCtorDoesNotTakeAllProvidedArguments();
                    }

                    matchingConstructor = constructor;
                    parameterMap = tempParameterMap;
                    seenPreferred = true;
                }
            }

            return matchingConstructor != null;
        }

        // Creates an injective parameterMap from givenParameterTypes to assignable constructorParameters.
        // Returns true if each given parameter type is assignable to a unique; otherwise, false.
        private static bool TryCreateParameterMap(ParameterInfo[] constructorParameters, Type[] argumentTypes, out int?[] parameterMap)
        {
            parameterMap = new int?[constructorParameters.Length];

            for (var i = 0; i < argumentTypes.Length; i++)
            {
                var foundMatch = false;
                var givenParameter = argumentTypes[i].GetTypeInfo();

                for (var j = 0; j < constructorParameters.Length; j++)
                {
                    if (parameterMap[j] != null)
                    {
                        // This ctor parameter has already been matched
                        continue;
                    }

                    if (constructorParameters[j].ParameterType.GetTypeInfo().IsAssignableFrom(givenParameter))
                    {
                        foundMatch = true;
                        parameterMap[j] = i;
                        break;
                    }
                }

                if (!foundMatch)
                {
                    return false;
                }
            }

            return true;
        }

        private struct ConstructorMatcher
        {
            private readonly ConstructorInfo _constructor;
            private readonly ParameterInfo[] _parameters;
            private readonly object[] _parameterValues;

            public ConstructorMatcher(ConstructorInfo constructor)
            {
                _constructor = constructor;
                _parameters = _constructor.GetParameters();
                _parameterValues = new object[_parameters.Length];
            }

            public int Match(object[] givenParameters)
            {
                var applyIndexStart = 0;
                var applyExactLength = 0;
                for (var givenIndex = 0; givenIndex != givenParameters.Length; givenIndex++)
                {
                    var givenType = givenParameters[givenIndex]?.GetType().GetTypeInfo();
                    var givenMatched = false;

                    for (var applyIndex = applyIndexStart; givenMatched == false && applyIndex != _parameters.Length; ++applyIndex)
                    {
                        if (_parameterValues[applyIndex] == null &&
                            _parameters[applyIndex].ParameterType.GetTypeInfo().IsAssignableFrom(givenType))
                        {
                            givenMatched = true;
                            _parameterValues[applyIndex] = givenParameters[givenIndex];
                            if (applyIndexStart == applyIndex)
                            {
                                applyIndexStart++;
                                if (applyIndex == givenIndex)
                                {
                                    applyExactLength = applyIndex;
                                }
                            }
                        }
                    }

                    if (givenMatched == false)
                    {
                        return -1;
                    }
                }
                return applyExactLength;
            }

            public object CreateInstance(IServiceProvider provider)
            {
                for (var index = 0; index != _parameters.Length; index++)
                {
                    if (_parameterValues[index] == null)
                    {
                        var value = provider.GetService(_parameters[index].ParameterType);
                        if (value == null)
                        {
                            if (!ParameterDefaultValue.TryGetDefaultValue(_parameters[index], out var defaultValue))
                            {
                                throw new InvalidOperationException($"Unable to resolve service for type '{_parameters[index].ParameterType}' while attempting to activate '{_constructor.DeclaringType}'.");
                            }
                            else
                            {
                                _parameterValues[index] = defaultValue;
                            }
                        }
                        else
                        {
                            _parameterValues[index] = value;
                        }
                    }
                }

#if NETCOREAPP
                return _constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: _parameterValues, culture: null);
#else
                try
                {
                    return _constructor.Invoke(_parameterValues);
                }
                catch (TargetInvocationException ex) when (ex.InnerException != null)
                {
                    ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
                    // The above line will always throw, but the compiler requires we throw explicitly.
                    throw;
                }
#endif
            }
        }

        private static void ThrowMultipleCtorsMarkedWithAttributeException()
        {
            throw new InvalidOperationException($"Multiple constructors were marked with {nameof(ActivatorUtilitiesConstructorAttribute)}.");
        }

        private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
        {
            throw new InvalidOperationException($"Constructor marked with {nameof(ActivatorUtilitiesConstructorAttribute)} does not accept all given argument types.");
        }
    }
}
Microsoft.Extensions.Internal.ActivatorUtilities类源码

仔细阅读上面的源码后我们会发现:如果 Invoke(或InvokeAsync)方法只有一个参数(HttpContext)则它会把 Invoke(或InvokeAsync)方法包装为 RequestDelegate ,进而包装为 Func<RequestDelegate, RequestDelegate> ,最后再调用 app.Use(...) 这个方法进行注册;如果 Invoke(或InvokeAsync)方法有多个参数,不符合 RequestDelegate 委托的签名,则它会对 Invoke(或InvokeAsync)方法进行二次包装以符合 RequestDelegate 委托的签名;在二次包装过程中框架会尝试从依赖注入容器中获取 Invoke(或InvokeAsync)方法所需的参数。

此外我们可以得出一些结论,如下:

关于自定义中间件的方法:

1、中间件的方法名必须叫 Invoke 或者 InvokeAsync ,且为 public ,非 static 。

2、Invoke 或者 InvokeAsync 方法第一个参数必须是 HttpContext 类型。

3、Invoke 或者 InvokeAsync 方法的返回值类型必须是 Task 。

4、Invoke 或者 InvokeAsync 方法可以有多个参数,除 HttpContext 外其它参数会尝试从依赖注入容器中获取。

5、Invoke 或者 InvokeAsync 方法不能有重载。

 

关于自定义中间件的构造函数:

1、构造函数必须包含 RequestDelegate 参数,该参数传入的是下一个中间件。

2、构造函数参数中的 RequestDelegate 参数不是必须放在第一个,可以是任意位置。

3、构造函数可以有多个参数,参数会优先从给定的参数列表中查找,其次会从依赖注入容器中获取,获取失败会尝试获取默认值,都失败则会抛出异常。

4、构造函数可以有多个,届时会根据构造函数参数列表和给定的参数列表选择匹配度最高的一个。

 

个人建议:

1、除极特殊情况外只保留一个构造函数,以省去多余的构造函数匹配检查。

2、在构造函数中注入所需依赖而不是在 Invoke 或者 InvokeAsync 方法中注入。

3、关于构造函数参数的顺序,把 RequestDelegate 放在第一个;之后是 UseMiddleware 方法中给出的参数,而且构造函数中参数顺序和给定参数列表中的顺序最好也相同;然后是需要注入的参数;最后是有默认值的参数。以上除了默认值参数必须放在最后外其余的顺序都不是必须的,但按照上面的顺序会比较清晰,而且能使实例创建的开销最小。

4、Invoke 或者 InvokeAsync 方法只保留一个 HttpContext 参数,这样可以省去对 Invoke 或者 InvokeAsync 方法的二次包装。

 

本文部分内容参考博文:https://www.cnblogs.com/durow/p/5748124.html

更多关于ASP.NET Core 中间件的相关知识可参考微软官方文档: https://docs.microsoft.com/zh-cn/aspnet/core/fundamentals/middleware/?view=aspnetcore-6.0

至此本文就全部介绍完了,如果觉得对您有所启发请记得点个赞哦!!!

 

Demo源码:

链接:https://pan.baidu.com/s/11NLDtt_cKPFxDbA3fW8Btg 
提取码:0hyo

此文由博主精心撰写转载请保留此原文链接:https://www.cnblogs.com/xyh9039/p/16414344.html

版权声明:如有雷同纯属巧合,如有侵权请及时联系本人修改,谢谢!!! 

posted @ 2022-07-02 23:32  谢友海  阅读(920)  评论(0编辑  收藏  举报