unittest源码解析

最近学习drf看了下实现源码,心血来潮想着研究下unnittest源码看看后续用于编写一个测试平台,结尾最后附上自己简单写的一个unittest报告生成的代码

正篇开始我们先以一个unittest使用案例入手

from test_demo import *  
  
if __name__ == '__main__':  
    runner=unittest.TextTestRunner()  
    suite=unittest.TestSuite()  
    suite.addTest(TestDemo('test_mul'))  
    suite.addTest(TestDemo('test_div'))   
    result=runner.run(suite)  
    print(f'Total tests run: {result.testsRun}')

首先确认入口,我这里以runner.run函数作为入口开始分析代码流程,在此之前建议熟悉下unnitest使用流程

    def run(self, test):
            # 生成一个result对象,result由runner类制定
            result = self._makeResult()
            startTime = time.perf_counter()
            startTestRun = getattr(result, 'startTestRun', None)
            if startTestRun is not None: # 所有用例开始执行前执行一般默认为空
                startTestRun()
            try:
                test(result) # test:suite或者testcase对象,对象()调用__call__方法
            finally:
                stopTestRun = getattr(result, 'stopTestRun', None) #所有用例执行完成后执行,默认为空
                if stopTestRun is not None:
                    stopTestRun()
            stopTime = time.perf_counter()
            timeTaken = stopTime - startTime #统计执行时间

接下来看下suite逻辑:

def __iter__(self):
     return iter(self._tests)

def addTest(self, test):
      self._tests.append(test)

def run(self, result, debug=False):
    topLevel = False
    if getattr(result, '_testRunEntered', False) is False:
        result._testRunEntered = topLevel = True

        for index, test in enumerate(self): # 遍历添加的用例,参考addTest方法及__iter__魔法方法
            if result.shouldStop: #如果需要终止则停止
                break

            if _isnotsuite(test):
                    self._tearDownPreviousClass(test, result) #执行上一测试类的tearDownClass方法,若没有则不执行
                    self._handleModuleFixture(test, result) #执行上一测试模块的tearDownModule方法(若没有则不执行)及当前测试模块的setUpModule方法(若没有或已执行则不执行,否则执行)
                    self._handleClassSetUp(test, result) #执行当前测试类的setUpClass方法,若没有或已执行则不执行,否则执行
                    result._previousTestClass = test.__class__ #result记录当前测试类上一测试类,用来判断切换测试类时执行tearDownClass方法

                    if (getattr(test.__class__, '_classSetupFailed', False) or
                        getattr(result, '_moduleSetUpFailed', False)): #初始化失败不执行用例
                        continue

            if not debug:
                 test(result) #进入用例执行逻辑
            else:
                 test.debug()

            if self._cleanup:
                 self._removeTestAtIndex(index)

        if topLevel: #最后一个测试用例执行完,执行对应的测试类和模块的tearDown方法
            self._tearDownPreviousClass(None, result)
            self._handleModuleTearDown(result)
            result._testRunEntered = False
            return result
def __call__(self, result): #对象执行调用该方法
        return self.run(result)

TestCase逻辑:

class _Outcome(object):  
    def __init__(self, result=None):  
        self.expecting_failure = False  
        self.result = result  
        self.result_supports_subtests = hasattr(result, "addSubTest")  
        self.success = True  
        self.skipped = []  
        self.expectedFailure = None  
        self.errors = []  
  
    @contextlib.contextmanager  
    def testPartExecutor(self, test_case, isTest=False):  #yield为执行with句柄里的代码,待句柄里的代码执行完成后再执行yield之后的代码
        old_success = self.success  
        self.success = True  
        try:  
            yield  
        except KeyboardInterrupt:  
            raise  
        except SkipTest as e:  
            self.success = False  
            self.skipped.append((test_case, str(e)))  
        except _ShouldStop:  
            pass  
        except:  
            exc_info = sys.exc_info()  
            if self.expecting_failure:  
                self.expectedFailure = exc_info  
            else:  
                self.success = False  
                self.errors.append((test_case, exc_info))  
            # explicitly break a reference cycle:  
            # exc_info -> frame -> exc_info            exc_info = None  
        else:  
            if self.result_supports_subtests and self.success:  
                self.errors.append((test_case, None))  
        finally:  
            self.success = self.success and old_success  
  
class TestCase(object):  
  
    def __init__(self, methodName='runTest'):  
        self._testMethodName = methodName  
        self._outcome = None  
        testMethod = getattr(self, methodName) 
        self._testMethodDoc = testMethod.__doc__ 
        self._cleanups = []  
        self._subtest = None  
  
    def shortDescription(self):  
        doc = self._testMethodDoc  
        return doc.strip().split("\n")[0].strip() if doc else None  
  
  
    def id(self):  
        return "%s.%s" % (strclass(self.__class__), self._testMethodName)  

    def __str__(self):  
        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))  
  
    def __repr__(self):  
        return "<%s testMethod=%s>" % \  
               (strclass(self.__class__), self._testMethodName)  
  
    def run(self, result=None):  
  
        result.startTest(self)  #用例执行开始执行
        try:  
            testMethod = getattr(self, self._testMethodName)  
            if (getattr(self.__class__, "__unittest_skip__", False) or  
                getattr(testMethod, "__unittest_skip__", False)):  
                # If the class or method was skipped.  
                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')  
                            or getattr(testMethod, '__unittest_skip_why__', ''))  
	                self._addSkip(result, self, skip_why)  #skip逻辑
                return result  
  
            expecting_failure = (  
                getattr(self, "__unittest_expecting_failure__", False) or  
                getattr(testMethod, "__unittest_expecting_failure__", False)  
            )  #期望失败逻辑
            outcome = _Outcome(result)  
            try:  
                self._outcome = outcome  
  
                with outcome.testPartExecutor(self):   #关于testPartExecutor,参考上述此方法解释
                    self._callSetUp()  #执行setup方法
                if outcome.success:  #初始化成功执行用例主体
                    outcome.expecting_failure = expecting_failure  #记录期望失败:true or false
                    with outcome.testPartExecutor(self, isTest=True):  
                        """
                        抛出异常时,如果期望失败只记录异常信息,否则结果置为失败
                        if self.expecting_failure:  
                             self.expectedFailure = exc_info  
                        else:  
                              self.success = False  
                              self.errors.append((test_case, exc_info)) 
                        """
                        self._callTestMethod(testMethod)  
                    outcome.expecting_failure = False  
                    with outcome.testPartExecutor(self):  
                        self._callTearDown()  #执行teardown方法
                for test, reason in outcome.skipped:  
                    self._addSkip(result, test, reason)  
                self._feedErrorsToResult(result, outcome.errors)  
                if outcome.success:  
                    if expecting_failure:  
                        if outcome.expectedFailure:  
                            self._addExpectedFailure(result, outcome.expectedFailure)  
                        else:  
                            self._addUnexpectedSuccess(result)  
                    else:  
                        result.addSuccess(self)  
                return result  
            finally:  
                outcome.expectedFailure = None  
  
                # clear the outcome, no more needed  
                self._outcome = None  
  
        finally:  
            result.stopTest(self)  
            if stopTestRun is not None:  
                stopTestRun()  
  
    
    def __call__(self, *args, **kwds):  
        return self.run(*args, **kwds)  

流程已基本梳理通了,那么可以通过定制result和runner来达到生成html报告的目的,下面是最终结果(后续会补充针对result的逻辑,因为result是贯穿全局的,理清result逻辑就基本可以自己手撸一个测试框架):

import re  
import os  
import time  
import getpass  
import unittest  
import warnings  
  
  
class HtmlTestResult(unittest.TextTestResult):  
  
    def __init__(self, stream, descriptions, verbosity=2):  
        super().__init__(stream, descriptions, verbosity)  
        self.timeInfo = {}  
  
    def startTest(self, test):  
        super().startTest(test)  
        self.timeInfo[str(test)] = {'startDate': time.strftime("%Y-%m-%d %H:%M:%S")}  
        self.startTime = time.perf_counter()  
  
    def stopTest(self, test):  
        super().stopTest(test)  
        self.timeInfo[str(test)].update(  
            {  
                'takenTime': round(time.perf_counter() - self.startTime,2) or 0.01,  
                'endDate': time.strftime("%Y-%m-%d %H:%M:%S")  
            })  
  
  
class HtmlTestRunner(unittest.TextTestRunner):  
    pattern = re.compile('\((.+)\)')  
    resultclass = HtmlTestResult  
  
    def _get_tests(self, test):  
        "Return a list of tests for the given test case or test suite."  
        if isinstance(test, unittest.TestSuite):  
            return [t for t in test]  
        else:  
            return [test]  
  
    def _new_errors(self, result, tests):  
        new_errors = []  
        sub_run = 0  
        for error in result.errors:  
            if isinstance(error[0], unittest.TestCase):  
                new_errors.append(error)  
            else:  
                for test in tests:  
                    module_or_class = self.pattern.search(str(error[0])).group(1)  
                    if module_or_class == unittest.util.strclass(test.__class__) \  
                            or module_or_class == test.__class__.__module__:  
                        new_errors.append((test, error[1]))  
                        if 'setUp' in error[1]:  
                            sub_run += 1  
        result.errors = new_errors  
        result.testsSubRun = sub_run  
  
    def run(self, test):  
        self._tests = self._get_tests(test)  # 生成运行测试用例列表  
        # 源码        result = self._makeResult()  
        unittest.signals.registerResult(result)  
        result.failfast = self.failfast  
        result.buffer = self.buffer  
        result.tb_locals = self.tb_locals  
        with warnings.catch_warnings():  
            if self.warnings:  
                warnings.simplefilter(self.warnings)  
                if self.warnings in ['default', 'always']:  
                    warnings.filterwarnings('module',  
                                            category=DeprecationWarning,  
                                            message=r'Please use assert\w+ instead.')  
            startTime = time.perf_counter()  
            startDate = time.strftime('%Y-%m-%d %H:%M:%S')  
            startTestRun = getattr(result, 'startTestRun', None)  
            if startTestRun is not None:  
                startTestRun()  
            try:  
                test(result)  
            finally:  
                stopTestRun = getattr(result, 'stopTestRun', None)  
                if stopTestRun is not None:  
                    stopTestRun()  
            stopTime = time.perf_counter()  
        timeTaken = stopTime - startTime  
  
        # setUpClass or setUpModule 出现异常会导致error类是_ErrorHandler类,修改成case实例对象  
        self._new_errors(result, self._tests)  
        err_count = len(set([error[0] for error in result.errors if 'tearDown' not in error[1]]))  
  
        result.printErrors()  
        if hasattr(result, 'separator2'):  
            self.stream.writeln(result.separator2)  
        run = result.testsRun  
        sub_run = result.testsSubRun  # 因模块或类初始化而未统计到的用例  
        sub_run_str = '类或模块初始化失败的用例数 %d ' % sub_run if sub_run else ''  
        run_str = "%s 执行到的用例数 %d in %.3fs" % (sub_run_str, run, timeTaken)  
        self.stream.writeln(run_str)  # run != 1 and "s" or "",  
        self.stream.writeln()  
  
        expectedFails = unexpectedSuccesses = skipped = 0  
        try:  
            results = map(len, (result.expectedFailures,  
                                result.unexpectedSuccesses,  
                                result.skipped))  
        except AttributeError:  
            pass  
        else:  
            expectedFails, unexpectedSuccesses, skipped = results  
  
        infos = []  
        info_dict = {'total': len(self._tests), 'fail': 0, 'error': 0, 'skip': skipped,  
                     'expectedFail': expectedFails,  
                     'unexpectedSuccess': unexpectedSuccesses}  
        result.infos = infos  # 添加info信息到result  
        if not result.wasSuccessful():  
            self.stream.write("FAILED")  
            failed, errored = len(result.failures), err_count  
            info_dict.update({'fail': failed, 'error': err_count,  
                              'pass': len(self._tests)-skipped-unexpectedSuccesses-expectedFails-failed-err_count})  
            info_dict['passRate'] = round(info_dict['pass']/info_dict['total'],4)*100  
            if failed:  
                infos.append("failures=%d" % failed)  
            if errored:  
                infos.append("errors=%d" % errored)  
        else:  
            self.stream.write("OK")  
        if skipped:  
            infos.append("skipped=%d" % skipped)  
        if expectedFails:  
            infos.append("expected failures=%d" % expectedFails)  
        if unexpectedSuccesses:  
            infos.append("unexpected successes=%d" % unexpectedSuccesses)  
        if infos:  
            self.stream.writeln()  
        else:  
            self.stream.write("\n")  
        self.stream.flush()  
        # 生成summary信息]  
  
        summary = {  
            'tester': os.getenv('tester') or getpass.getuser(),  
            'startDate': startDate,  
            'timeTaken': timeTaken,  
            'runResults': f'执行用例总数:{run + sub_run},失败:{len(result.failures)},异常:{err_count}'  
        }  
        result.summary = summary  # 加入summary信息到result  
        result.info_dict = info_dict  
        return result  
  
    def _handle_errors(self, result, test_str, case_data):  
        self._error_count = getattr(self, '_error_count', 0)  
        for error in result.errors.copy():  
            if test_str == str(error[0]):  
                result.errors.remove(error)  
                if 'setUp' in error[1]:  
                    case_data[test_str].update({'status': 'error', 'setUpError': error[1]})  
                    break  
                elif 'tearDown' in error[1]:  
                    case_data[test_str].update({'tearDownError': error[1]})  
                else:  
                    case_data[test_str].update({'status': 'error', 'runError': error[1]})  
                    break  
        if case_data[test_str]['status'] == 'error':  
            self._error_count += 1  
            return True  
  
    def _handle_failures(self, result, test_str, case_data):  
        for failure in result.failures:  
            if test_str == str(failure[0]):  
                result.failures.remove(failure)  
                case_data[test_str].update({'status': 'fail', 'failure': failure[1]})  
                return True  
  
    def _handle_skipped(self, result, test_str, case_data):  
        for skip in result.skipped:  
            if test_str == str(skip[0]):  
                result.skipped.remove(skip)  
                case_data[test_str].update({'status': 'skip', 'skipReason': skip[1]})  
                return True  
  
    def _handle_expectedFailures(self, result, test_str, case_data):  
        for expectedFailure in result.expectedFailures:  
            if test_str == str(expectedFailure[0]):  
                result.expectedFailures.remove(expectedFailure)  
                case_data[test_str].update({'status': 'expectedFail', 'expectedFail': expectedFailure[1]})  
                return True  
  
    def _handle_unexpectedSuccesses(self, result, test_str, case_data):  
        for unexpectedFailure in result.unexpectedSuccesses:  
            if test_str == str(unexpectedFailure[0]):  
                result.unexpectedSuccesses.remove(unexpectedFailure)  
                case_data[test_str].update({'status': 'unexpectedSuccess'})  
                return True  
  
    def _info_description(self, info):  
        for k, v in info.items():  
            if v > 0 and k != 'total':  
                return f'测试失败,执行{info["total"]}用例,失败:{info["fail"]},异常:{info["error"]}'  
        return f'测试执行{info["total"]}用例,结果通过'  
  
    def _write_html(self, data,report_path):  
        base_html = os.path.join(os.path.dirname(os.path.abspath(__file__)),'utest.html')  
        pattern = re.compile('({{(.+?)}})')  
        with open(base_html) as fp:  
            with open(report_path, 'w') as wf:  
                line = fp.readline()  
                while line:  
                    for old_value, key in pattern.findall(line):  
                        key = key.strip()  
                        if key == 'data':  
                            line = line.replace(old_value, str(data))  
                        else:  
                            line = line.replace(old_value, str(data[key]))  
                    wf.write(line)  
                    line = fp.readline()  
  
    def html(self, result, report_path=None, title='测试报告', description=None, attrs=None):  
        case_data = {}  
        attrs = attrs or ['errors', 'failures', 'skipped',  
                          'expectedFailures', 'unexpectedSuccesses']  
        for test in self._tests:  
            test_str = str(test)  
            case_data[test_str] = dict(status='pass', name=test._testMethodDoc or test._testMethodName,  
                                       **result.timeInfo.get(test_str, {}))  
            for attr in attrs:  
                if getattr(self, f'_handle_{attr}')(result, test_str, case_data):  
                    break  
        description = description or self._info_description(result.info_dict)  
        html_data={  
            'case_data': case_data,  
            'summary': result.summary,  
            'info': result.info_dict,  
            'title': title,  
            'description': description,  
            'passRate': result.info_dict.pop('passRate', None),  
        }  
        report_path=report_path or os.path.join(os.getcwd(),f'report_{result.summary["startDate"]}.html')  
        self._write_html(html_data, report_path)
<!DOCTYPE html>  
<html lang="en">  
<head>  
    <meta charset="UTF-8">  
    <title>{{title}}</title>  
  
    <style>  
         *{  
            margin: 0;  
            padding: 0;  
        }  
         .top{  
             padding:20px 0 0 10px;  
         }  
        .title {  
            font-size:40px;  
            color: blue;  
        }  
        .description {  
            color: gray;  
            font-size: 16px;  
        }  
        .second-title{  
            color: aquamarine;  
            font-size: 25px;  
        }  
        li{  
            list-style-type: none;  
            padding: 5px;  
        }  
        .sum-title{  
            width: 100px;  
            display: inline-block;  
        }  
        .sum-content {  
            width: 300px;  
            display: inline-block;  
            margin-left:10px;  
            color: gray;  
        }  
        table{  
            width: 100%;  
            border-collapse: collapse;  
            text-align: center;  
            margin-bottom:20vh;  
            margin-top:10px;  
        }  
        thead{  
            background-color: aqua;  
        }  
        td,th{  
            padding: 3px;  
        }  
        button{  
            width: 80px;  
            height: 40px;  
            border: none;  
            margin-right:10px;  
            border-radius:5px;  
        }  
        .pass,.expectedFail{  
            background-color: green;  
        }  
        .fail,.unexpectedSuccess{  
            background-color: red;  
        }  
        .error{  
            background-color: yellow;  
        }  
        .skip{  
            background-color: gray;  
        }  
        .rate-pass{  
            color: green;  
        }  
        .rate-fail{  
            color: red;  
        }  
        .hide{  
            display: none;  
        }  
        .error_info {  
          width: 100%; /* 占据整个单元格宽度 */  
          background-color: black; /* 设置背景颜色(可选) */  
          color: red; /* 设置字体颜色(可选) */  
          padding: 10px; /* 内边距 */  
          box-sizing: border-box; /* 包括内边距在宽度内 */  
          font-size: 14px; /* 可根据需要调整字体大小 */  
          white-space: pre-line; /* 保留换行 */  
            text-align: left;  
        }  
    </style>  
  
</head>  
<body>  
    <h1 class="title top">{{title}}</h1>  
    <p class="description top">{{description}}</p>  
    <h2 class="summary-title top">测试概要</h2>  
    <ul class="summary-content top">  
  
    </ul>  
    <h2 class="top">测试详情</h2>  
    <p class="top pass-rate">通过率:{{ passRate }}%</p>  
    <p class="top btn-list"></p>  
    <table class="top">  
        <thead>  
            <tr>  
                <th>序号</th>  
                <th>用例名称</th>  
                <th>测试结果</th>  
                <th>开始时间</th>  
                <th>结束时间</th>  
                <th>执行时长</th>  
            </tr>  
        </thead>  
        <tbody class="case-content"></tbody>  
    </table>  
<script>  
    var data= {{ data }}  
  
    addSummary(data.summary)  
    addCasesTr(data.case_data)  
    addTrClick()  
    addBtnList(data.info)  
    eachRowColor()  
    filterCases()  
    passRateClass()  
  
    function addSummary(summary){  
        for(let key in summary){  
            var li=`<li><span class="sum-title">${key}</span>:<span class="sum-content">${summary[key]}</span></li>`  
            document.querySelector('.summary-content').innerHTML+=li;  
        }  
    }  
  
    function addCasesTr(case_data) {  
        err_key=['setUpError','runError','tearDownError','failure','skipReason']  
        for(let key in case_data) {  
            let status = case_data[key].status  
            let startDate = case_data[key].startDate  
            let endDate = case_data[key].endDate  
            let takenTime = case_data[key].takenTime  
            let tr = `<tr>  
                <td>${key}</td>  
                <td>${case_data[key].name}</td>  
                <td class="${status}">${status}</td>  
                <td>${startDate==undefined?'--':startDate}</td>  
                <td>${endDate==undefined?'--':endDate}</td>  
                <td>${takenTime==undefined?'--':takenTime}</td>  
            </tr>`;  
            var div_tag;  
            err_key.forEach(function(k) {  
                if(case_data[key][k]!=undefined){  
                    let text=k+'\n'+case_data[key][k]  
                    div_tag=`<tr class="hide err"><td colspan="6"><div class="error_info">${text}</div></td></tr>`;  
                }  
            })  
            document.querySelector('.case-content').innerHTML += tr;  
            if(div_tag!=undefined) {  
                document.querySelector('.case-content').innerHTML += div_tag;  
                div_tag=undefined;  
            }else{  
                document.querySelector('.case-content').innerHTML += `<tr class="hide err"><td colspan="6">无错误信息</td></tr>`;  
            }  
  
        }  
    }  
  
    function addTrClick() {  
        document.querySelectorAll('.case-content tr').forEach(function(row){  
            row.addEventListener('click', function(){  
                var next_node=this.nextSibling;  
                console.log(next_node)  
                if(next_node.classList.contains('hide')){  
                    this.nextSibling.classList.remove('hide');  
                }else {  
                    this.nextSibling.classList.add('hide');  
                }  
            })  
        })  
  
    }  
  
    function addBtnList(info) {  
        for(let key in info) {  
            if (info[key]>0){  
                let btn = `<button class="${key} filter">${key}:${info[key]}</button>`;  
            document.querySelector('.btn-list').innerHTML += btn;  
            }  
  
        }  
    }  
  
    function filterCases() {  
        document.querySelectorAll('.filter').forEach(function(ele){  
            ele.addEventListener('click', function(){  
                let status = this.className.split(' ')[0];  
                let rows = document.querySelectorAll('.case-content tr');  
                rows.forEach(function(row){  
                    if(row.classList.contains('err')){  
                        row.classList.add('hide');  
                        return  
                    }else{  
                        row.classList.remove('hide');  
                    } 
                    if(status!='total' && row.querySelector('td:nth-child(3)').innerHTML !== status){  
                        console.log('inner status'+status)  
                        row.classList.add('hide');  
                    }  
                })  
            })  
        })  
    }  
  
    function passRateClass() {  
        let el=document.querySelector('.pass-rate')  
  
        pass_rate=el.innerText.split(':')[1].split('%')[0];  
        if(pass_rate>85){  
            el.classList.add('rate-pass')  
            el.innerHTML += '(通过率已达标)'  
        }else{  
            el.classList.add('rate-fail')  
            el.innerHTML += '(通过率未达标)'  
        }  
    }  
  
    function eachRowColor() {  
        const colors = ['whitesmoke','white','#eee','#ddd'];  
        const tbody = document.querySelector('.case-content');  
        Array.from(tbody.rows).forEach((row, index) => {  
          row.style.backgroundColor = colors[index % colors.length];  
        });  
    }  
  
</script>  
</body>  
</html>

写在最后:其实可以看出来本质实现就是try except语句,分别执行setup、testmethod、teardown方法,结果记录到result对应列表里,那么根据逻辑可以有两个思路来生成报告,一个如上述实现所有用例统一生成一个报告,一个就是每个用例生成一个报告,这样在平台展示用例列表时可以很方便点击到对应用例的报告,后面会编写一个针对TestCase定制生成单独报告,舍弃掉类和模块初始化,轻装上阵

posted @   麦兜顶当当  阅读(3)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 地球OL攻略 —— 某应届生求职总结
· 提示词工程——AI应用必不可少的技术
· Open-Sora 2.0 重磅开源!
· 周边上新:园子的第一款马克杯温暖上架
点击右上角即可分享
微信分享提示