动手造轮子:实现一个简单的依赖注入(一)
动手造轮子:实现一个简单的依赖注入(一)
Intro
在上一篇文章中主要介绍了一下要做的依赖注入的整体设计和大概编程体验,这篇文章要开始写代码了,开始实现自己的依赖注入框架。
类图
首先来温习一下上次提到的 UML 类图
服务生命周期
服务生命周期定义:
public enum ServiceLifetime : sbyte
{
/// <summary>
/// Specifies that a single instance of the service will be created.
/// </summary>
Singleton = 0,
/// <summary>
/// Specifies that a new instance of the service will be created for each scope.
/// </summary>
Scoped = 1,
/// <summary>
/// Specifies that a new instance of the service will be created every time it is requested.
/// </summary>
Transient = 2,
}
服务定义
服务注册定义:
public class ServiceDefinition
{
// 服务生命周期
public ServiceLifetime ServiceLifetime { get; }
// 实现类型
public Type ImplementType { get; }
// 服务类型
public Type ServiceType { get; }
// 实现实例
public object ImplementationInstance { get; }
// 实现工厂
public Func<IServiceProvider, object> ImplementationFactory { get; }
// 获取真实的实现类型
public Type GetImplementType()
{
if (ImplementationInstance != null)
return ImplementationInstance.GetType();
if (ImplementationFactory != null)
return ImplementationFactory.Method.DeclaringType;
if (ImplementType != null)
return ImplementType;
return ServiceType;
}
public ServiceDefinition(object instance, Type serviceType)
{
ImplementationInstance = instance;
ServiceType = serviceType;
ServiceLifetime = ServiceLifetime.Singleton;
}
public ServiceDefinition(Type serviceType, ServiceLifetime serviceLifetime) : this(serviceType, serviceType, serviceLifetime)
{
}
public ServiceDefinition(Type serviceType, Type implementType, ServiceLifetime serviceLifetime)
{
ServiceType = serviceType;
ImplementType = implementType ?? serviceType;
ServiceLifetime = serviceLifetime;
}
public ServiceDefinition(Type serviceType, Func<IServiceProvider, object> factory, ServiceLifetime serviceLifetime)
{
ServiceType = serviceType;
ImplementationFactory = factory;
ServiceLifetime = serviceLifetime;
}
}
为了使用起来更方便添加了一些静态方法
public static ServiceDefinition Singleton<TService>(Func<IServiceProvider, object> factory)
{
return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Singleton);
}
public static ServiceDefinition Scoped<TService>(Func<IServiceProvider, object> factory)
{
return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Scoped);
}
public static ServiceDefinition Transient<TService>(Func<IServiceProvider, object> factory)
{
return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Transient);
}
public static ServiceDefinition Singleton<TService>()
{
return new ServiceDefinition(typeof(TService), ServiceLifetime.Singleton);
}
public static ServiceDefinition Scoped<TService>()
{
return new ServiceDefinition(typeof(TService), ServiceLifetime.Scoped);
}
public static ServiceDefinition Transient<TService>()
{
return new ServiceDefinition(typeof(TService), ServiceLifetime.Transient);
}
public static ServiceDefinition Singleton<TService, TServiceImplement>() where TServiceImplement : TService
{
return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Singleton);
}
public static ServiceDefinition Scoped<TService, TServiceImplement>() where TServiceImplement : TService
{
return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Scoped);
}
public static ServiceDefinition Transient<TService, TServiceImplement>() where TServiceImplement : TService
{
return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Transient);
}
ServiceContainer
serviceContainer v1
public class ServiceContainer : IServiceContainer
{
internal readonly List<ServiceDefinition> _services;
private readonly ConcurrentDictionary<Type, object> _singletonInstances;
private readonly ConcurrentDictionary<Type, object> _scopedInstances;
private readonly List<object> _transientDisposables = new List<object>();
private readonly bool _isRootScope;
public ServiceContainer()
{
_isRootScope = true;
_singletonInstances = new ConcurrentDictionary<Type, object>();
_services = new List<ServiceDefinition>();
}
internal ServiceContainer(ServiceContainer serviceContainer)
{
_isRootScope = false;
_singletonInstances = serviceContainer._singletonInstances;
_services = serviceContainer._services;
_scopedInstances = new ConcurrentDictionary<Type, object>();
}
public void Add(ServiceDefinition item)
{
_services.Add(item);
}
public IServiceContainer CreateScope()
{
return new ServiceContainer(this);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
{
return;
}
if (_isRootScope)
{
lock (_singletonInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _singletonInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
}
}
else
{
lock (_scopedInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _scopedInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
}
}
}
private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
{
if (serviceDefinition.ImplementationInstance != null)
return serviceDefinition.ImplementationInstance;
if (serviceDefinition.ImplementationFactory != null)
return serviceDefinition.ImplementationFactory.Invoke(this);
var implementType = (serviceDefinition.ImplementType ?? serviceType);
if (implementType.IsInterface || implementType.IsAbstract)
{
throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
}
var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
if (ctorInfos.Length == 0)
{
throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
}
ConstructorInfo ctor;
if (ctorInfos.Length == 1)
{
ctor = ctorInfos[0];
}
else
{
// try find best ctor
ctor = ctorInfos
.OrderBy(_ => _.GetParameters().Length)
.First();
}
var parameters = ctor.GetParameters();
if (parameters.Length == 0)
{
// TODO: cache New Func
return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
}
else
{
var ctorParams = new object[parameters.Length];
for (var index = 0; index < parameters.Length; index++)
{
var parameter = parameters[index];
var param = GetService(parameter.ParameterType);
if (param == null && parameter.HasDefaultValue)
{
param = parameter.DefaultValue;
}
ctorParams[index] = param;
}
return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
}
}
public object GetService(Type serviceType)
{
var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
if (null == serviceDefinition)
{
return null;
}
if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
}
if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
{
var svc = _singletonInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
return svc;
}
else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
var svc = _scopedInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
return svc;
}
else
{
var svc = GetServiceInstance(serviceType, serviceDefinition);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
return svc;
}
}
}
为了使得服务注册更加方便,可以写一些扩展方法来方便注册:
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]TService service)
{
serviceContainer.Add(new ServiceDefinition(service, typeof(TService)));
return serviceContainer;
}
public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Singleton));
return serviceContainer;
}
public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Singleton));
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Singleton<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Singleton<TService>());
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Singleton<TService, TServiceImplement>());
return serviceContainer;
}
public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Scoped));
return serviceContainer;
}
public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Scoped));
return serviceContainer;
}
public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Scoped<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Scoped<TService>());
return serviceContainer;
}
public static IServiceContainer AddScoped<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Scoped<TService, TServiceImplement>());
return serviceContainer;
}
public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Transient));
return serviceContainer;
}
public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Transient));
return serviceContainer;
}
public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Transient<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Transient<TService>());
return serviceContainer;
}
public static IServiceContainer AddTransient<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Transient<TService, TServiceImplement>());
return serviceContainer;
}
通过上面的代码就可以实现基本依赖注入了,但是从功能上来说,上面的代码只支持获取单个服务的实例,不支持注册一个接口的多个实现,获取接口的所有实现,为此对 ServiceContainer
中的 Instance 的 ConcurrentDictionary
的 Key 进行一下改造,使得可以能够以接口类型和实现类型联合作为 key,为此就有了第二版的 ServiceContainer
ServiceContainer
v2
为此定义了一个 ServiceKey
的类型,请注意这里一定要重写 GetHashCode
方法:
private class ServiceKey : IEquatable<ServiceKey>
{
public Type ServiceType { get; }
public Type ImplementType { get; }
public ServiceKey(Type serviceType, ServiceDefinition definition)
{
ServiceType = serviceType;
ImplementType = definition.GetImplementType();
}
public bool Equals(ServiceKey other)
{
return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
}
public override bool Equals(object obj)
{
return Equals((ServiceKey)obj);
}
public override int GetHashCode()
{
var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
return key.GetHashCode();
}
}
第二版的 ServiceContainer
:
public class ServiceContainer : IServiceContainer
{
internal readonly ConcurrentBag<ServiceDefinition> _services;
private readonly ConcurrentDictionary<ServiceKey, object> _singletonInstances;
private readonly ConcurrentDictionary<ServiceKey, object> _scopedInstances;
private ConcurrentBag<object> _transientDisposables = new ConcurrentBag<object>();
private class ServiceKey : IEquatable<ServiceKey>
{
public Type ServiceType { get; }
public Type ImplementType { get; }
public ServiceKey(Type serviceType, ServiceDefinition definition)
{
ServiceType = serviceType;
ImplementType = definition.GetImplementType();
}
public bool Equals(ServiceKey other)
{
return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
}
public override bool Equals(object obj)
{
return Equals((ServiceKey)obj);
}
public override int GetHashCode()
{
var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
return key.GetHashCode();
}
}
private readonly bool _isRootScope;
public ServiceContainer()
{
_isRootScope = true;
_singletonInstances = new ConcurrentDictionary<ServiceKey, object>();
_services = new ConcurrentBag<ServiceDefinition>();
}
private ServiceContainer(ServiceContainer serviceContainer)
{
_isRootScope = false;
_singletonInstances = serviceContainer._singletonInstances;
_services = serviceContainer._services;
_scopedInstances = new ConcurrentDictionary<ServiceKey, object>();
}
public IServiceContainer Add(ServiceDefinition item)
{
if (_disposed)
{
throw new InvalidOperationException("the service container had been disposed");
}
if (_services.Any(_ => _.ServiceType == item.ServiceType && _.GetImplementType() == item.GetImplementType()))
{
return this;
}
_services.Add(item);
return this;
}
public IServiceContainer TryAdd(ServiceDefinition item)
{
if (_disposed)
{
throw new InvalidOperationException("the service container had been disposed");
}
if (_services.Any(_ => _.ServiceType == item.ServiceType))
{
return this;
}
_services.Add(item);
return this;
}
public IServiceContainer CreateScope()
{
return new ServiceContainer(this);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
{
return;
}
if (_isRootScope)
{
lock (_singletonInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _singletonInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
_singletonInstances.Clear();
_transientDisposables = null;
}
}
else
{
lock (_scopedInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _scopedInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
_scopedInstances.Clear();
_transientDisposables = null;
}
}
}
private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
{
if (serviceDefinition.ImplementationInstance != null)
return serviceDefinition.ImplementationInstance;
if (serviceDefinition.ImplementationFactory != null)
return serviceDefinition.ImplementationFactory.Invoke(this);
var implementType = (serviceDefinition.ImplementType ?? serviceType);
if (implementType.IsInterface || implementType.IsAbstract)
{
throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
}
if (implementType.IsGenericType)
{
implementType = implementType.MakeGenericType(serviceType.GetGenericArguments());
}
var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
if (ctorInfos.Length == 0)
{
throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
}
ConstructorInfo ctor;
if (ctorInfos.Length == 1)
{
ctor = ctorInfos[0];
}
else
{
// TODO: try find best ctor
ctor = ctorInfos
.OrderBy(_ => _.GetParameters().Length)
.First();
}
var parameters = ctor.GetParameters();
if (parameters.Length == 0)
{
// TODO: cache New Func
return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
}
else
{
var ctorParams = new object[parameters.Length];
for (var index = 0; index < parameters.Length; index++)
{
var parameter = parameters[index];
var param = GetService(parameter.ParameterType);
if (param == null && parameter.HasDefaultValue)
{
param = parameter.DefaultValue;
}
ctorParams[index] = param;
}
return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
}
}
public object GetService(Type serviceType)
{
if (_disposed)
{
throw new InvalidOperationException($"can not get scope service from a disposed scope, serviceType: {serviceType.FullName}");
}
var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
if (null == serviceDefinition)
{
if (serviceType.IsGenericType)
{
var genericType = serviceType.GetGenericTypeDefinition();
serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == genericType);
if (null == serviceDefinition)
{
var innerServiceType = serviceType.GetGenericArguments().First();
if (typeof(IEnumerable<>).MakeGenericType(innerServiceType)
.IsAssignableFrom(serviceType))
{
var innerRegType = innerServiceType;
if (innerServiceType.IsGenericType)
{
innerRegType = innerServiceType.GetGenericTypeDefinition();
}
//
var list = new List<object>(4);
foreach (var def in _services.Where(_ => _.ServiceType == innerRegType))
{
object svc;
if (def.ServiceLifetime == ServiceLifetime.Singleton)
{
svc = _singletonInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
}
else if (def.ServiceLifetime == ServiceLifetime.Scoped)
{
svc = _scopedInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
}
else
{
svc = GetServiceInstance(innerServiceType, def);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
}
if (null != svc)
{
list.Add(svc);
}
}
var methodInfo = typeof(Enumerable)
.GetMethod("Cast", BindingFlags.Static | BindingFlags.Public);
if (methodInfo != null)
{
var genericMethod = methodInfo.MakeGenericMethod(innerServiceType);
var castedValue = genericMethod.Invoke(null, new object[] { list });
if (typeof(IEnumerable<>).MakeGenericType(innerServiceType) == serviceType)
{
return castedValue;
}
var toArrayMethod = typeof(Enumerable).GetMethod("ToArray", BindingFlags.Static | BindingFlags.Public)
.MakeGenericMethod(innerServiceType);
return toArrayMethod.Invoke(null, new object[] { castedValue });
}
return list;
}
return null;
}
}
else
{
return null;
}
}
if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
}
if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
{
var svc = _singletonInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
return svc;
}
else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
var svc = _scopedInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
return svc;
}
else
{
var svc = GetServiceInstance(serviceType, serviceDefinition);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
return svc;
}
}
}
这样我们就不仅支持了 IEnumerable<TService>
的注册,也支持 IReadOnlyList<TService>
/ IReadOnlyCollection<TService>
的注册
因为 GetService
返回是 object , 不是强类型的,所以为了使用起来方便,定义了几个扩展方法,类似于微软的依赖注入框架里的 GetService<TService>()
/GetServices<TService>()
/GetRequiredService<TService>()
/// <summary>
/// ResolveService
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveService<TService>([NotNull]this IServiceProvider serviceProvider)
=> (TService)serviceProvider.GetService(typeof(TService));
/// <summary>
/// ResolveRequiredService
/// throw exception if can not get a service instance
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveRequiredService<TService>([NotNull] this IServiceProvider serviceProvider)
{
var serviceType = typeof(TService);
var svc = serviceProvider.GetService(serviceType);
if (null == svc)
{
throw new InvalidOperationException($"service had not been registered, serviceType: {serviceType}");
}
return (TService)svc;
}
/// <summary>
/// Resolve services
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static IEnumerable<TService> ResolveServices<TService>([NotNull]this IServiceProvider serviceProvider)
=> serviceProvider.ResolveService<IEnumerable<TService>>();
More
后面还更新了一版,主要优化性能,目前来说还不太满意,暂时这里先不提了
Reference
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。