EFCore 仓储模式的实现
仓储模式的EF实现
Repository Pattern, 解耦应用层与ORM层,提供对ORM层统一的API的访问。再配合DI,可以很方便的实现数据库的访问。下面介绍一下针对EF 的仓储模式的实现,以及DI的注册。
仓储模式代码
public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
{
private readonly DbContext _context;
public virtual DbSet<TEntity> Table => this._context.Set<TEntity>();
public Repository(DbContext context)
{
this._context = context ?? throw new ArgumentNullException(nameof(context));
}
protected virtual void AttachIfNot(TEntity entity)
{
var entry = _context.ChangeTracker.Entries().FirstOrDefault(each => each == entity);
if (entry != null)
{
return;
}
Table.Attach(entity);
}
public int Count()
{
return GetAll().Count();
}
public int Count(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).Count();
}
public async Task<int> CountAsync()
{
return await GetAll().CountAsync();
}
public async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).CountAsync();
}
public void Delete(TEntity entity)
{
AttachIfNot(entity);
Table.Remove(entity);
}
public void DeleteAndSaveToDb(TEntity entity)
{
Delete(entity);
_context.SaveChanges();
}
public TEntity First(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).First();
}
public async Task<TEntity> FirstAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).FirstAsync();
}
public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).FirstOrDefault();
}
public async Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).FirstOrDefaultAsync();
}
public IQueryable<TEntity> GetAll()
{
return Table;
}
public IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors)
{
if (propertySelectors == null || propertySelectors.Length == 0)
{
return GetAll();
}
var query = GetAll();
foreach (var eachSelector in propertySelectors)
{
query = query.Include(eachSelector);
}
return query;
}
public List<TEntity> GetAllList()
{
return GetAll().ToList();
}
public List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).ToList();
}
public async Task<List<TEntity>> GetAllListAsync()
{
return await GetAll().ToListAsync();
}
public async Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).ToListAsync();
}
public TEntity Insert(TEntity entity)
{
return Table.Add(entity).Entity;
}
public TEntity InsertAndSaveToDb(TEntity entity)
{
var inserted = Insert(entity);
SaveChanges();
return inserted;
}
public async Task<TEntity> InsertAndSaveToDbAsync(TEntity entity)
{
var inserted = await InsertAsync(entity);
await SaveChangesAsync();
return inserted;
}
public async Task<TEntity> InsertAsync(TEntity entity)
{
return (await Table.AddAsync(entity)).Entity;
}
public T Query<T>(Func<IQueryable<TEntity>, T> queryMethod)
{
return queryMethod(GetAll());
}
public int SaveChanges()
{
return _context.SaveChanges();
}
public async Task<int> SaveChangesAsync()
{
return await _context.SaveChangesAsync();
}
public TEntity Single(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).Single();
}
public async Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).SingleAsync();
}
public TEntity SingleOrDefault(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).SingleOrDefault();
}
public async Task<TEntity> SingleOrDefaultAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).SingleOrDefaultAsync();
}
public TEntity Update(TEntity entity)
{
AttachIfNot(entity);
return Table.Update(entity).Entity;
}
public TEntity UpdateAndSaveToDb(TEntity entity)
{
var updated = Update(entity);
SaveChanges();
return updated;
}
public async Task<TEntity> UpdateAndSaveToDbAsync(TEntity entity)
{
var updated = await UpdateAsync(entity);
await SaveChangesAsync();
return updated;
}
public Task<TEntity> UpdateAsync(TEntity entity)
{
return Task.FromResult(Update(entity));
}
}
下面是接口
public interface IRepository<TEntity> where TEntity : class
{
IQueryable<TEntity> GetAll();
IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors);
T Query<T>(Func<IQueryable<TEntity>, T> queryMethod);
List<TEntity> GetAllList();
Task<List<TEntity>> GetAllListAsync();
List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate);
Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate);
TEntity Single(Expression<Func<TEntity, bool>> predicate);
Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate);
TEntity SingleOrDefault(Expression<Func<TEntity, bool>> predicate);
Task<TEntity> SingleOrDefaultAsync(Expression<Func<TEntity, bool>> predicate);
TEntity First(Expression<Func<TEntity, bool>> predicate);
Task<TEntity> FirstAsync(Expression<Func<TEntity, bool>> predicate);
TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate);
Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate);
TEntity Insert(TEntity entity);
Task<TEntity> InsertAsync(TEntity entity);
TEntity InsertAndSaveToDb(TEntity entity);
Task<TEntity> InsertAndSaveToDbAsync(TEntity entity);
TEntity Update(TEntity entity);
Task<TEntity> UpdateAsync(TEntity entity);
TEntity UpdateAndSaveToDb(TEntity entity);
Task<TEntity> UpdateAndSaveToDbAsync(TEntity entity);
void Delete(TEntity entity);
void DeleteAndSaveToDb(TEntity entity);
int Count();
Task<int> CountAsync();
int Count(Expression<Func<TEntity, bool>> predicate);
Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate);
int SaveChanges();
Task<int> SaveChangesAsync();
}
最后是DI的注册代码,它实际是通过反射的方式来针对每一个Table实现一个Repository的注册.
public static IEnumerable<Type> ExtractTableTypesFromDbContext(Type dbContextType)
{
return
from property in dbContextType.GetProperties(BindingFlags.Public | BindingFlags.Instance)
where IsAssignableToGeneralType(property.PropertyType, typeof(DbSet<>))
select property.PropertyType.GenericTypeArguments[0];
}
public static bool IsAssignableToGeneralType(Type givenType, Type genericType)
{
if (givenType.IsGenericType && givenType.GetGenericTypeDefinition() == genericType)
{
return true;
}
foreach (var interfaceType in givenType.GetInterfaces())
{
if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == genericType)
{
return true;
}
}
if (givenType.BaseType == null)
{
return false;
}
return IsAssignableToGeneralType(givenType.BaseType, genericType);
}
// 组测DI的方法
private IServiceCollection RegisterRepository(IServiceCollection services)
{
foreach (var eachEntity in RepositoryHelper.ExtractTableTypesFromDbContext(typeof(CorPayContext)))
{
var repositoryInterfaceType = typeof(IRepository<>)
.MakeGenericType(eachEntity);
var repositoryImplementType = typeof(Repository<>)
.MakeGenericType(eachEntity);
services.AddScoped(repositoryInterfaceType, repositoryImplementType);
}
return services;
}