Loading

动手造轮子:实现一个简单的依赖注入(一)

动手造轮子:实现一个简单的依赖注入(一)

Intro

上一篇文章中主要介绍了一下要做的依赖注入的整体设计和大概编程体验,这篇文章要开始写代码了,开始实现自己的依赖注入框架。

类图

首先来温习一下上次提到的 UML 类图

服务生命周期

服务生命周期定义:

public enum ServiceLifetime : sbyte
{
    /// <summary>
    /// Specifies that a single instance of the service will be created.
    /// </summary>
    Singleton = 0,

    /// <summary>
    /// Specifies that a new instance of the service will be created for each scope.
    /// </summary>
    Scoped = 1,

    /// <summary>
    /// Specifies that a new instance of the service will be created every time it is requested.
    /// </summary>
    Transient = 2,
}

服务定义

服务注册定义:

public class ServiceDefinition
{
    // 服务生命周期
    public ServiceLifetime ServiceLifetime { get; }
    // 实现类型
    public Type ImplementType { get; }
    // 服务类型
    public Type ServiceType { get; }
    // 实现实例
    public object ImplementationInstance { get; }
    // 实现工厂
    public Func<IServiceProvider, object> ImplementationFactory { get; }

    // 获取真实的实现类型
    public Type GetImplementType()
    {
        if (ImplementationInstance != null)
            return ImplementationInstance.GetType();

        if (ImplementationFactory != null)
            return ImplementationFactory.Method.DeclaringType;

        if (ImplementType != null)
            return ImplementType;

        return ServiceType;
    }

    public ServiceDefinition(object instance, Type serviceType)
    {
        ImplementationInstance = instance;
        ServiceType = serviceType;
        ServiceLifetime = ServiceLifetime.Singleton;
    }

    public ServiceDefinition(Type serviceType, ServiceLifetime serviceLifetime) : this(serviceType, serviceType, serviceLifetime)
    {
    }

    public ServiceDefinition(Type serviceType, Type implementType, ServiceLifetime serviceLifetime)
    {
        ServiceType = serviceType;
        ImplementType = implementType ?? serviceType;
        ServiceLifetime = serviceLifetime;
    }

    public ServiceDefinition(Type serviceType, Func<IServiceProvider, object> factory, ServiceLifetime serviceLifetime)
    {
        ServiceType = serviceType;
        ImplementationFactory = factory;
        ServiceLifetime = serviceLifetime;
    }
}

为了使用起来更方便添加了一些静态方法

public static ServiceDefinition Singleton<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Transient);
}

public static ServiceDefinition Singleton<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Transient);
}

public static ServiceDefinition Singleton<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Transient);
}

ServiceContainer

serviceContainer v1

public class ServiceContainer : IServiceContainer
{
    internal readonly List<ServiceDefinition> _services;

    private readonly ConcurrentDictionary<Type, object> _singletonInstances;

    private readonly ConcurrentDictionary<Type, object> _scopedInstances;
    
    private readonly List<object> _transientDisposables = new List<object>();

    private readonly bool _isRootScope;

    public ServiceContainer()
    {
        _isRootScope = true;
        _singletonInstances = new ConcurrentDictionary<Type, object>();
        _services = new List<ServiceDefinition>();
    }

    internal ServiceContainer(ServiceContainer serviceContainer)
    {
        _isRootScope = false;
        _singletonInstances = serviceContainer._singletonInstances;
        _services = serviceContainer._services;
        _scopedInstances = new ConcurrentDictionary<Type, object>();
    }

    public void Add(ServiceDefinition item)
    {
        _services.Add(item);
    }

    public IServiceContainer CreateScope()
    {
        return new ServiceContainer(this);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
        {
            return;
        }

        if (_isRootScope)
        {
            lock (_singletonInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _singletonInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }
            }
        }
        else
        {
            lock (_scopedInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _scopedInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }
            }
        }
    }

    private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
    {
        if (serviceDefinition.ImplementationInstance != null)
            return serviceDefinition.ImplementationInstance;

        if (serviceDefinition.ImplementationFactory != null)
            return serviceDefinition.ImplementationFactory.Invoke(this);

        var implementType = (serviceDefinition.ImplementType ?? serviceType);

        if (implementType.IsInterface || implementType.IsAbstract)
        {
            throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
        }

        var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
        if (ctorInfos.Length == 0)
        {
            throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
        }

        ConstructorInfo ctor;
        if (ctorInfos.Length == 1)
        {
            ctor = ctorInfos[0];
        }
        else
        {
            // try find best ctor
            ctor = ctorInfos
                .OrderBy(_ => _.GetParameters().Length)
                .First();
        }

        var parameters = ctor.GetParameters();
        if (parameters.Length == 0)
        {
            // TODO: cache New Func
            return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
        }
        else
        {
            var ctorParams = new object[parameters.Length];
            for (var index = 0; index < parameters.Length; index++)
            {
                var parameter = parameters[index];
                var param = GetService(parameter.ParameterType);
                if (param == null && parameter.HasDefaultValue)
                {
                    param = parameter.DefaultValue;
                }

                ctorParams[index] = param;
            }
            return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
        }
    }

    public object GetService(Type serviceType)
    {
        var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
        if (null == serviceDefinition)
        {
            return null;
        }

        if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
        }

        if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
        {
            var svc = _singletonInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
            return svc;
        }
        else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            var svc = _scopedInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
            return svc;
        }
        else
        {
            var svc = GetServiceInstance(serviceType, serviceDefinition);
            if (svc is IDisposable)
            {
                _transientDisposables.Add(svc);
            }
            return svc;
        }
    }
}

为了使得服务注册更加方便,可以写一些扩展方法来方便注册:

public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]TService service)
{
    serviceContainer.Add(new ServiceDefinition(service, typeof(TService)));
    return serviceContainer;
}

public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Singleton));
    return serviceContainer;
}

public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Singleton));
    return serviceContainer;
}

public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService>());
    return serviceContainer;
}


public static IServiceContainer AddSingleton<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService, TServiceImplement>());
    return serviceContainer;
}

public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Scoped));
    return serviceContainer;
}

public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Scoped));
    return serviceContainer;
}

public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService>());
    return serviceContainer;
}


public static IServiceContainer AddScoped<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService, TServiceImplement>());
    return serviceContainer;
}

public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Transient));
    return serviceContainer;
}

public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Transient));
    return serviceContainer;
}

public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Transient<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Transient<TService>());
    return serviceContainer;
}


public static IServiceContainer AddTransient<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Transient<TService, TServiceImplement>());
    return serviceContainer;
}

通过上面的代码就可以实现基本依赖注入了,但是从功能上来说,上面的代码只支持获取单个服务的实例,不支持注册一个接口的多个实现,获取接口的所有实现,为此对 ServiceContainer 中的 Instance 的 ConcurrentDictionary 的 Key 进行一下改造,使得可以能够以接口类型和实现类型联合作为 key,为此就有了第二版的 ServiceContainer

ServiceContainer v2

为此定义了一个 ServiceKey 的类型,请注意这里一定要重写 GetHashCode 方法:

private class ServiceKey : IEquatable<ServiceKey>
{
    public Type ServiceType { get; }

    public Type ImplementType { get; }

    public ServiceKey(Type serviceType, ServiceDefinition definition)
    {
        ServiceType = serviceType;
        ImplementType = definition.GetImplementType();
    }

    public bool Equals(ServiceKey other)
    {
        return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
    }

    public override bool Equals(object obj)
    {
        return Equals((ServiceKey)obj);
    }

    public override int GetHashCode()
    {
        var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
        return key.GetHashCode();
    }
}

第二版的 ServiceContainer :

public class ServiceContainer : IServiceContainer
{
    internal readonly ConcurrentBag<ServiceDefinition> _services;

    private readonly ConcurrentDictionary<ServiceKey, object> _singletonInstances;

    private readonly ConcurrentDictionary<ServiceKey, object> _scopedInstances;
    private ConcurrentBag<object> _transientDisposables = new ConcurrentBag<object>();

    private class ServiceKey : IEquatable<ServiceKey>
    {
        public Type ServiceType { get; }

        public Type ImplementType { get; }

        public ServiceKey(Type serviceType, ServiceDefinition definition)
        {
            ServiceType = serviceType;
            ImplementType = definition.GetImplementType();
        }

        public bool Equals(ServiceKey other)
        {
            return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
        }

        public override bool Equals(object obj)
        {
            return Equals((ServiceKey)obj);
        }

        public override int GetHashCode()
        {
            var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
            return key.GetHashCode();
        }
    }

    private readonly bool _isRootScope;

    public ServiceContainer()
    {
        _isRootScope = true;
        _singletonInstances = new ConcurrentDictionary<ServiceKey, object>();
        _services = new ConcurrentBag<ServiceDefinition>();
    }

    private ServiceContainer(ServiceContainer serviceContainer)
    {
        _isRootScope = false;
        _singletonInstances = serviceContainer._singletonInstances;
        _services = serviceContainer._services;
        _scopedInstances = new ConcurrentDictionary<ServiceKey, object>();
    }

    public IServiceContainer Add(ServiceDefinition item)
    {
        if (_disposed)
        {
            throw new InvalidOperationException("the service container had been disposed");
        }
        if (_services.Any(_ => _.ServiceType == item.ServiceType && _.GetImplementType() == item.GetImplementType()))
        {
            return this;
        }

        _services.Add(item);
        return this;
    }

    public IServiceContainer TryAdd(ServiceDefinition item)
    {
        if (_disposed)
        {
            throw new InvalidOperationException("the service container had been disposed");
        }
        if (_services.Any(_ => _.ServiceType == item.ServiceType))
        {
            return this;
        }
        _services.Add(item);
        return this;
    }

    public IServiceContainer CreateScope()
    {
        return new ServiceContainer(this);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
        {
            return;
        }

        if (_isRootScope)
        {
            lock (_singletonInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _singletonInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }

                _singletonInstances.Clear();
                _transientDisposables = null;
            }
        }
        else
        {
            lock (_scopedInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _scopedInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }

                _scopedInstances.Clear();
                _transientDisposables = null;
            }
        }
    }

    private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
    {
        if (serviceDefinition.ImplementationInstance != null)
            return serviceDefinition.ImplementationInstance;

        if (serviceDefinition.ImplementationFactory != null)
            return serviceDefinition.ImplementationFactory.Invoke(this);

        var implementType = (serviceDefinition.ImplementType ?? serviceType);

        if (implementType.IsInterface || implementType.IsAbstract)
        {
            throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
        }

        if (implementType.IsGenericType)
        {
            implementType = implementType.MakeGenericType(serviceType.GetGenericArguments());
        }

        var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
        if (ctorInfos.Length == 0)
        {
            throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
        }

        ConstructorInfo ctor;
        if (ctorInfos.Length == 1)
        {
            ctor = ctorInfos[0];
        }
        else
        {
            // TODO: try find best ctor
            ctor = ctorInfos
                .OrderBy(_ => _.GetParameters().Length)
                .First();
        }

        var parameters = ctor.GetParameters();
        if (parameters.Length == 0)
        {
            // TODO: cache New Func
            return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
        }
        else
        {
            var ctorParams = new object[parameters.Length];
            for (var index = 0; index < parameters.Length; index++)
            {
                var parameter = parameters[index];
                var param = GetService(parameter.ParameterType);
                if (param == null && parameter.HasDefaultValue)
                {
                    param = parameter.DefaultValue;
                }

                ctorParams[index] = param;
            }
            return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
        }
    }

    public object GetService(Type serviceType)
    {
        if (_disposed)
        {
            throw new InvalidOperationException($"can not get scope service from a disposed scope, serviceType: {serviceType.FullName}");
        }

        var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
        if (null == serviceDefinition)
        {
            if (serviceType.IsGenericType)
            {
                var genericType = serviceType.GetGenericTypeDefinition();
                serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == genericType);
                if (null == serviceDefinition)
                {
                    var innerServiceType = serviceType.GetGenericArguments().First();
                    if (typeof(IEnumerable<>).MakeGenericType(innerServiceType)
                        .IsAssignableFrom(serviceType))
                    {
                        var innerRegType = innerServiceType;
                        if (innerServiceType.IsGenericType)
                        {
                            innerRegType = innerServiceType.GetGenericTypeDefinition();
                        }
                        //
                        var list = new List<object>(4);
                        foreach (var def in _services.Where(_ => _.ServiceType == innerRegType))
                        {
                            object svc;
                            if (def.ServiceLifetime == ServiceLifetime.Singleton)
                            {
                                svc = _singletonInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
                            }
                            else if (def.ServiceLifetime == ServiceLifetime.Scoped)
                            {
                                svc = _scopedInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
                            }
                            else
                            {
                                svc = GetServiceInstance(innerServiceType, def);
                                if (svc is IDisposable)
                                {
                                    _transientDisposables.Add(svc);
                                }
                            }
                            if (null != svc)
                            {
                                list.Add(svc);
                            }
                        }

                        var methodInfo = typeof(Enumerable)
                            .GetMethod("Cast", BindingFlags.Static | BindingFlags.Public);
                        if (methodInfo != null)
                        {
                            var genericMethod = methodInfo.MakeGenericMethod(innerServiceType);
                            var castedValue = genericMethod.Invoke(null, new object[] { list });
                            if (typeof(IEnumerable<>).MakeGenericType(innerServiceType) == serviceType)
                            {
                                return castedValue;
                            }
                            var toArrayMethod = typeof(Enumerable).GetMethod("ToArray", BindingFlags.Static | BindingFlags.Public)
                                .MakeGenericMethod(innerServiceType);

                            return toArrayMethod.Invoke(null, new object[] { castedValue });
                        }
                        return list;
                    }

                    return null;
                }
            }
            else
            {
                return null;
            }
        }

        if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
        }

        if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
        {
            var svc = _singletonInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
            return svc;
        }
        else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            var svc = _scopedInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
            return svc;
        }
        else
        {
            var svc = GetServiceInstance(serviceType, serviceDefinition);
            if (svc is IDisposable)
            {
                _transientDisposables.Add(svc);
            }
            return svc;
        }
    }
}

这样我们就不仅支持了 IEnumerable<TService> 的注册,也支持 IReadOnlyList<TService>/ IReadOnlyCollection<TService> 的注册

因为 GetService 返回是 object , 不是强类型的,所以为了使用起来方便,定义了几个扩展方法,类似于微软的依赖注入框架里的 GetService<TService>()/GetServices<TService>()/GetRequiredService<TService>()

/// <summary>
/// ResolveService
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveService<TService>([NotNull]this IServiceProvider serviceProvider)
    => (TService)serviceProvider.GetService(typeof(TService));

/// <summary>
/// ResolveRequiredService
/// throw exception if can not get a service instance
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveRequiredService<TService>([NotNull] this IServiceProvider serviceProvider)
{
    var serviceType = typeof(TService);
    var svc = serviceProvider.GetService(serviceType);
    if (null == svc)
    {
        throw new InvalidOperationException($"service had not been registered, serviceType: {serviceType}");
    }
    return (TService)svc;
}

/// <summary>
/// Resolve services
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static IEnumerable<TService> ResolveServices<TService>([NotNull]this IServiceProvider serviceProvider)
    => serviceProvider.ResolveService<IEnumerable<TService>>();

More

后面还更新了一版,主要优化性能,目前来说还不太满意,暂时这里先不提了

Reference

posted @ 2019-10-28 23:36  WeihanLi  阅读(1804)  评论(6编辑  收藏  举报