動手造輪子:實現一個簡單的依賴注入(一)
阿新 • • 發佈:2019-10-29
動手造輪子:實現一個簡單的依賴注入(一)
Intro
在上一篇文章中主要介紹了一下要做的依賴注入的整體設計和大概程式設計體驗,這篇文章要開始寫程式碼了,開始實現自己的依賴注入框架。
類圖
首先來溫習一下上次提到的類圖
服務生命週期
服務生命週期定義:
public enum ServiceLifetime : sbyte { /// <summary> /// Specifies that a single instance of the service will be created. /// </summary> Singleton = 0, /// <summary> /// Specifies that a new instance of the service will be created for each scope. /// </summary> Scoped = 1, /// <summary> /// Specifies that a new instance of the service will be created every time it is requested. /// </summary> Transient = 2, }
服務定義
服務註冊定義:
public class ServiceDefinition { // 服務生命週期 public ServiceLifetime ServiceLifetime { get; } // 實現型別 public Type ImplementType { get; } // 服務型別 public Type ServiceType { get; } // 實現例項 public object ImplementationInstance { get; } // 實現工廠 public Func<IServiceProvider, object> ImplementationFactory { get; } // 獲取真實的實現型別 public Type GetImplementType() { if (ImplementationInstance != null) return ImplementationInstance.GetType(); if (ImplementationFactory != null) return ImplementationFactory.Method.DeclaringType; if (ImplementType != null) return ImplementType; return ServiceType; } public ServiceDefinition(object instance, Type serviceType) { ImplementationInstance = instance; ServiceType = serviceType; ServiceLifetime = ServiceLifetime.Singleton; } public ServiceDefinition(Type serviceType, ServiceLifetime serviceLifetime) : this(serviceType, serviceType, serviceLifetime) { } public ServiceDefinition(Type serviceType, Type implementType, ServiceLifetime serviceLifetime) { ServiceType = serviceType; ImplementType = implementType ?? serviceType; ServiceLifetime = serviceLifetime; } public ServiceDefinition(Type serviceType, Func<IServiceProvider, object> factory, ServiceLifetime serviceLifetime) { ServiceType = serviceType; ImplementationFactory = factory; ServiceLifetime = serviceLifetime; } }
為了使用起來更方便添加了一些靜態方法
public static ServiceDefinition Singleton<TService>(Func<IServiceProvider, object> factory) { return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Singleton); } public static ServiceDefinition Scoped<TService>(Func<IServiceProvider, object> factory) { return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Scoped); } public static ServiceDefinition Transient<TService>(Func<IServiceProvider, object> factory) { return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Transient); } public static ServiceDefinition Singleton<TService>() { return new ServiceDefinition(typeof(TService), ServiceLifetime.Singleton); } public static ServiceDefinition Scoped<TService>() { return new ServiceDefinition(typeof(TService), ServiceLifetime.Scoped); } public static ServiceDefinition Transient<TService>() { return new ServiceDefinition(typeof(TService), ServiceLifetime.Transient); } public static ServiceDefinition Singleton<TService, TServiceImplement>() where TServiceImplement : TService { return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Singleton); } public static ServiceDefinition Scoped<TService, TServiceImplement>() where TServiceImplement : TService { return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Scoped); } public static ServiceDefinition Transient<TService, TServiceImplement>() where TServiceImplement : TService { return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Transient); }
ServiceContainer
serviceContainer v1
public class ServiceContainer : IServiceContainer
{
internal readonly List<ServiceDefinition> _services;
private readonly ConcurrentDictionary<Type, object> _singletonInstances;
private readonly ConcurrentDictionary<Type, object> _scopedInstances;
private readonly List<object> _transientDisposables = new List<object>();
private readonly bool _isRootScope;
public ServiceContainer()
{
_isRootScope = true;
_singletonInstances = new ConcurrentDictionary<Type, object>();
_services = new List<ServiceDefinition>();
}
internal ServiceContainer(ServiceContainer serviceContainer)
{
_isRootScope = false;
_singletonInstances = serviceContainer._singletonInstances;
_services = serviceContainer._services;
_scopedInstances = new ConcurrentDictionary<Type, object>();
}
public void Add(ServiceDefinition item)
{
_services.Add(item);
}
public IServiceContainer CreateScope()
{
return new ServiceContainer(this);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
{
return;
}
if (_isRootScope)
{
lock (_singletonInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _singletonInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
}
}
else
{
lock (_scopedInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _scopedInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
}
}
}
private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
{
if (serviceDefinition.ImplementationInstance != null)
return serviceDefinition.ImplementationInstance;
if (serviceDefinition.ImplementationFactory != null)
return serviceDefinition.ImplementationFactory.Invoke(this);
var implementType = (serviceDefinition.ImplementType ?? serviceType);
if (implementType.IsInterface || implementType.IsAbstract)
{
throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
}
var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
if (ctorInfos.Length == 0)
{
throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
}
ConstructorInfo ctor;
if (ctorInfos.Length == 1)
{
ctor = ctorInfos[0];
}
else
{
// try find best ctor
ctor = ctorInfos
.OrderBy(_ => _.GetParameters().Length)
.First();
}
var parameters = ctor.GetParameters();
if (parameters.Length == 0)
{
// TODO: cache New Func
return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
}
else
{
var ctorParams = new object[parameters.Length];
for (var index = 0; index < parameters.Length; index++)
{
var parameter = parameters[index];
var param = GetService(parameter.ParameterType);
if (param == null && parameter.HasDefaultValue)
{
param = parameter.DefaultValue;
}
ctorParams[index] = param;
}
return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
}
}
public object GetService(Type serviceType)
{
var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
if (null == serviceDefinition)
{
return null;
}
if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
}
if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
{
var svc = _singletonInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
return svc;
}
else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
var svc = _scopedInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
return svc;
}
else
{
var svc = GetServiceInstance(serviceType, serviceDefinition);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
return svc;
}
}
}
為了使得服務註冊更加方便,可以寫一些擴充套件方法來方便註冊:
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]TService service)
{
serviceContainer.Add(new ServiceDefinition(service, typeof(TService)));
return serviceContainer;
}
public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Singleton));
return serviceContainer;
}
public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Singleton));
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Singleton<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Singleton<TService>());
return serviceContainer;
}
public static IServiceContainer AddSingleton<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Singleton<TService, TServiceImplement>());
return serviceContainer;
}
public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Scoped));
return serviceContainer;
}
public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Scoped));
return serviceContainer;
}
public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Scoped<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Scoped<TService>());
return serviceContainer;
}
public static IServiceContainer AddScoped<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Scoped<TService, TServiceImplement>());
return serviceContainer;
}
public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Transient));
return serviceContainer;
}
public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Transient));
return serviceContainer;
}
public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
serviceContainer.Add(ServiceDefinition.Transient<TService>(func));
return serviceContainer;
}
public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer)
{
serviceContainer.Add(ServiceDefinition.Transient<TService>());
return serviceContainer;
}
public static IServiceContainer AddTransient<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
serviceContainer.Add(ServiceDefinition.Transient<TService, TServiceImplement>());
return serviceContainer;
}
通過上面的程式碼就可以實現基本依賴注入了,但是從功能上來說,上面的程式碼只支援獲取單個服務的例項,不支援註冊一個介面的多個實現,獲取介面的所有實現,為此對 ServiceContainer
中的 Instance 的 ConcurrentDictionary
的 Key 進行一下改造,使得可以能夠以介面型別和實現型別聯合作為 key,為此就有了第二版的 ServiceContainer
ServiceContainer
v2
為此定義了一個 ServiceKey
的型別,請注意這裡一定要重寫 GetHashCode
方法:
private class ServiceKey : IEquatable<ServiceKey>
{
public Type ServiceType { get; }
public Type ImplementType { get; }
public ServiceKey(Type serviceType, ServiceDefinition definition)
{
ServiceType = serviceType;
ImplementType = definition.GetImplementType();
}
public bool Equals(ServiceKey other)
{
return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
}
public override bool Equals(object obj)
{
return Equals((ServiceKey)obj);
}
public override int GetHashCode()
{
var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
return key.GetHashCode();
}
}
第二版的 ServiceContainer
:
public class ServiceContainer : IServiceContainer
{
internal readonly ConcurrentBag<ServiceDefinition> _services;
private readonly ConcurrentDictionary<ServiceKey, object> _singletonInstances;
private readonly ConcurrentDictionary<ServiceKey, object> _scopedInstances;
private ConcurrentBag<object> _transientDisposables = new ConcurrentBag<object>();
private class ServiceKey : IEquatable<ServiceKey>
{
public Type ServiceType { get; }
public Type ImplementType { get; }
public ServiceKey(Type serviceType, ServiceDefinition definition)
{
ServiceType = serviceType;
ImplementType = definition.GetImplementType();
}
public bool Equals(ServiceKey other)
{
return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
}
public override bool Equals(object obj)
{
return Equals((ServiceKey)obj);
}
public override int GetHashCode()
{
var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
return key.GetHashCode();
}
}
private readonly bool _isRootScope;
public ServiceContainer()
{
_isRootScope = true;
_singletonInstances = new ConcurrentDictionary<ServiceKey, object>();
_services = new ConcurrentBag<ServiceDefinition>();
}
private ServiceContainer(ServiceContainer serviceContainer)
{
_isRootScope = false;
_singletonInstances = serviceContainer._singletonInstances;
_services = serviceContainer._services;
_scopedInstances = new ConcurrentDictionary<ServiceKey, object>();
}
public IServiceContainer Add(ServiceDefinition item)
{
if (_disposed)
{
throw new InvalidOperationException("the service container had been disposed");
}
if (_services.Any(_ => _.ServiceType == item.ServiceType && _.GetImplementType() == item.GetImplementType()))
{
return this;
}
_services.Add(item);
return this;
}
public IServiceContainer TryAdd(ServiceDefinition item)
{
if (_disposed)
{
throw new InvalidOperationException("the service container had been disposed");
}
if (_services.Any(_ => _.ServiceType == item.ServiceType))
{
return this;
}
_services.Add(item);
return this;
}
public IServiceContainer CreateScope()
{
return new ServiceContainer(this);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
{
return;
}
if (_isRootScope)
{
lock (_singletonInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _singletonInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
_singletonInstances.Clear();
_transientDisposables = null;
}
}
else
{
lock (_scopedInstances)
{
if (_disposed)
{
return;
}
_disposed = true;
foreach (var instance in _scopedInstances.Values)
{
(instance as IDisposable)?.Dispose();
}
foreach (var o in _transientDisposables)
{
(o as IDisposable)?.Dispose();
}
_scopedInstances.Clear();
_transientDisposables = null;
}
}
}
private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
{
if (serviceDefinition.ImplementationInstance != null)
return serviceDefinition.ImplementationInstance;
if (serviceDefinition.ImplementationFactory != null)
return serviceDefinition.ImplementationFactory.Invoke(this);
var implementType = (serviceDefinition.ImplementType ?? serviceType);
if (implementType.IsInterface || implementType.IsAbstract)
{
throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
}
if (implementType.IsGenericType)
{
implementType = implementType.MakeGenericType(serviceType.GetGenericArguments());
}
var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
if (ctorInfos.Length == 0)
{
throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
}
ConstructorInfo ctor;
if (ctorInfos.Length == 1)
{
ctor = ctorInfos[0];
}
else
{
// TODO: try find best ctor
ctor = ctorInfos
.OrderBy(_ => _.GetParameters().Length)
.First();
}
var parameters = ctor.GetParameters();
if (parameters.Length == 0)
{
// TODO: cache New Func
return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
}
else
{
var ctorParams = new object[parameters.Length];
for (var index = 0; index < parameters.Length; index++)
{
var parameter = parameters[index];
var param = GetService(parameter.ParameterType);
if (param == null && parameter.HasDefaultValue)
{
param = parameter.DefaultValue;
}
ctorParams[index] = param;
}
return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
}
}
public object GetService(Type serviceType)
{
if (_disposed)
{
throw new InvalidOperationException($"can not get scope service from a disposed scope, serviceType: {serviceType.FullName}");
}
var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
if (null == serviceDefinition)
{
if (serviceType.IsGenericType)
{
var genericType = serviceType.GetGenericTypeDefinition();
serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == genericType);
if (null == serviceDefinition)
{
var innerServiceType = serviceType.GetGenericArguments().First();
if (typeof(IEnumerable<>).MakeGenericType(innerServiceType)
.IsAssignableFrom(serviceType))
{
var innerRegType = innerServiceType;
if (innerServiceType.IsGenericType)
{
innerRegType = innerServiceType.GetGenericTypeDefinition();
}
//
var list = new List<object>(4);
foreach (var def in _services.Where(_ => _.ServiceType == innerRegType))
{
object svc;
if (def.ServiceLifetime == ServiceLifetime.Singleton)
{
svc = _singletonInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
}
else if (def.ServiceLifetime == ServiceLifetime.Scoped)
{
svc = _scopedInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
}
else
{
svc = GetServiceInstance(innerServiceType, def);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
}
if (null != svc)
{
list.Add(svc);
}
}
var methodInfo = typeof(Enumerable)
.GetMethod("Cast", BindingFlags.Static | BindingFlags.Public);
if (methodInfo != null)
{
var genericMethod = methodInfo.MakeGenericMethod(innerServiceType);
var castedValue = genericMethod.Invoke(null, new object[] { list });
if (typeof(IEnumerable<>).MakeGenericType(innerServiceType) == serviceType)
{
return castedValue;
}
var toArrayMethod = typeof(Enumerable).GetMethod("ToArray", BindingFlags.Static | BindingFlags.Public)
.MakeGenericMethod(innerServiceType);
return toArrayMethod.Invoke(null, new object[] { castedValue });
}
return list;
}
return null;
}
}
else
{
return null;
}
}
if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
}
if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
{
var svc = _singletonInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
return svc;
}
else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
{
var svc = _scopedInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
return svc;
}
else
{
var svc = GetServiceInstance(serviceType, serviceDefinition);
if (svc is IDisposable)
{
_transientDisposables.Add(svc);
}
return svc;
}
}
}
這樣我們就不僅支援了 IEnumerable<TService>
的註冊,也支援 IReadOnlyList<TService>
/ IReadOnlyCollection<TService>
的註冊
因為 GetService
返回是 object , 不是強型別的,所以為了使用起來方便,定義了幾個擴充套件方法,類似於微軟的依賴注入框架裡的 GetService<TService>()
/GetServices<TService>()
/GetRequiredService<TService>()
/// <summary>
/// ResolveService
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveService<TService>([NotNull]this IServiceProvider serviceProvider)
=> (TService)serviceProvider.GetService(typeof(TService));
/// <summary>
/// ResolveRequiredService
/// throw exception if can not get a service instance
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveRequiredService<TService>([NotNull] this IServiceProvider serviceProvider)
{
var serviceType = typeof(TService);
var svc = serviceProvider.GetService(serviceType);
if (null == svc)
{
throw new InvalidOperationException($"service had not been registered, serviceType: {serviceType}");
}
return (TService)svc;
}
/// <summary>
/// Resolve services
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static IEnumerable<TService> ResolveServices<TService>([NotNull]this IServiceProvider serviceProvider)
=> serviceProvider.ResolveService<IEnumerable<TService>>();
More
後面還更新了一版,主要優化效能,目前來說還不太滿意,暫時這裡先不提了
Reference
- Dynamic Casting using Reflection
- https://github.com/WeihanLi/WeihanLi.Common/blob/dev/test/WeihanLi.Common.Test/DependencyInjectionTest.cs