参照:

https://www.cnblogs.com/youhui/articles/10813468.html

接口类:

public  interface IRepository<TEntity,TKey> where TEntity : class
    {
        #region 查找数据
        long Count(Expression<Func<TEntity, bool>> predicate = null);
        Task<long> CountAsync(Expression<Func<TEntity, bool>> predicate = null);
 
        TEntity Get(Expression<Func<TEntity, bool>> predicate, bool isNoTracking);
        Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate, bool isNoTracking);
 
      
        Task<TEntity> GetAsync(TKey id);
 
        IQueryable<TEntity> Load(Expression<Func<TEntity, bool>> predicate , bool isNoTracking);
        Task<IQueryable<TEntity>> LoadAsync(Expression<Func<TEntity, bool>> predicate , bool isNoTracking);
 
        List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate, string ordering, bool isNoTracking );
        Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate, string ordering, bool isNoTracking );
        
        #endregion
 
        #region 插入数据
        bool Insert(TEntity entity, bool isSaveChange);
        Task<bool> InsertAsync(TEntity entity, bool isSaveChange);
        bool Insert(List<TEntity> entitys, bool isSaveChange = true);
        Task<bool> InsertAsync(List<TEntity> entitys, bool isSaveChange);
        #endregion
 
        #region 删除(删除之前需要查询)
        bool Delete(TEntity entity, bool isSaveChange);
        bool Delete(List<TEntity> entitys, bool isSaveChange);
        Task<bool> DeleteAsync(TEntity entity, bool isSaveChange);
        Task<bool> DeleteAsync(List<TEntity> entitys, bool isSaveChange = true);
        #endregion
 
        #region 修改数据
       bool Update(TEntity entity, bool isSaveChange, List<string> updatePropertyList);
       Task<bool> UpdateAsync(TEntity entity, bool isSaveChange, List<string> updatePropertyList);
       bool Update(List<TEntity> entitys, bool isSaveChange);
       Task<bool> UpdateAsync(List<TEntity> entitys, bool isSaveChange );
        #endregion
 
        #region 执行Sql语句
        void BulkInsert<T>(List<T> entities);
        int ExecuteSql(string sql);
        Task<int> ExecuteSqlAsync(string sql);
        int ExecuteSql(string sql, List<DbParameter> spList);
        Task<int> ExecuteSqlAsync(string sql, List<DbParameter> spList);
        DataTable GetDataTableWithSql(string sql);
 
        DataTable GetDataTableWithSql(string sql, List<DbParameter> spList);
 
       
 
        #endregion
}

实现类:

public abstract  class BaseRepository<TEntity,TKey> :IRepository<TEntity,TKey> where TEntity : class
    {
        private readonly DbSet<TEntity> _dbSet;
        public GeneralDbContext _dbContext { get; } = null;
 
        /// <summary>
        /// 连接字符串
        /// </summary>
        protected string _connectionString { get; set; }
 
        /// <summary>
        /// 数据库类型
        /// </summary>
        private DatabaseType _dbType { get; set; }
        public BaseRepository(GeneralDbContext context)
        {
            _dbContext = context;
            _dbSet = _dbContext.Set<TEntity>();
        }
        public DatabaseFacade Database => _dbContext.Database;
        public IQueryable<TEntity> Entities => _dbSet.AsQueryable().AsNoTracking();
        public int SaveChanges()
        {
            return _dbContext.SaveChanges();
        }
        public async Task<int> SaveChangesAsync()
        {
            return await _dbContext.SaveChangesAsync();
        }
        public bool Any(Expression<Func<TEntity, bool>> whereLambd)
        {
            return _dbSet.Where(whereLambd).Any();
        }
        #region 插入数据
        public bool Insert(TEntity entity, bool isSaveChange = true)
        {
            _dbSet.Add(entity);
            if (isSaveChange)
            {
                return SaveChanges() > 0;
            }
            return false;
        }
        public async Task<bool> InsertAsync(TEntity entity, bool isSaveChange = true)
        {
            _dbSet.Add(entity);
            if (isSaveChange)
            {
                return await SaveChangesAsync() > 0;
            }
            return false;
        }
        public bool Insert(List<TEntity> entitys, bool isSaveChange = true)
        {
            _dbSet.AddRange(entitys);
            if (isSaveChange)
            {
                return SaveChanges() > 0;
            }
            return false;
        }
        public async Task<bool> InsertAsync(List<TEntity> entitys, bool isSaveChange = true)
        {
            _dbSet.AddRange(entitys);
            if (isSaveChange)
            {
                return await SaveChangesAsync() > 0;
            }
            return false;
        }
        #endregion
 
        #region 删除
        public bool Delete(TEntity entity, bool isSaveChange = true)
        {
            _dbSet.Attach(entity);
            _dbSet.Remove(entity);
            return isSaveChange ? SaveChanges() > 0 : false;
        }
        public bool Delete(List<TEntity> entitys, bool isSaveChange = true)
        {
            entitys.ForEach(entity =>
            {
                _dbSet.Attach(entity);
                _dbSet.Remove(entity);
            });
            return isSaveChange ? SaveChanges() > 0 : false;
        }
 
        public virtual async Task<bool> DeleteAsync(TEntity entity, bool isSaveChange = true)
        {
 
            _dbSet.Attach(entity);
            _dbSet.Remove(entity);
            return isSaveChange ? await SaveChangesAsync() > 0 : false;
        }
        public virtual async Task<bool> DeleteAsync(List<TEntity> entitys, bool isSaveChange = true)
        {
            entitys.ForEach(entity =>
            {
                _dbSet.Attach(entity);
                _dbSet.Remove(entity);
            });
            return isSaveChange ? await SaveChangesAsync() > 0 : false;
        }
        #endregion
 
        #region 更新数据
        public bool Update(TEntity entity, bool isSaveChange = true, List<string> updatePropertyList = null)
        {
            if (entity == null)
            {
                return false;
            }
            _dbSet.Attach(entity);
            var entry = _dbContext.Entry(entity);
            if (updatePropertyList == null)
            {
                entry.State = EntityState.Modified;//全字段更新
            }
            else
            {
                
                    updatePropertyList.ForEach(c => {
                        entry.Property(c).IsModified = true; //部分字段更新的写法
                    });
                
                
            }
            if (isSaveChange)
            {
                return SaveChanges() > 0;
            }
            return false;
        }
        public bool Update(List<TEntity> entitys, bool isSaveChange = true)
        {
            if (entitys == null || entitys.Count == 0)
            {
                return false;
            }
            entitys.ForEach(c => {
                Update(c, false);
            });
            if (isSaveChange)
            {
                return SaveChanges() > 0;
            }
            return false;
        }
        public async Task<bool> UpdateAsync(TEntity entity, bool isSaveChange = true, List<string> updatePropertyList = null)
        {
            if (entity == null)
            {
                return false;
            }
            _dbSet.Attach(entity);
            var entry = _dbContext.Entry<TEntity>(entity);
            if (updatePropertyList == null)
            {
                entry.State = EntityState.Modified;//全字段更新
            }
            else
            {
                updatePropertyList.ForEach(c => {
                        entry.Property(c).IsModified = true; //部分字段更新的写法
                    });
               
            }
            if (isSaveChange)
            {
                return await SaveChangesAsync() > 0;
            }
            return false;
        }
        public async Task<bool> UpdateAsync(List<TEntity> entitys, bool isSaveChange = true)
        {
            if (entitys == null || entitys.Count == 0)
            {
                return false;
            }
            entitys.ForEach(c => {
                _dbSet.Attach(c);
                _dbContext.Entry<TEntity>(c).State = EntityState.Modified;
            });
            if (isSaveChange)
            {
                return await SaveChangesAsync() > 0;
            }
            return false;
        }
 
 
 
 
      
        
        #endregion
 
        #region 查找
        public long Count(Expression<Func<TEntity, bool>> predicate = null)
        {
            if (predicate == null)
            {
                predicate = c => true;
            }
            return _dbSet.LongCount(predicate);
        }
        public async Task<long> CountAsync(Expression<Func<TEntity, bool>> predicate = null)
        {
            if (predicate == null)
            {
                predicate = c => true;
            }
            return await _dbSet.LongCountAsync(predicate);
        }
        public TEntity Get(TKey id)
        {
            if (id == null)
            {
                return default(TEntity);
            }
            return _dbSet.Find(id);
        }
       
        public TEntity Get(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true)
        {
            var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate);
            return data.FirstOrDefault();
        }
 
 
        public async Task<TEntity> GetAsync(TKey id)
        {
            if (id == null)
            {
                return default(TEntity);
            }
            return await _dbSet.FindAsync(id);
        }
        public async Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true)
        {
            var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate);
            return await data.FirstOrDefaultAsync();
        }
 
        
        public async Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate = null, string ordering = "", bool isNoTracking = true)
        {
            var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate);
            if (!string.IsNullOrEmpty(ordering))
            {
                data = data.OrderByBatch(ordering);
            }
            return await data.ToListAsync();
        }
        public List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate = null, string ordering = "", bool isNoTracking = true)
        {
            var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate);
            if (!string.IsNullOrEmpty(ordering))
            {
                data = data.OrderByBatch(ordering);
            }
            return data.ToList();
        }
        public async Task<IQueryable<TEntity>> LoadAsync(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true)
        {
            if (predicate == null)
            {
                predicate = c => true;
            }
            return await Task.Run(() => isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate));
        }
        public IQueryable<TEntity> Load(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true)
        {
            if (predicate == null)
            {
                predicate = c => true;
            }
            return isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate);
        }
 
 
 
 
 
 
 
 
        #endregion
 
        #region SQL语句
        public virtual void BulkInsert<T>(List<T> entities)
        { }
        public int ExecuteSql(string sql)
        {
            return _dbContext.Database.ExecuteSqlCommand(sql) ;
        }
 
        public Task<int> ExecuteSqlAsync(string sql)
        {
             return _dbContext.Database.ExecuteSqlCommandAsync(sql); 
        }
 
        public int ExecuteSql(string sql, List<DbParameter> spList)
        {
           return  _dbContext.Database.ExecuteSqlCommand(sql, spList.ToArray());
        }
 
        public Task<int>  ExecuteSqlAsync(string sql, List<DbParameter> spList)
        {  
             return  _dbContext.Database.ExecuteSqlCommandAsync(sql, spList.ToArray());
        }
 
       
        public virtual DataTable GetDataTableWithSql(string sql)
        {
            throw new NotImplementedException();
        }
 
        public virtual DataTable GetDataTableWithSql(string sql, List<DbParameter> spList)
        {
            throw new NotImplementedException();
        }
 
        #endregion

批量操作实现类(因不同的数据库sql语句不一样,针对不同的数据,继承BaseRepository这个基类,重写sql语句方法):

public  class SqlServerRepository<TEntity,TKey>: BaseRepository<TEntity,TKey>,IRepository<TEntity,TKey> where TEntity : class
    {
          protected ConfigOption _dbOpion;
        public SqlServerRepository(GeneralDbContext generalDbContext,IOptionsSnapshot<ConfigOption> options)
            :base(generalDbContext)
        {    
            _dbOpion = options.Get("config");
            _connectionString = _dbOpion.ReadWriteHosts;
        }
 
        #region 插入数据
 
        /// <summary>
        /// 使用Bulk批量插入数据(适合大数据量,速度非常快)
        /// </summary>
        /// <typeparam name="T">实体类型</typeparam>
        /// <param name="entities">数据</param>
        public override void BulkInsert<T>(List<T> entities)
        {
            using (SqlConnection conn = new SqlConnection())
            {
                conn.ConnectionString =_connectionString ;
                if (conn.State != ConnectionState.Open)
                {
                    conn.Open();
                }
 
                string tableName = string.Empty;
                var tableAttribute = typeof(T).GetCustomAttributes(typeof(TableAttribute), true).FirstOrDefault();
                if (tableAttribute != null)
                    tableName = ((TableAttribute)tableAttribute).Name;
                else
                    tableName = typeof(T).Name;
 
                SqlBulkCopy sqlBC = new SqlBulkCopy(conn)
                {
                    BatchSize = 100000,
                    BulkCopyTimeout = 0,
                    DestinationTableName = tableName
                };
                using (sqlBC)
                {
                    sqlBC.WriteToServer(entities.ToDataTable());
                }
            }
        }   
       
        public override DataTable GetDataTableWithSql(string sql)
        {
            return   GetDataTableWithSql(sql);
        }
 
        public override DataTable GetDataTableWithSql(string sql, List<DbParameter> spList=null)
        {
            DataTable dt = new DataTable(); ;
            using (SqlConnection conn = new SqlConnection(_connectionString))
            {
                SqlDataAdapter da = new SqlDataAdapter(sql, conn);
                da.SelectCommand.CommandType = CommandType.Text;
                if (spList.ToArray() != null)
                {
                    da.SelectCommand.Parameters.AddRange(spList.ToArray());
                }
                da.Fill(dt);
            }
            return dt;
        }
        #endregion
    }