参照:
https://www.cnblogs.com/youhui/articles/10813468.html
接口类:
public interface IRepository<TEntity,TKey> where TEntity : class { #region 查找数据 long Count(Expression<Func<TEntity, bool>> predicate = null); Task<long> CountAsync(Expression<Func<TEntity, bool>> predicate = null); TEntity Get(Expression<Func<TEntity, bool>> predicate, bool isNoTracking); Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate, bool isNoTracking); Task<TEntity> GetAsync(TKey id); IQueryable<TEntity> Load(Expression<Func<TEntity, bool>> predicate , bool isNoTracking); Task<IQueryable<TEntity>> LoadAsync(Expression<Func<TEntity, bool>> predicate , bool isNoTracking); List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate, string ordering, bool isNoTracking ); Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate, string ordering, bool isNoTracking ); #endregion #region 插入数据 bool Insert(TEntity entity, bool isSaveChange); Task<bool> InsertAsync(TEntity entity, bool isSaveChange); bool Insert(List<TEntity> entitys, bool isSaveChange = true); Task<bool> InsertAsync(List<TEntity> entitys, bool isSaveChange); #endregion #region 删除(删除之前需要查询) bool Delete(TEntity entity, bool isSaveChange); bool Delete(List<TEntity> entitys, bool isSaveChange); Task<bool> DeleteAsync(TEntity entity, bool isSaveChange); Task<bool> DeleteAsync(List<TEntity> entitys, bool isSaveChange = true); #endregion #region 修改数据 bool Update(TEntity entity, bool isSaveChange, List<string> updatePropertyList); Task<bool> UpdateAsync(TEntity entity, bool isSaveChange, List<string> updatePropertyList); bool Update(List<TEntity> entitys, bool isSaveChange); Task<bool> UpdateAsync(List<TEntity> entitys, bool isSaveChange ); #endregion #region 执行Sql语句 void BulkInsert<T>(List<T> entities); int ExecuteSql(string sql); Task<int> ExecuteSqlAsync(string sql); int ExecuteSql(string sql, List<DbParameter> spList); Task<int> ExecuteSqlAsync(string sql, List<DbParameter> spList); DataTable GetDataTableWithSql(string sql); DataTable GetDataTableWithSql(string sql, List<DbParameter> spList); #endregion }
实现类:
public abstract class BaseRepository<TEntity,TKey> :IRepository<TEntity,TKey> where TEntity : class { private readonly DbSet<TEntity> _dbSet; public GeneralDbContext _dbContext { get; } = null; /// <summary> /// 连接字符串 /// </summary> protected string _connectionString { get; set; } /// <summary> /// 数据库类型 /// </summary> private DatabaseType _dbType { get; set; } public BaseRepository(GeneralDbContext context) { _dbContext = context; _dbSet = _dbContext.Set<TEntity>(); } public DatabaseFacade Database => _dbContext.Database; public IQueryable<TEntity> Entities => _dbSet.AsQueryable().AsNoTracking(); public int SaveChanges() { return _dbContext.SaveChanges(); } public async Task<int> SaveChangesAsync() { return await _dbContext.SaveChangesAsync(); } public bool Any(Expression<Func<TEntity, bool>> whereLambd) { return _dbSet.Where(whereLambd).Any(); } #region 插入数据 public bool Insert(TEntity entity, bool isSaveChange = true) { _dbSet.Add(entity); if (isSaveChange) { return SaveChanges() > 0; } return false; } public async Task<bool> InsertAsync(TEntity entity, bool isSaveChange = true) { _dbSet.Add(entity); if (isSaveChange) { return await SaveChangesAsync() > 0; } return false; } public bool Insert(List<TEntity> entitys, bool isSaveChange = true) { _dbSet.AddRange(entitys); if (isSaveChange) { return SaveChanges() > 0; } return false; } public async Task<bool> InsertAsync(List<TEntity> entitys, bool isSaveChange = true) { _dbSet.AddRange(entitys); if (isSaveChange) { return await SaveChangesAsync() > 0; } return false; } #endregion #region 删除 public bool Delete(TEntity entity, bool isSaveChange = true) { _dbSet.Attach(entity); _dbSet.Remove(entity); return isSaveChange ? SaveChanges() > 0 : false; } public bool Delete(List<TEntity> entitys, bool isSaveChange = true) { entitys.ForEach(entity => { _dbSet.Attach(entity); _dbSet.Remove(entity); }); return isSaveChange ? SaveChanges() > 0 : false; } public virtual async Task<bool> DeleteAsync(TEntity entity, bool isSaveChange = true) { _dbSet.Attach(entity); _dbSet.Remove(entity); return isSaveChange ? await SaveChangesAsync() > 0 : false; } public virtual async Task<bool> DeleteAsync(List<TEntity> entitys, bool isSaveChange = true) { entitys.ForEach(entity => { _dbSet.Attach(entity); _dbSet.Remove(entity); }); return isSaveChange ? await SaveChangesAsync() > 0 : false; } #endregion #region 更新数据 public bool Update(TEntity entity, bool isSaveChange = true, List<string> updatePropertyList = null) { if (entity == null) { return false; } _dbSet.Attach(entity); var entry = _dbContext.Entry(entity); if (updatePropertyList == null) { entry.State = EntityState.Modified;//全字段更新 } else { updatePropertyList.ForEach(c => { entry.Property(c).IsModified = true; //部分字段更新的写法 }); } if (isSaveChange) { return SaveChanges() > 0; } return false; } public bool Update(List<TEntity> entitys, bool isSaveChange = true) { if (entitys == null || entitys.Count == 0) { return false; } entitys.ForEach(c => { Update(c, false); }); if (isSaveChange) { return SaveChanges() > 0; } return false; } public async Task<bool> UpdateAsync(TEntity entity, bool isSaveChange = true, List<string> updatePropertyList = null) { if (entity == null) { return false; } _dbSet.Attach(entity); var entry = _dbContext.Entry<TEntity>(entity); if (updatePropertyList == null) { entry.State = EntityState.Modified;//全字段更新 } else { updatePropertyList.ForEach(c => { entry.Property(c).IsModified = true; //部分字段更新的写法 }); } if (isSaveChange) { return await SaveChangesAsync() > 0; } return false; } public async Task<bool> UpdateAsync(List<TEntity> entitys, bool isSaveChange = true) { if (entitys == null || entitys.Count == 0) { return false; } entitys.ForEach(c => { _dbSet.Attach(c); _dbContext.Entry<TEntity>(c).State = EntityState.Modified; }); if (isSaveChange) { return await SaveChangesAsync() > 0; } return false; } #endregion #region 查找 public long Count(Expression<Func<TEntity, bool>> predicate = null) { if (predicate == null) { predicate = c => true; } return _dbSet.LongCount(predicate); } public async Task<long> CountAsync(Expression<Func<TEntity, bool>> predicate = null) { if (predicate == null) { predicate = c => true; } return await _dbSet.LongCountAsync(predicate); } public TEntity Get(TKey id) { if (id == null) { return default(TEntity); } return _dbSet.Find(id); } public TEntity Get(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true) { var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate); return data.FirstOrDefault(); } public async Task<TEntity> GetAsync(TKey id) { if (id == null) { return default(TEntity); } return await _dbSet.FindAsync(id); } public async Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true) { var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate); return await data.FirstOrDefaultAsync(); } public async Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate = null, string ordering = "", bool isNoTracking = true) { var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate); if (!string.IsNullOrEmpty(ordering)) { data = data.OrderByBatch(ordering); } return await data.ToListAsync(); } public List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate = null, string ordering = "", bool isNoTracking = true) { var data = isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate); if (!string.IsNullOrEmpty(ordering)) { data = data.OrderByBatch(ordering); } return data.ToList(); } public async Task<IQueryable<TEntity>> LoadAsync(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true) { if (predicate == null) { predicate = c => true; } return await Task.Run(() => isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate)); } public IQueryable<TEntity> Load(Expression<Func<TEntity, bool>> predicate = null, bool isNoTracking = true) { if (predicate == null) { predicate = c => true; } return isNoTracking ? _dbSet.Where(predicate).AsNoTracking() : _dbSet.Where(predicate); } #endregion #region SQL语句 public virtual void BulkInsert<T>(List<T> entities) { } public int ExecuteSql(string sql) { return _dbContext.Database.ExecuteSqlCommand(sql) ; } public Task<int> ExecuteSqlAsync(string sql) { return _dbContext.Database.ExecuteSqlCommandAsync(sql); } public int ExecuteSql(string sql, List<DbParameter> spList) { return _dbContext.Database.ExecuteSqlCommand(sql, spList.ToArray()); } public Task<int> ExecuteSqlAsync(string sql, List<DbParameter> spList) { return _dbContext.Database.ExecuteSqlCommandAsync(sql, spList.ToArray()); } public virtual DataTable GetDataTableWithSql(string sql) { throw new NotImplementedException(); } public virtual DataTable GetDataTableWithSql(string sql, List<DbParameter> spList) { throw new NotImplementedException(); } #endregion }
批量操作实现类(因不同的数据库sql语句不一样,针对不同的数据,继承BaseRepository这个基类,重写sql语句方法):
public class SqlServerRepository<TEntity,TKey>: BaseRepository<TEntity,TKey>,IRepository<TEntity,TKey> where TEntity : class { protected ConfigOption _dbOpion; public SqlServerRepository(GeneralDbContext generalDbContext,IOptionsSnapshot<ConfigOption> options) :base(generalDbContext) { _dbOpion = options.Get("config"); _connectionString = _dbOpion.ReadWriteHosts; } #region 插入数据 /// <summary> /// 使用Bulk批量插入数据(适合大数据量,速度非常快) /// </summary> /// <typeparam name="T">实体类型</typeparam> /// <param name="entities">数据</param> public override void BulkInsert<T>(List<T> entities) { using (SqlConnection conn = new SqlConnection()) { conn.ConnectionString =_connectionString ; if (conn.State != ConnectionState.Open) { conn.Open(); } string tableName = string.Empty; var tableAttribute = typeof(T).GetCustomAttributes(typeof(TableAttribute), true).FirstOrDefault(); if (tableAttribute != null) tableName = ((TableAttribute)tableAttribute).Name; else tableName = typeof(T).Name; SqlBulkCopy sqlBC = new SqlBulkCopy(conn) { BatchSize = 100000, BulkCopyTimeout = 0, DestinationTableName = tableName }; using (sqlBC) { sqlBC.WriteToServer(entities.ToDataTable()); } } } public override DataTable GetDataTableWithSql(string sql) { return GetDataTableWithSql(sql); } public override DataTable GetDataTableWithSql(string sql, List<DbParameter> spList=null) { DataTable dt = new DataTable(); ; using (SqlConnection conn = new SqlConnection(_connectionString)) { SqlDataAdapter da = new SqlDataAdapter(sql, conn); da.SelectCommand.CommandType = CommandType.Text; if (spList.ToArray() != null) { da.SelectCommand.Parameters.AddRange(spList.ToArray()); } da.Fill(dt); } return dt; } #endregion }