先说一下Unit of Work 是什么:
Unit of Work(工作单元)是一种设计模式,通常用于管理数据库事务和持久化操作。它有助于确保数据操作的一致性和完整性,同时减少不必要的数据库操作,提高性能。
在软件开发中,Unit of Work 模式通常与 Repository 模式一起使用。下面是 Unit of Work 模式的一些关键概念和优点:
关键概念:
-
工作单元(Unit of Work):代表一组相关的操作,通常涉及对数据库的一系列读取和写入操作。它跟踪这些操作,并在事务完成时一起提交或回滚。
-
事务管理:Unit of Work 负责管理事务的开始、提交和回滚。这确保了一组操作要么全部成功提交,要么全部回滚。
-
持久化操作:Unit of Work 负责协调多个 Repository 对象(数据访问层)的操作,以确保数据的一致性。
优点:
-
事务控制:Unit of Work 管理事务,确保一组操作要么全部成功,要么全部失败。
-
性能优化:通过批量提交操作,减少与数据库的交互次数,提高性能。
-
业务逻辑的解耦:将数据访问逻辑与业务逻辑分离,使代码更易于维护和测试。
-
数据一致性:Unit of Work 确保在一组相关操作中数据的一致性,避免不一致状态。
-
实现领域驱动设计:在领域驱动设计中,Unit of Work 可以帮助实现聚合根的一致性。
在实际应用中,开发人员可以根据具体需求实现自己的 Unit of Work 模式,或者使用现有的 ORM 框架(如Entity Framework)提供的 Unit of Work 功能来简化数据操作和事务管理。
经过多次实践,我自己写了一个关于Unit of Work的简单封装(支持多数据源),如下:
1 /// <summary> 2 /// 泛型接口定义工作单元模式,用于管理仓储和事务 3 /// </summary> 4 /// <typeparam name="TContext">工作单元所需的DbContext类型</typeparam> 5 public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext 6 { 7 /// <summary> 8 /// 获取特定实体类型的仓储 9 /// </summary> 10 IBaseRepository<TEntity> Repository<TEntity>() where TEntity : class, new(); 11 12 /// <summary> 13 /// 开始事务 14 /// </summary> 15 void BeginTran(); 16 17 /// <summary> 18 /// 提交事务 19 /// </summary> 20 void Commit(); 21 22 /// <summary> 23 /// 异步提交事务 24 /// </summary> 25 Task CommitAsync(); 26 } 27 28 /// <summary> 29 /// 实现工作单元模式的基本功能 30 /// </summary> 31 /// <typeparam name="TContext">工作单元所需的DbContext类型</typeparam> 32 public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext 33 { 34 private readonly TContext _context; 35 private readonly ConcurrentDictionary<string, dynamic> _repositories; 36 private IDbContextTransaction _transaction; 37 private bool _disposed; 38 39 /// <summary> 40 /// 构造函数 41 /// </summary> 42 /// <param name="context">工作单元所需的DbContext实例</param> 43 public UnitOfWork(TContext context) 44 { 45 _context = context; 46 _repositories = new ConcurrentDictionary<string, dynamic>(); 47 } 48 49 /// <inheritdoc/> 50 public IBaseRepository<TEntity> Repository<TEntity>() where TEntity : class, new() 51 { 52 var type = typeof(TEntity).Name; 53 54 return _repositories.GetOrAdd(type, (t) => 55 { 56 var repositoryType = typeof(BaseRepository<>); 57 return Activator.CreateInstance(repositoryType.MakeGenericType(typeof(TEntity)), _context) as IBaseRepository<TEntity>; 58 }); 59 } 60 61 /// <inheritdoc/> 62 public void BeginTran() 63 { 64 if (_transaction != null)// 防止重复开始事务 65 { 66 throw new InvalidOperationException("A transaction is already in progress."); 67 } 68 _transaction = _context.Database.BeginTransaction(); 69 } 70 71 /// <inheritdoc/> 72 public void Commit() 73 { 74 try 75 { 76 _context.SaveChanges(); 77 _transaction?.Commit(); 78 } 79 catch 80 { 81 _transaction?.Rollback(); 82 throw; 83 } 84 finally 85 { 86 _transaction?.Dispose(); 87 _transaction = null; 88 } 89 } 90 91 /// <inheritdoc/> 92 public async Task CommitAsync() 93 { 94 try 95 { 96 await _context.SaveChangesAsync(); 97 if (_transaction != null) 98 { 99 await _transaction.CommitAsync(); 100 } 101 } 102 catch 103 { 104 if (_transaction != null) 105 { 106 await _transaction.RollbackAsync(); 107 } 108 109 throw; 110 } 111 finally 112 { 113 if (_transaction != null) 114 { 115 await _transaction.DisposeAsync(); 116 _transaction = null; 117 } 118 } 119 } 120 121 /// <inheritdoc/> 122 public void Dispose() 123 { 124 if (!_disposed) 125 { 126 _transaction?.Dispose(); 127 _context.Dispose(); 128 _disposed = true; 129 } 130 131 GC.SuppressFinalize(this); 132 } 133 }
1 /// <summary> 2 /// 表示基础仓储的接口,用于对实体进行增删改查操作。 3 /// </summary> 4 /// <typeparam name="TEntity">实体类型。</typeparam> 5 /// <remarks> 6 /// 作者:我只吃饭不洗碗 7 /// 创建日期:2024/1/29 8 /// </remarks> 9 public interface IBaseRepository<TEntity> where TEntity : class, new() 10 { 11 /// <summary> 12 /// 根据 ID 获取实体。 13 /// </summary> 14 /// <param name="id">实体的 ID。</param> 15 /// <returns>找到的实体,如果找不到则为 null。</returns> 16 TEntity GetById(object id); 17 18 /// <summary> 19 /// 获取所有实体。 20 /// </summary> 21 /// <returns>所有实体的集合。</returns> 22 List<TEntity> GetAll(); 23 24 /// <summary> 25 /// 根据条件查找实体。 26 /// </summary> 27 /// <param name="predicate">筛选条件的表达式。</param> 28 /// <returns>符合条件的实体的集合。</returns> 29 List<TEntity> QueryList(Expression<Func<TEntity, bool>> predicate); 30 31 /// <summary> 32 /// 根据条件查找实体的查询对象。 33 /// </summary> 34 /// <param name="predicate">筛选条件的表达式。</param> 35 /// <returns>符合条件的实体的查询对象。</returns> 36 IQueryable<TEntity> Query(Expression<Func<TEntity, bool>> predicate); 37 38 /// <summary> 39 /// 根据条件查找实体的第一个或默认实体。 40 /// </summary> 41 /// <param name="predicate">筛选条件的表达式。</param> 42 /// <returns>符合条件的第一个或默认实体。</returns> 43 TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate); 44 45 /// <summary> 46 /// 添加实体。 47 /// </summary> 48 /// <param name="entity">要添加的实体。</param> 49 void Add(TEntity entity); 50 51 /// <summary> 52 /// 批量添加实体。 53 /// </summary> 54 /// <param name="entities">要添加的实体集合。</param> 55 void AddRange(IEnumerable<TEntity> entities); 56 57 /// <summary> 58 /// 删除实体。 59 /// </summary> 60 /// <param name="entity">要删除的实体。</param> 61 void Remove(TEntity entity); 62 63 /// <summary> 64 /// 批量删除实体。 65 /// </summary> 66 /// <param name="entities">要删除的实体集合。</param> 67 void RemoveRange(IEnumerable<TEntity> entities); 68 69 /// <summary> 70 /// 批量删除实体。 71 /// </summary> 72 /// <param name="entities">要删除的实体主键。</param> 73 void RemoveRangeByPks(IEnumerable<object> pks); 74 75 /// <summary> 76 /// 更新实体。 77 /// </summary> 78 /// <param name="entity">要更新的实体。</param> 79 void Update(TEntity entity); 80 81 82 /// <summary> 83 /// 根据 ID 获取实体。 84 /// </summary> 85 /// <param name="id">实体的 ID。</param> 86 /// <returns>找到的实体,如果找不到则为 null。</returns> 87 Task<TEntity> GetByIdAsync(object id); 88 89 /// <summary> 90 /// 获取所有实体。 91 /// </summary> 92 /// <returns>所有实体的集合。</returns> 93 Task<List<TEntity>> GetAllAsync(); 94 95 /// <summary> 96 /// 根据条件查找实体。 97 /// </summary> 98 /// <param name="predicate">筛选条件的表达式。</param> 99 /// <returns>符合条件的实体的集合。</returns> 100 Task<List<TEntity>> QueryListAsync(Expression<Func<TEntity, bool>> predicate); 101 102 103 /// <summary> 104 /// 根据条件查找实体的第一个或默认实体。 105 /// </summary> 106 /// <param name="predicate">筛选条件的表达式。</param> 107 /// <returns>符合条件的第一个或默认实体。</returns> 108 Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate); 109 110 /// <summary> 111 /// 添加实体。 112 /// </summary> 113 /// <param name="entity">要添加的实体。</param> 114 Task AddAsync(TEntity entity); 115 116 /// <summary> 117 /// 批量添加实体。 118 /// </summary> 119 /// <param name="entities">要添加的实体集合。</param> 120 Task AddRangeAsync(IEnumerable<TEntity> entities); 121 } 122 123 /// <summary> 124 /// 泛型仓储实现,实现基本的仓储操作。 125 /// </summary> 126 public class BaseRepository<TEntity> : IBaseRepository<TEntity> 127 where TEntity : class, new() 128 129 { 130 private readonly DbContext _dbContext; 131 132 public BaseRepository(DbContext context) 133 { 134 _dbContext = context ?? throw new ArgumentNullException(nameof(context), "The DbContext cannot be null."); 135 } 136 137 public TEntity GetById(object id) 138 { 139 return _dbContext.Set<TEntity>().Find(id); 140 } 141 142 public List<TEntity> GetAll() 143 { 144 return _dbContext.Set<TEntity>().ToList(); 145 } 146 147 public List<TEntity> QueryList(Expression<Func<TEntity, bool>> predicate) 148 { 149 return _dbContext.Set<TEntity>().Where(predicate).ToList(); 150 } 151 152 153 public void Add(TEntity entity) 154 { 155 _dbContext.Set<TEntity>().Add(entity); 156 } 157 158 public void AddRange(IEnumerable<TEntity> entities) 159 { 160 _dbContext.Set<TEntity>().AddRange(entities); 161 } 162 163 public void Remove(TEntity entity) 164 { 165 _dbContext.Set<TEntity>().Remove(entity); 166 } 167 168 public void RemoveRange(IEnumerable<TEntity> entities) 169 { 170 _dbContext.Set<TEntity>().RemoveRange(entities); 171 } 172 173 public void RemoveRangeByPks(IEnumerable<object> pks) 174 { 175 foreach (var pk in pks) 176 { 177 Remove(GetById(pk)); 178 } 179 } 180 181 public void Update(TEntity entity) 182 { 183 _dbContext.Set<TEntity>().Update(entity); 184 } 185 186 public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate) 187 { 188 return _dbContext.Set<TEntity>().Where(predicate).FirstOrDefault(); 189 } 190 191 /// <inheritdoc /> 192 public async Task<TEntity> GetByIdAsync(object id) 193 { 194 return await _dbContext.Set<TEntity>().FindAsync(id); 195 } 196 197 /// <inheritdoc /> 198 public async Task<List<TEntity>> GetAllAsync() 199 { 200 return await _dbContext.Set<TEntity>().ToListAsync(); 201 } 202 203 /// <inheritdoc /> 204 public async Task<List<TEntity>> QueryListAsync(Expression<Func<TEntity, bool>> predicate) 205 { 206 return await _dbContext.Set<TEntity>().Where(predicate).ToListAsync(); 207 } 208 209 /// <inheritdoc /> 210 public IQueryable<TEntity> Query(Expression<Func<TEntity, bool>> predicate) 211 { 212 return _dbContext.Set<TEntity>().Where(predicate); 213 } 214 215 /// <inheritdoc /> 216 public async Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate) 217 { 218 return await _dbContext.Set<TEntity>().FirstOrDefaultAsync(predicate); 219 } 220 221 /// <inheritdoc /> 222 public async Task AddAsync(TEntity entity) 223 { 224 await _dbContext.Set<TEntity>().AddAsync(entity); 225 await _dbContext.SaveChangesAsync(); 226 } 227 228 /// <inheritdoc /> 229 public async Task AddRangeAsync(IEnumerable<TEntity> entities) 230 { 231 await _dbContext.Set<TEntity>().AddRangeAsync(entities); 232 await _dbContext.SaveChangesAsync(); 233 } 234 }
1 public class BContext : DbContext 2 { 3 public BContext(DbContextOptions<BContext> options) : base(options) { } 4 5 public DbSet<EntityB> EntitiesB { get; set; } 6 // ... 其他数据集 7 } 8 9 public class AContext : DbContext 10 { 11 public AContext(DbContextOptions<AContext> options) : base(options) { } 12 13 public DbSet<EntityA> EntitiesA { get; set; } 14 // ... 其他数据集 15 }
1 public HomeController(ILogger<HomeController> logger, IUnitOfWork<AContext> unitOfWorkA, IUnitOfWork<BContext> unitOfWorkB) 2 { 3 unitOfWorkA.BeginTran();//A数据库开始事务 4 unitOfWorkA.Repository<EntityA>().RemoveRangeByPks(new object[] { 6, 7 });//A数据库批量删除 5 unitOfWorkA.Repository<EntityA>().Add(new EntityA() { Id = 6, Name = "nnnn", CreatedAt = DateTime.Now }); 6 unitOfWorkA.Repository<EntityA>().Add(new EntityA() { Id = 7, Name = "nnnn7", CreatedAt = DateTime.Now }); 7 unitOfWorkA.Commit(); 8 9 unitOfWorkB.BeginTran(); //B数据库开始事务 10 unitOfWorkB.Repository<EntityB>().RemoveRangeByPks(new object[] { 16, 27 }); //B数据库批量删除 11 unitOfWorkB.Repository<EntityB>().Add(new EntityB() { Id = 16, Name = "nnnn", CreatedAt = DateTime.Now }); 12 unitOfWorkB.Repository<EntityB>().Add(new EntityB() { Id = 27, Name = "nnnn7", CreatedAt = DateTime.Now }); 13 unitOfWorkB.Commit(); 14 15 var adata = unitOfWorkA.Repository<EntityA>().QueryList(c => c.Id > 1);//A数据库查询结果 16 var bdata = unitOfWorkB.Repository<EntityB>().QueryList(c => c.Id > 1);//B数据库查询结果 17 _logger = logger; 18 }
总结一下代码内容
-
泛型设计:
UnitOfWork<TContext>
的泛型设计提供了灵活性,使得它能够与不同的DbContext
实现一起工作,这增加了代码的重用性。 -
事务管理:通过提供
BeginTran
、Commit
和Rollback
方法,UnitOfWork
类封装了事务管理逻辑,这有助于保持代码的整洁和易于管理。 -
线程安全:使用
ConcurrentDictionary
来缓存仓库实例是一个线程安全的选择,这有助于在多线程环境中避免潜在的并发问题。 -
异步支持:添加异步方法(如
CommitAsync
)有助于提高应用程序在处理数据库操作时的响应性和吞吐量。 -
资源管理:通过实现
IDisposable
接口并正确释放资源,UnitOfWork
类有助于防止内存泄漏和其他资源管理问题。 -
分布式事务:考虑到分布式事务的实现,能够支持更复杂的事务场景。
源码地址 :https://github.com/yycb1994/MultiDataSourceExample