基于Entity Framework 4.1实现一个适用于单元测试的MockDbContext(下)
Posted on 2011-07-24 05:13 Saar 阅读(2236) 评论(7) 编辑 收藏 举报上篇中提到,我们在利用修改做的MockDbContext进行单元测试时,在获取数据时出现了问题。原因在于,在写获取的代码时,我们直接调用了DbContext的Set<T>()方法,而这个方法会从数据库中取数据。
我们来看一个例子:
这是一个简化的业务逻辑代码来获取全部的Handbook:
public IList<Handbook> GetMyHandbooks()
{
var handbookSet = dbContext.Set<Handbook>();
return handbookSet.ToList();
}
对应的单元测试:
1: [TestMethod]
2: public void GetMyHandBooksTest()
3: {
4: var mockDbContext = new MockMTBContainer();
5: BizHandbook target = new BizHandbook(mockDbContext);
6:
7: mockDbContext.Handbooks1.Add(new Handbook() { HandbookID = 1 });
8: mockDbContext.Handbooks1.Add(new Handbook() { HandbookID = 2 });
9:
10: var result = target.GetMyHandbooks();
11: Assert.IsNotNull(result);
12: Assert.AreEqual(2, result.Count);
13: }
行4-5创建一个业务逻辑对象,也是我们的测试目标target,第5行中的构造函数以以mockDbContext为参数,此业务逻辑对象会使用mockDbContext而非默认的DbContext。第10行代码调用了GetMyHandbooks()方法。第11行和12行分别验能够获得结果并且结果集中有两个对象。大家已经知道使用上篇中的mockDbContext,这个单元测试会Fail。第12行Assert预期为2,实际为0。
直接问题出在GetMyHandbooks方法的return语句中的.ToList()方法。这是一个System.Linq中提供的扩展方法,它所做的事情是:调用Expression对象进行数据查询得到一个IEnumerator<TEntity>,然后创建一个List<TEntity>对象,进行迭代将数据填入列表,最后返回。谁提供了Expression?DbContext中的DbSet<TEntity>。
在上篇中,我们仅仅是简单的重写了(或者说,废掉了^_^)DbContext的SaveChanges(),但是在查询数据的时候,Expression仍然会从对应的DB中获取数据。另外,对Local属性的理解也有问题,Local里并不是所有数据,但是新增或变更过的数据。于是,单元测试12行出现actual=0的现象也就得到合理解释了。看来,之前的想法太过简单了。需要重新构思一个mockDbContext。
由于DbSet<TEntity>的Expression属性不是virtual属性,简单重写一下Expression的想法行不通。曾经想过用一个同名属性覆盖它,但是,这样应该也有问题(之所以说应该,是因为没有试过),因为DbContext中使用的都是DbSet<TEntity>类,会调用原来的Expression属性的。
通过以上的分析,我们知道,要让MockDbContext好用,涉及两个类:DbContext和DbSet<TEntity>,因此,现在的思路是,实现两个新的类,替换掉它们。
由于要在业务逻辑不作任何修改的情况下调用DbContext和MockDbContext,这两个类要求实现同一接口;
同理,DbSet<TEntity>和新写的类(叫MockDbSet<TEntity>吧)同样也要实现同一个接口。
我们从MockDbSet<TEntity>着手。由于DbSet<TEntity>实现了IDbSet<TEntity>接口,因此,对于MockDbSet<TEntity>来说,实现IDbSet<TEntity>即可(这个类方法比较多,但都是非常基本的方法,文章最后有下载):
1: public class MockDbSet<TEntity> : IDbSet<TEntity>
2: where TEntity : class
3: {
4: private ObservableCollection<TEntity> storage = new ObservableCollection<TEntity>();
5:
6:
7: public TEntity Add(TEntity entity)
8: {
9: if (entity != null)
10: {
11: storage.Add(entity);
12: }
13: return entity;
14: }
15:
16: public TEntity Attach(TEntity entity)
17: {
18: storage.Add(entity);
19: return entity;
20: }
21:
22: public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
23: {
24: return Activator.CreateInstance<TDerivedEntity>();
25: }
26:
27: public TEntity Create()
28: {
29: return Activator.CreateInstance<TEntity>();
30: }
31:
32: public TEntity Find(params object[] keyValues)
33: {
34: int currentKeyPropertyIndex;
35: foreach (var entity in storage)
36: {
37: currentKeyPropertyIndex = 0;
38: foreach (var property in typeof(TEntity).GetProperties())
39: {
40: if (property.Name.Contains("ID"))
41: {
42: if (property.GetValue(entity).Equals(keyValues[currentKeyPropertyIndex]))
43: {
44: currentKeyPropertyIndex++;
45: if (currentKeyPropertyIndex == keyValues.Length)
46: {
47: return entity;
48: }
49: }
50: }
51: }
52: }
53: return null;
54: }
55:
56: public System.Collections.ObjectModel.ObservableCollection<TEntity> Local
57: {
58: get { return new ObservableCollection<TEntity>(this.storage); }
59: }
60:
61: public TEntity Remove(TEntity entity)
62: {
63: storage.Remove(entity);
64: return entity;
65: }
66:
67: public IEnumerator<TEntity> GetEnumerator()
68: {
69: return this.storage.GetEnumerator();
70: }
71:
72: System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
73: {
74: return this.storage.GetEnumerator();
75: }
76:
77: public Type ElementType
78: {
79: get
80: {
81: return storage.AsQueryable().ElementType;
82: }
83: }
84:
85: public System.Linq.Expressions.Expression Expression
86: {
87: get
88: {
89: return storage.AsQueryable().Expression;
90: }
91: }
92:
93: public IQueryProvider Provider
94: {
95: get
96: {
97: return storage.AsQueryable().Provider;
98: }
99: }
100: }
延续前文思路,在第4行添加了一个ObservableCollection<T>的集合,用作本地存储;增、删、改全部针对这个集合来完成。第85到91行,在返回Expression的时候,把storage.AsQueryable()的表达式返回出去,这样,查询的的时候就会获取到查询storage中元素的表达式。
接下来,我们要把使用DbSet<TEntity>的DbContext和使用MockDbSet<TEntity>的MockDbContext统一起来。查看一下DbContext……呃~没有实现任何接口-.-||(曾经有位伟人说过……算了,不说了)既然没有接口,我们就来创建一个:
using System.Data.Entity;
namespace MTB.Data
{
public interface ITestableDbContext
{
IDbSet<TEntity> Set<TEntity>() where TEntity : class;
int SaveChanges();
}
}
这个接口做两件事,第一,可以实现数据持久化——SaveChanges()。第二,可以获取到IDbSet<TEntity>集合以对集合中数据进行操作。然后,做两件事:第一,让具体的DbContext实现这个接口;第二,写一个MockDbContext实现这个接口。
我们先来看DbContext实现接口的部分:
1: public partial class MTBContainer : DbContext, ITestableDbContext
2: {
3: public MTBContainer()
4: : base("name=MTBContainer")
5: {
6: }
7:
8: public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
9: {
10: return base.Set<TEntity>();
11: }
12:
13: //... Other code...
14:
15: public IDbSet<Handbook> Handbooks1 { get; set; }
16: public IDbSet<Trip> Trips { get; set; }
17: // ... more item
18: }
其中,第8到11行,虽然覆盖了Set<TEntity>(),但调用的仍然是DbContext类中的Set<TEntity>()方法;然后把对应的DbSet<TEntity对象集合全部改为IDbSet<TEntity>,大功告成。由于用的是Model First的EF4.1,因此,其实这个类是通过修改模板而来的。修改过的模板会在文章结束时附上。
MockDbContext实现:
1: public partial class MockMTBContainer : DbContext, ITestableDbContext
2: {
3: public MockMTBContainer()
4: : base("name=MTBContainer")
5: {
6: Handbooks1 = new MockDbSet<Handbook>();
7: Trips = new MockDbSet<Trip>();
// Other code
15: }
16:
17: public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
18: {
19: foreach (PropertyInfo property in typeof(MockMTBContainer).GetProperties())
20: {
21: if (property.PropertyType == typeof(IDbSet<TEntity>))
22: {
23: return property.GetValue(this, null) as IDbSet<TEntity>;
24: }
25: }
26: throw new Exception("Type collection not found");
27: }
28:
29: public override int SaveChanges()
30: {
31: //Do nothing
32: return 0;
33: }
34:
35: // ...
36:
37: public IDbSet<Handbook> Handbooks1 { get; set; }
38: public IDbSet<Trip> Trips { get; set; }
39: // ...
40: }
这个MockDbContext同样继承自DbContext类(嗯,可以少写不少代码呢)。但是,覆盖了Set<TEntity>()方法;当然,第29到33行废掉SaveChanges()的事仍然不可不做——如果想让测试更全面一些,看看那些个增、删、改方法有没有忘调用SaveChanges(),利用这个重写设置一个标志为也不错啊
Okay,一切就绪,我们把这一切综合起来使用:
首先,业务逻辑:
1: public class BizHandbook
2: {
3: ITestableDbContext dbContext = null;
4:
5: #region Constructions
6: public BizHandbook()
7: : this(null)
8: {
9: }
10:
11:
12: public BizHandbook(ITestableDbContext dbContextl)
13: {
14: if (dbContext == null)
15: {
16: dbContext = new MTBContainer();
17: }
18: this.dbContext = dbContext;
19: }
20: #endregion
21:
22: #region Publich Methods
23: public IList<Handbook> GetMyHandbooks()
24: {
25: var handbookSet = dbContext.Set<Handbook>();
26: return handbookSet.ToList();
27: }
28: #endregion
29: }
第12行构造函数会要求一个实现了ITestableDbContext的对象,如果为null,那么使用默认的DbContext。增删改查一律该怎么写就怎么写。模板下载,使用的时候记得修改对应的using和inputFile变量——点击下载。
内容:
1. 使用IDbSet<T>的DbContext的模板;
2. MockDbContext模板;
3. MockDbSet<TEntity>类;
4. ITestableDbContext接口;
对了,顺便提一下,在MockDbSet<TEntity>实现的时候,Find()方法为了方便起见,使用了一个Hack的方法来判断属性是否为Key属性——属性名称中是否含有ID。
Little knowledge is dangerous.