Jtester+unitils+testng:DAO单元测试文件模板自动生成


     定位

     本文适合于不愿意手工编写而想自动化生成DAO单元测试的筒鞋。成果是不能照搬的,但其中的"创建模板、填充内容、自动生成"思想是可以复用的。读完本文,可以了解 Python 读取配置文件、替换字符串相关的知识点。

 

     在使用 jtester+unitils+testng 做数据库接口的单元测试框架中, 常常需要编写一些 wiki 及 DAOTest java 文件,  比如:

public class XXXDefaultDAOTest extends BaseRegionDbDAOTestCase {

    @SpringBeanByName
    private XXXDefaultDAO XXXDefaultDAO;
    
    @Test
    @DbFit(when="XXXDefaultDAOTest.initBlank.when.wiki", then="XXXDefaultDAOTest.queryOneRecord.then.wiki")
    public void testInsertXXXDefaultDO() {
        XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
        XXXDefaultDO.setId(1L);
        XXXDefaultDO.setCidrBlock("192.168.10.10");
        XXXDefaultDO.setIpProtocol("tcp");
        XXXDefaultDO.setPortRange("3000:4000");
        XXXDefaultDO.setPolicy(Policy.POLICY_ACCEPT);
        XXXDefaultDO.setNic(Nic.INTRANET);
        XXXDefaultDO.setPriority(65533L);
        XXXDefaultDO.setType(1L);
        XXXDefaultDO.setIsDeleted(0L);
        XXXDefaultDO.setDescription("test1");
        XXXDefaultDO.setGmtCreate(new Date());
        XXXDefaultDO.setGmtModify(new Date());
        XXXDefaultDAO.insertXXXDefaultDO(XXXDefaultDO);
    }

    @Test
    @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
    public void testCountXXXDefaultDOByExample() {
        XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
        Assert.assertTrue(XXXDefaultDAO.countXXXDefaultDOByExample(XXXDefaultDO) == 1);
    }

    @Test
    @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="XXXDefaultDAOTest.testUpdate.then.wiki")
    public void testUpdateXXXDefaultDO() {
        XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
        found.setIpProtocol("udp");
        found.setNic(Nic.INTERNET);
        found.setDescription("desc");
        XXXDefaultDAO.updateXXXDefaultDO(found);
    }

    @Test
    @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
    public void testFindListByExample() {
        String cidrBlock = "10.152.126.83";
        Policy policy = Policy.POLICY_ACCEPT;
        XXXDefaultDO XXXDefault = new XXXDefaultDO();
        XXXDefault.setCidrBlock(cidrBlock);
        XXXDefault.setPolicy(policy);
        List<XXXDefaultDO> list = XXXDefaultDAO.findListByExample(XXXDefault);
        Assert.assertEquals(list.size(), 1);
        for (XXXDefaultDO XXXDefaultDO: list) {
            Assert.assertEquals(XXXDefaultDO.getCidrBlock(), cidrBlock);
            Assert.assertEquals(XXXDefaultDO.getPolicy(), policy);
        }
    }

    @Test
    @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
    public void testFindXXXDefaultDOByPrimaryKey() {
        XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
        Assert.assertEquals(found.getCidrBlock(), "10.152.126.83");
        Assert.assertEquals(found.getIpProtocol(), "all");
        Assert.assertEquals(found.getPortRange(), "");
        Assert.assertEquals(found.getPolicy(), Policy.POLICY_ACCEPT);
        Assert.assertEquals(found.getNic(), Nic.BOTH);
        Assert.assertEquals(found.getPriority().longValue(),1L);
        Assert.assertEquals(found.getType().intValue(), 1);
        Assert.assertEquals(found.getIsDeleted().intValue(), 0);
        Assert.assertEquals(found.getDescription(), "bie dong");
    }

    @Test
    @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="")
    public void testDeleteXXXDefaultDOByPrimaryKey() {
        Integer count = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
        Assert.assertEquals(count.intValue(), 1);
        
        Integer nodelete = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
        Assert.assertEquals(nodelete.intValue(), 0);
    }

}

       

       其中, 数据准备文件在 *.when.wiki 中, 数据验证文件在 *.then.wiki 中, 数据库中只需要保证正确的表结构即可。 每次单元测试都是自动化可重复的。

XXXDefaultDAOTest.initBlank.when.wiki
|connect|
|clean table|xxx_default|
      
XXXDefaultDAOTest.initRecords.when.wiki |connect| |clean table|xxx_default| |clean table|xxx| |insert|xxx_default| | id | gmt_create | gmt_modify | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 6 | 2014-04-08 20:18:04 | 2014-04-08 20:18:04 | 10.152.126.83 | all | | accept | 3 | 1 | 1 | 0 | bie dong | XXXDefaultDAOTest.queryOneRecord.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default | |cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | |192.168.10.10 | tcp | 3000:4000 | accept | 2 | 65533 | 1 | 0 | test1 | XXXDefaultDAOTest.testUpdate.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default| | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 10.152.126.83 | udp | | accept | 1 | 1 | 1 | 0 | desc |

         

  显然, 如果每个 DAO 测试类都写这些 WIKI  及 DAO 类(set/get 字段很耗体力), 那会是比较大的工作量。 这时候, 最好能够自动生成这些文件或文件模板, 减少手工的劳动量。

 

         因此, 编写了一个 python 程序, 在指定配置下, 可以自动生成相关的测试文件模板文件。

      readcfg.py :  读取DAO测试类信息的配置文件     

from ConfigParser import ConfigParser

config = ConfigParser()
config.read("daotest.conf")

def getAllDAOTestInfo():
    allDAOTest = {}
    secs = config.sections() 
    for sec in secs:
        allDAOTest[sec] = getDAOTestInfo(sec)
    return allDAOTest        

def getDAOTestInfo(daoTestName) :
    daoTestInfo = { 
        'DaoTestName': config.get(daoTestName, 'DaoTestName') ,
        'TableName': config.get(daoTestName, 'TableName'), 
        'FieldArray': config.get(daoTestName, 'FieldArray'),
        'NumTypeFields': config.get(daoTestName, 'NumTypeFields'),
    }
    return daoTestInfo     

 

    create_daotest_wiki.py : 生成 dao 测试的测试文件模板:    

import readcfg
import time
import re

def gene_daotest(daoTestInfo):

    daoTestName = daoTestInfo['DaoTestName']
    tableName = daoTestInfo['TableName']
    fieldArray = re.split('\s*,\s*', daoTestInfo['FieldArray'])
    numTypeFields = set(re.split('\s*,\s*', daoTestInfo['NumTypeFields']))
    
    print ' *** ', daoTestName , ' start...\n'
    
    startTime = time.clock()
    gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields)
    gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields)
    endTime = time.clock()
    
    print ' *** ', daoTestName,  ' finished.\n'
    print 'time cost: ', str((endTime - startTime)*1000) + 'ms.\n'

def gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields):
    
    '''
        generate the wikies used for DAO test java file
    '''
    
    conn = '|connect|'
    clean_table = '|clean table|' + tableName + '|'
    insert_table = '|insert|' + tableName + '|'
    all_fields = '|' + getfieldsWithSep(fieldArray, 0, '|') + '|'
    query_stmt = '|query|' + 'select ' + getfieldsWithSep(fieldArray, 0, ', ', filterTimeAndIdFieldFunc) + ' from ' + tableName + '|'
    query_fields = '|' + getfieldsWithSep(fieldArray, 0, '|', filterTimeAndIdFieldFunc) + '|'
    all_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|') + '|'
    query_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|', filterTimeAndIdFieldFunc) + '|'
    
    # create DaoTestName.initBlank.when.wiki
    f_initBlank = open(daoTestName+".initBlank.when.wiki", 'w')
    f_initBlank.write('\n'.join([conn, clean_table]));
    f_initBlank.close
    
    # create DaoTestName.initRecords.when.wiki
    f_initRecs = open(daoTestName+".initRecords.when.wiki", 'w')
    f_initRecs.write('\n'.join([conn, clean_table, insert_table, all_fields, all_fields_default_values]))
    f_initRecs.close
    
    # create DaoTestName.queryOneRecord.then.wiki
    f_qor = open(daoTestName+".queryOneRecord.then.wiki", 'w')
    f_qor.write('\n'.join([conn, query_stmt, query_fields, query_fields_default_values]))
    f_qor.close
    
    # create DaoTestName.testUpdate.then.wiki
    f_update = open(daoTestName+".testUpdate.then.wiki", 'w')
    f_update.write('\n'.join([conn, query_stmt, query_fields, query_fields_default_values]))
    f_update.close
    
def gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields):
    
    f_daotest_java = open(daoTestName+'.java', 'w')
    f_daotest_tmpl = open('TemplateDefaultDAOTest.java')
    content = ''
    for line in f_daotest_tmpl:
        content += line
    daoPrefixIndex = daoTestName.find('DAOTest')
    daoPrefix = daoTestName[0: daoPrefixIndex]
    XXXReplacer = daoPrefix
    YYYReplacer = firstLowerCase(XXXReplacer)
    filteredFieldArray = getFilteredFields(fieldArray, filterTimeFieldFunc)
    contentReplaced = content.replace('XXX', XXXReplacer).replace('YYY', YYYReplacer)  \
                             .replace('$setFields', geneSetFields(filteredFieldArray, numTypeFields, YYYReplacer)) \
                             .replace('$AssertGetValues', geneAssertGetValues(filteredFieldArray, numTypeFields, YYYReplacer))
    f_daotest_java.write(contentReplaced)


def geneAssertGetValues(fieldArray, numTypeFields, YYYReplacer):
    content = ''
    for field in fieldArray:
        quoteStr = '' if field in numTypeFields else '"'
        content += 'Assert.assertEquals(%s.get%s(), %s%s%s);\n%s' %  \
                   (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2)) 
    return content
    
def geneSetFields(fieldArray, numTypeFields, YYYReplacer):
    content = ''
    for field in fieldArray:
        quoteStr = '' if field in numTypeFields else '"'
        content += '%s.set%s(%s%s%s);\n%s' % \
                  (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2))
    return content
    
def transformField(field):
    '''
       convert field with UnderLine form to Camel Form
       eg.  cidr_block ==> CidrBlock        
    '''
    
    parts = field.split('_')
    content = ''
    for part in parts: 
        content += firstSuperCase(part)
    return content

def indentTimes(num):
    indent = '';
    while num > 0 :
        indent += '\t'
        num -= 1
    return indent
    
def firstLowerCase(input):
    '''
        the first letter lowered. eg. NcDAOTest ==> ncDAOTest
    '''    
    return input[0].lower() + input[1:]
    
def firstSuperCase(input):
    '''
        the first letter uppered. eg. ncDAOTest ==> NcDAOTest
    '''    
    return input[0].upper() + input[1:]    
    
def nopFunc(field):
    return True    

def currTime():
    return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))        
    
def getDefaultValueForField(field, numTypeFields):
    '''
        get the default value of field and return a value
        if you want to set more proper value , do it here.
    '''
    if field == 'id' or field.find('_id') != -1:
        return '1'
    elif field.find('gmt_') != -1:
        return currTime()
    elif field.find('ip') != -1 or field.find('addr') != -1:
        return '172.16.0.1'
    elif field.find('cidr') != -1:
        return '172.16.0.0/22'
    elif field in numTypeFields:
        return '0'
    else :
        return 'test-'+field
        
def getfieldsWithSep(fieldArray, index=0, sep='|', filterFunc=nopFunc):
    if index < 0 or index > len(fieldArray):
        raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
    fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
    return sep.join(fieldFilteredArray[index:])    
    
def getfieldValuesWithSep(fieldArray, numTypeFields, index=0, sep='|', filterFunc=nopFunc):
    if index < 0 or index > len(fieldArray):
        raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
    fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
    fieldDefaultValues = []
    for field in fieldFilteredArray:
        fieldDefaultValues.append(getDefaultValueForField(field, numTypeFields)) 
    return sep.join(fieldDefaultValues)
    
def filterTimeAndIdFieldFunc(field):
    return field.find('gmt_') == -1 and field != 'id'

def filterTimeFieldFunc(field):
    return field.find('gmt_') == -1 
    
def getFilteredFields(fieldArray, filterFunc):
    return filter(filterFunc, fieldArray)
    
if __name__ == '__main__':
    allDAOTest = readcfg.getAllDAOTestInfo()
    for daoTestName, daoTestInfo in allDAOTest.iteritems():
        gene_daotest(daoTestInfo)

 

     daotest.conf:  配置文件  

[VmDAOTest]
DaoTestName=VmDAOTest
TableName=vm
FieldArray=id,gmt_create,gmt_modify,vm_name,cores,mem,disk,status,nc_id,is_deleted
NumTypeFields=id,cores,mem,disk,status,nc_id,is_deleted

[NcDAOTest]
DaoTestName=NcDAOTest
TableName=nc
FieldArray=id,gmt_create,gmt_modify,hostname,ip,avail_cpu, avail_mem, avail_disk
NumTypeFields=id,avail_cpu, avail_mem, avail_disk

 

     DAO java 文件模板: 

package xxx.dao.regiondb.impl;

import java.util.Date;
import java.util.List;

import org.jtester.unitils.dbfit.DbFit;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.unitils.spring.annotation.SpringBeanByName;

import xxx.BaseRegionDbDAOTestCase;
import xxx.constant.group.Nic;
import xxx.constant.group.Policy;
import xxx.dao.regiondb.XXXDAO;
import xxx.model.db.XXXDO;

public class XXXDAOTest extends BaseRegionDbDAOTestCase {

    @SpringBeanByName
    private XXXDAO YYYDAO;
    
    @Test
    @DbFit(when="XXXDAOTest.initBlank.when.wiki", then="XXXDAOTest.queryOneRecord.then.wiki")
    public void testInsertXXXDO() {
        XXXDO YYY = new XXXDO();
        $setFields
        YYY.setGmtCreate(new Date());
        YYY.setGmtModify(new Date());
        YYYDAO.insertXXXDO(YYY);
    }

    @Test
    @DbFit(when="XXXDAOTest.initRecords.when.wiki")
    public void testCountXXXDOByExample() {
        XXXDO YYY = new XXXDO();
        Assert.assertTrue(YYYDAO.countXXXDOByExample(YYY).intValue() == 1);
    }

    @Test
    @DbFit(when="XXXDAOTest.initRecords.when.wiki", then="XXXDAOTest.testUpdate.then.wiki")
    public void testUpdateXXXDO() {
        XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey();
        $setFields
        YYYDAO.updateXXXDO(YYY);
    }

    @Test
    @DbFit(when="XXXDAOTest.initRecords.when.wiki")
    public void testFindListByExample() {
        XXXDO YYY = new XXXDO();
        $setFields
        List<XXXDO> list = YYYDAO.findListByExample(YYY);
        Assert.assertEquals(list.size(), 1);
        for (XXXDO YYYDO: list) {
            $AssertGetValues
        }
    }

    @Test
    @DbFit(when="XXXDAOTest.initRecords.when.wiki")
    public void testFindXXXDOByPrimaryKey() {
        XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey(1L);
        $AssertGetValues
    }

    @Test
    @DbFit(when="XXXDAOTest.initRecords.when.wiki")
    public void testDeleteXXXDOByPrimaryKey() {
        Integer count = YYYDAO.deleteXXXDOByPrimaryKey(1L);
        Assert.assertEquals(count.intValue(), 1);
        
        Integer nodelete = YYYDAO.deleteXXXDOByPrimaryKey(1L);
        Assert.assertEquals(nodelete.intValue(), 0);
    }

}

 

     运行: $ python create_daotest_wiki.py

     生成以下文件: 

     

      其中: 
     

VmDAOTest.initBlank.when.wiki  

|connect|
|clean table|vm|
 

VmDAOTest.initRecords.when.wiki    

|connect|
|clean table|vm|
|insert|vm|
|id|gmt_create|gmt_modify|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|1|2014-05-22 12:51:38|2014-05-22 12:51:38|test-vm_name|0|0|0|0|1|0|
 

 VmDAOTest.queryOneRecord.when.wiki / VmDAOTest.testUpdate.when.wiki       

|connect|
|query|select vm_name, cores, mem, disk, status, nc_id, is_deleted from vm|
|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|test-vm_name|0|0|0|0|1|0| 

 

      生成的DAOTEST Java 文件:  

package xxx.dao.regiondb.impl;

import java.util.Date;
import java.util.List;

import org.jtester.unitils.dbfit.DbFit;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.unitils.spring.annotation.SpringBeanByName;

import xxx.BaseRegionDbDAOTestCase;
import xxx.constant.group.Nic;
import xxx.constant.group.Policy;
import xxx.dao.regiondb.VmDAO;
import xxx.model.db.VmDO;

public class VmDAOTest extends BaseRegionDbDAOTestCase {

    @SpringBeanByName
    private VmDAO vmDAO;
    
    @Test
    @DbFit(when="VmDAOTest.initBlank.when.wiki", then="VmDAOTest.queryOneRecord.then.wiki")
    public void testInsertVmDO() {
        VmDO vm = new VmDO();
        vm.setId(1);
        vm.setVmName("test-vm_name");
        vm.setCores(0);
        vm.setMem(0);
        vm.setDisk(0);
        vm.setStatus(0);
        vm.setNcId(1);
        vm.setIsDeleted(0);
        
        vm.setGmtCreate(new Date());
        vm.setGmtModify(new Date());
        vmDAO.insertVmDO(vm);
    }

    @Test
    @DbFit(when="VmDAOTest.initRecords.when.wiki")
    public void testCountVmDOByExample() {
        VmDO vm = new VmDO();
        Assert.assertTrue(vmDAO.countVmDOByExample(vm).intValue() == 1);
    }

    @Test
    @DbFit(when="VmDAOTest.initRecords.when.wiki", then="VmDAOTest.testUpdate.then.wiki")
    public void testUpdateVmDO() {
        VmDO vm = vmDAO.findVmDOByPrimaryKey();
        vm.setId(1);
        vm.setVmName("test-vm_name");
        vm.setCores(0);
        vm.setMem(0);
        vm.setDisk(0);
        vm.setStatus(0);
        vm.setNcId(1);
        vm.setIsDeleted(0);
        
        vmDAO.updateVmDO(vm);
    }

    @Test
    @DbFit(when="VmDAOTest.initRecords.when.wiki")
    public void testFindListByExample() {
        VmDO vm = new VmDO();
        vm.setId(1);
        vm.setVmName("test-vm_name");
        vm.setCores(0);
        vm.setMem(0);
        vm.setDisk(0);
        vm.setStatus(0);
        vm.setNcId(1);
        vm.setIsDeleted(0);
        
        List<VmDO> list = vmDAO.findListByExample(vm);
        Assert.assertEquals(list.size(), 1);
        for (VmDO vmDO: list) {
            Assert.assertEquals(vm.getId(), 1)
        Assert.assertEquals(vm.getVmName(), "test-vm_name")
        Assert.assertEquals(vm.getCores(), 0)
        Assert.assertEquals(vm.getMem(), 0)
        Assert.assertEquals(vm.getDisk(), 0)
        Assert.assertEquals(vm.getStatus(), 0)
        Assert.assertEquals(vm.getNcId(), 1)
        Assert.assertEquals(vm.getIsDeleted(), 0)
        
        }
    }

    @Test
    @DbFit(when="VmDAOTest.initRecords.when.wiki")
    public void testFindVmDOByPrimaryKey() {
        VmDO vm = vmDAO.findVmDOByPrimaryKey(1L);
        Assert.assertEquals(vm.getId(), 1)
        Assert.assertEquals(vm.getVmName(), "test-vm_name")
        Assert.assertEquals(vm.getCores(), 0)
        Assert.assertEquals(vm.getMem(), 0)
        Assert.assertEquals(vm.getDisk(), 0)
        Assert.assertEquals(vm.getStatus(), 0)
        Assert.assertEquals(vm.getNcId(), 1)
        Assert.assertEquals(vm.getIsDeleted(), 0)
        
    }

    @Test
    @DbFit(when="VmDAOTest.initRecords.when.wiki")
    public void testDeleteVmDOByPrimaryKey() {
        Integer count = vmDAO.deleteVmDOByPrimaryKey(1L);
        Assert.assertEquals(count.intValue(), 1);
        
        Integer nodelete = vmDAO.deleteVmDOByPrimaryKey(1L);
        Assert.assertEquals(nodelete.intValue(), 0);
    }

}

 

     结语:

     只要是手工劳动, 尽可能自动化。而要做到自动化, 第一是规范标准化, 第二是要发现一些规律性的模式。 

 

posted @ 2014-05-21 19:55  琴水玉  阅读(1220)  评论(0编辑  收藏  举报