EFCore 仓储模式的实现

仓储模式的EF实现

Repository Pattern, 解耦应用层与ORM层,提供对ORM层统一的API的访问。再配合DI,可以很方便的实现数据库的访问。下面介绍一下针对EF 的仓储模式的实现,以及DI的注册。
仓储模式代码

public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
    {
        private readonly DbContext _context;
        public virtual DbSet<TEntity> Table => this._context.Set<TEntity>();

        public Repository(DbContext context)
        {
            this._context = context ?? throw new ArgumentNullException(nameof(context));
        }

        protected virtual void AttachIfNot(TEntity entity)
        {
            var entry = _context.ChangeTracker.Entries().FirstOrDefault(each => each == entity);
            if (entry != null)
            {
                return;
            }
            Table.Attach(entity);
        }

        public int Count()
        {
            return GetAll().Count();
        }

        public int Count(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).Count();
        }

        public async Task<int> CountAsync()
        {
            return await GetAll().CountAsync();
        }

        public async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).CountAsync();
        }

        public void Delete(TEntity entity)
        {
            AttachIfNot(entity);
            Table.Remove(entity);
        }
        public void DeleteAndSaveToDb(TEntity entity)
        {
            Delete(entity);
            _context.SaveChanges();
        }

        public TEntity First(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).First();
        }

        public async Task<TEntity> FirstAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).FirstAsync();
        }

        public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).FirstOrDefault();
        }

        public async Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).FirstOrDefaultAsync();
        }

        public IQueryable<TEntity> GetAll()
        {
            return Table;
        }

        public IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            if (propertySelectors == null || propertySelectors.Length == 0)
            {
                return GetAll();
            }
            var query = GetAll();
            foreach (var eachSelector in propertySelectors)
            {
                query = query.Include(eachSelector);
            }
            return query;
        }

        public List<TEntity> GetAllList()
        {
           return GetAll().ToList();
        }

        public List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).ToList();
        }

        public async Task<List<TEntity>> GetAllListAsync()
        {
            return await GetAll().ToListAsync();
        }

        public async Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).ToListAsync();
        }

        public TEntity Insert(TEntity entity)
        {
            return Table.Add(entity).Entity;
        }

        public TEntity InsertAndSaveToDb(TEntity entity)
        {
            var inserted = Insert(entity);
            SaveChanges();
            return inserted;
        }

        public async Task<TEntity> InsertAndSaveToDbAsync(TEntity entity)
        {
            var inserted = await InsertAsync(entity);
            await SaveChangesAsync();
            return inserted;
        }

        public async Task<TEntity> InsertAsync(TEntity entity)
        {
            return (await Table.AddAsync(entity)).Entity;
        }

        public T Query<T>(Func<IQueryable<TEntity>, T> queryMethod)
        {
            return queryMethod(GetAll());
        }

        public int SaveChanges()
        {
            return _context.SaveChanges();
        }

        public async Task<int> SaveChangesAsync()
        {
            return await _context.SaveChangesAsync();
        }

        public TEntity Single(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).Single();
        }

        public async Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).SingleAsync();
        }

        public TEntity SingleOrDefault(Expression<Func<TEntity, bool>> predicate)
        {
            return GetAll().Where(predicate).SingleOrDefault();
        }

        public async Task<TEntity> SingleOrDefaultAsync(Expression<Func<TEntity, bool>> predicate)
        {
            return await GetAll().Where(predicate).SingleOrDefaultAsync();
        }

        public TEntity Update(TEntity entity)
        {
            AttachIfNot(entity);
            return Table.Update(entity).Entity;
        }

        public TEntity UpdateAndSaveToDb(TEntity entity)
        {
            var updated = Update(entity);
            SaveChanges();
            return updated;
        }

        public async Task<TEntity> UpdateAndSaveToDbAsync(TEntity entity)
        {
            var updated = await UpdateAsync(entity);
            await SaveChangesAsync();
            return updated;
        }

        public Task<TEntity> UpdateAsync(TEntity entity)
        {
            return Task.FromResult(Update(entity));
        }
    }

下面是接口

public interface IRepository<TEntity> where TEntity : class
    {
        IQueryable<TEntity> GetAll();
        IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors);
        T Query<T>(Func<IQueryable<TEntity>, T> queryMethod);


        List<TEntity> GetAllList();
        Task<List<TEntity>> GetAllListAsync();
        List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate);
        Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate);

        TEntity Single(Expression<Func<TEntity, bool>> predicate);
        Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate);
        TEntity SingleOrDefault(Expression<Func<TEntity, bool>> predicate);
        Task<TEntity> SingleOrDefaultAsync(Expression<Func<TEntity, bool>> predicate);

        TEntity First(Expression<Func<TEntity, bool>> predicate);
        Task<TEntity> FirstAsync(Expression<Func<TEntity, bool>> predicate);
        TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate);
        Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate);

        TEntity Insert(TEntity entity);
        Task<TEntity> InsertAsync(TEntity entity);
        TEntity InsertAndSaveToDb(TEntity entity);
        Task<TEntity> InsertAndSaveToDbAsync(TEntity entity);


        TEntity Update(TEntity entity);
        Task<TEntity> UpdateAsync(TEntity entity);
        TEntity UpdateAndSaveToDb(TEntity entity);
        Task<TEntity> UpdateAndSaveToDbAsync(TEntity entity);


        void Delete(TEntity entity);
        void DeleteAndSaveToDb(TEntity entity);

        int Count();
        Task<int> CountAsync();
        int Count(Expression<Func<TEntity, bool>> predicate);
        Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate);


        int SaveChanges();
        Task<int> SaveChangesAsync();
    }

最后是DI的注册代码,它实际是通过反射的方式来针对每一个Table实现一个Repository的注册.

public static IEnumerable<Type> ExtractTableTypesFromDbContext(Type dbContextType)
{
  return
    from property in dbContextType.GetProperties(BindingFlags.Public | BindingFlags.Instance)
    where IsAssignableToGeneralType(property.PropertyType, typeof(DbSet<>))
    select property.PropertyType.GenericTypeArguments[0];
}



public static bool IsAssignableToGeneralType(Type givenType, Type genericType)
{
  if (givenType.IsGenericType && givenType.GetGenericTypeDefinition() == genericType)
  {
    return true;
  }

  foreach (var interfaceType in givenType.GetInterfaces())
  {
    if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == genericType)
    {
      return true;
    }
  }

  if (givenType.BaseType == null)
  {
  return false;
  }

  return IsAssignableToGeneralType(givenType.BaseType, genericType);
}      

// 组测DI的方法
private IServiceCollection RegisterRepository(IServiceCollection services)
{
  foreach (var eachEntity in RepositoryHelper.ExtractTableTypesFromDbContext(typeof(CorPayContext)))
  {
    var repositoryInterfaceType = typeof(IRepository<>)
          .MakeGenericType(eachEntity);
    var repositoryImplementType = typeof(Repository<>)
        .MakeGenericType(eachEntity);
    services.AddScoped(repositoryInterfaceType, repositoryImplementType);                
  }
  return services;
}

  
posted @ 2022-06-26 20:12  kongshu  阅读(527)  评论(0编辑  收藏  举报