常用的工具类 读取execl 、csv、mysql

execl -python xlrd

# -*- coding: utf-8 -*-
"""
@Time    : 2022/10/27 14:00
@FileName: execl_handler.py
"""

import xlrd, json


class Excel(object):
    """Excel文件操作工具类"""

    def __init__(self, filename):
        self.workbook = xlrd.open_workbook(filename, formatting_info=True)

    def get_sheet_names(self):
        """
        获取当前excel文件所有的工作表的表名
        :return:
        """
        return self.workbook.sheet_names()

    def __get_sheet(self, sheet_index_or_name):
        """
        根据sheet的索引或名称,获取sheet对象
        :param sheet_index_or_name: sheet的索引或名称
        :return:sheet对象
        """
        if isinstance(sheet_index_or_name, int):
            if len(self.workbook.sheet_names()) > sheet_index_or_name:
                return self.workbook.sheet_by_index(sheet_index_or_name)
            else:
                raise Exception("Invalid Sheet Index!")
        elif isinstance(sheet_index_or_name, str):
            if sheet_index_or_name in self.workbook.sheet_names():
                return self.workbook.sheet_by_name(sheet_index_or_name)
            else:
                raise Exception("Invalid Sheet Name!")

    def get_rows_num(self, sheet_index_or_name):
        """
        获取指定工作表的数据总行数
        :param sheet_index_or_name: 工作表名或索引
        :return:
        """
        return self.__get_sheet(sheet_index_or_name).nrows

    def get_cols_num(self, sheet_index_or_name):
        """
        获取指定工作表的数据总列数
        :param sheet_index_or_name: 工作表名或索引
        :return:
        """
        return self.__get_sheet(sheet_index_or_name).ncols

    def get_cell_value(self, sheet_index_or_name, row_index, col_index):
        """
        获取指定工作表的指定位置的数据值
        :param sheet_index_or_name: 工作表名或索引
        :param row_index: 行下标,从0开始
        :param col_index: 列下标,从0开始
        :return:
        """
        sheet = self.__get_sheet(sheet_index_or_name)
        if sheet.nrows and sheet.ncols:
            return sheet.cell_value(row_index, col_index)
        else:
            raise Exception("Index out of range!")

    def get_data(self, sheet_index_or_name, fields, first_line_is_header=True):
        """
        获取工作表的所有数据
        :param sheet_index_or_name: 工作表名或索引
        :param fields: 返回数据的字段名
        :param first_line_is_header: 工作表是否是否表头,也就是非数据
        :return:
        """
        rows = self.get_rows_num(sheet_index_or_name)
        cols = self.get_cols_num(sheet_index_or_name)
        data = []
        for row in range(int(first_line_is_header), rows):
            row_data = {}
            for col in range(cols):
                cell_data = self.get_cell_value(sheet_index_or_name, row, col)
                if type(cell_data) is str and (
                        "{" in cell_data and "}" in cell_data or "[" in cell_data and "]" in cell_data):
                    """判断如果表格中填写的数据是json格式键值对,则采用json模块转换成字典"""
                    cell_data = json.loads(cell_data)
                row_data[fields[col]] = cell_data
            data.append(row_data)

        return data


if __name__ == '__main__':
    xls = Excel(r"./data/case_user.xls")
    fields = [
        "case_id",
        "module_name",
        "case_name",
        "method",
        "url",
        "headers",
        "params_desc",
        "params",
        "assert_result",
        "real_result",
        "remark",
    ]

    print(xls.get_data(0, fields))

"""
[

    {'case_id': 1.0, 'module_name': '用户模块', 'case_name': '用户登录-测试用户名为空的情况', 'method': 'post', 'url': 'http://127.0.0.1:8000/user/login', 'headers': '', 'params_desc': 'username: 用户名\npassword: 密码', 'params': {'username': '', 'password': '123456'}, 'assert_result': 'code==400', 'real_result': '', 'remark': ''}, 
    {'case_id': 2.0, 'module_name': '用户模块', 'case_name': '用户登录-测试密码为空的情况', 'method': 'post', 'url': 'http://127.0.0.1:8000/user/login', 'headers': '', 'params_desc': 'username: 用户名\npassword: 密码', 'params': {'username': 'xiaoming', 'password': ''}, 'assert_result': 'code==400', 'real_result': '', 'remark': ''}, 
    {'case_id': 3.0, 'module_name': '用户模块', 'case_name': '用户登录-测试账号密码正确的情况', 'method': 'post', 'url': 'http://127.0.0.1:8000/user/login', 'headers': '', 'params_desc': 'username: 用户名\npassword: 密码', 'params': {'username': 'xiaoming', 'password': '123456'}, 'assert_result': ['code==200', "'data' in json"], 'real_result': '', 'remark': ''}, 
    {'case_id': 4.0, 'module_name': '用户模块', 'case_name': '用户登录-测试使用手机号码登录', 'method': 'post', 'url': 'http://127.0.0.1:8000/user/login', 'headers': '', 'params_desc': 'username: 手机号\npassword: 密码', 'params': {'username': '13312345678', 'password': '123456'}, 'assert_result': 'code==200', 'real_result': '', 'remark': ''}
]
"""


csv - python csv

CSV工具类是Python中的自带包,用来解析CSV文件。

1.实例化一个CSV对象,需要传入一个CSV文件的路径

with open('./case.csv') as casefile

2.csv.DictReader() 将CSV读取成字典的形式

rows2 = csv.DictReader(casefile)
print rows2
# [{'paxID': '111', 'daxID': '222', 'merID': '333'}, {'paxID': '444', 'daxID': '555', 'merID': '666'}]

3.csv.reader() 将CSV读取成list

rows = csv.reader(casefile)
row_list = [row for row in rows]
        print row_list
# [['paxID', 'daxID', 'merID'], ['111', '222', '333'], ['444', '555', '666']]

4.封装一个csv 工具类

#coding=UTF-8
import csv
import traceback

class CSV:
    def __init__(self,filePath):
        self.filePath = filePath
        self.allRows = None
        try:
            with open(self.filePath) as csvfile:
                rows = csv.DictReader(csvfile)
                self.allRows = [row for row in rows]
        except:
            traceback.print_exc()

    def getAll(self):
        return self.allRows

    def getCell(self, rowNum, colunmName):
        cell = ''
        if rowNum > 0 and colunmName != None:
            try:
                cell = self.allRows[rowNum-1][colunmName]
            except:
                print 'colunmName is inexistent'
        else:
            print 'rowNum should begin from 1 or colunmName is invalid'
        return cell

    #取csv的标题行,字典中的keys
    def getFirstRow(self):
        try:
            dict1 = self.allRows[0]
            keys = dict1.keys()
        except:
            traceback.print_exc()
        return keys

    #取某一纵列的值
    def getColunmName(self, name):
        #创建一个数组,循环大数据,每个字典取Key是name的值,追加到数组中,返回
        result = None
        for d in self.allRows:
            try:
                result.append(d[name])
            except:
                traceback.print_exc()
        return result


class OperationCSV:
    def __init__(self):
        pass

    @staticmethod
    def is_exist(file_path):
        if os.path.exists(file_path):
            return True
        else:
            logger.error("file is not exist!")
            raise Exception("file is not exist!")

    def list_reader(self, file_path):
        if self.is_exist(file_path):
            try:
                with open(file_path, 'r', newline="", encoding='utf-8') as f:
                    rows = csv.reader(f)
                    row_list = [row for row in rows]
                    return row_list
            except Exception as e:
                logger.error(e)

    @staticmethod
    def list_writer(file_path, headers, data):
        try:
            with open(file_path, 'w', encoding='utf-8', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(headers)
                writer.writerows(data)
                logger.info(f"成功生成文件:{file_path}!")
        except Exception as e:
            logger.error(e)

    def dict_reader(self, file_path):
        if self.is_exist(file_path):
            with open(file_path, 'r', newline="", encoding='utf-8') as f:
                rows = csv.DictReader(f, )
                row_list = [row for row in rows]
                return row_list

    @staticmethod
    def dict_writer(file_path, data, fields):
        with open(file_path, 'w', encoding='utf-8', newline='') as f:
            writer = csv.DictWriter(f, fields)
            writer.writeheader()
            writer.writerows(data)
            logger.info(f"成功生成文件:{file_path}!")

"""
@Time    : 2023/11/29 10:40
@FileName: csv_helper.py
"""
import csv
import os


class CsvWriter:
    """ csv helper - write """

    def __init__(self, filename):
        self.filename = filename
        self.csvfile = self.__csv_file

    @property
    def __csv_file(self):
        export_file = os.path.join(os.getcwd(), self.filename)
        return open(export_file, mode='w', encoding='utf-8-sig', newline='')

    def writer(self):
        return csv.writer(self.csvfile)

    def write_dict(self, field_names):
        writer = csv.DictWriter(self.csvfile, fieldnames=field_names)
        writer.writeheader()
        return writer

    def close(self):
        self.csvfile.close()


if __name__ == '__main__':
    department_csv_obj = CsvWriter('csv_template.csv')
    department_csv_writer = department_csv_obj.writer()
    department_csv_writer.writerow(['姓名', '年龄', '城市'])
    department_csv_writer.writerow(['张三', 25, '北京'])
    department_csv_obj.close()

    department_csv_obj = CsvWriter('csv_template_dict.csv')
    fieldnames = ['姓名', '年龄', '城市']
    department_csv_writer = department_csv_obj.write_dict(fieldnames)
    data = [{'姓名': '张三', '年龄': 25, '城市': '北京'}, {'姓名': '李四', '年龄': 30, '城市': '上海'}]
    for d in data:
        department_csv_writer.writerow(d)
    department_csv_obj.close()

读取 mysql

import os
import pymysql
import base64
import yaml
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

from loguru import logger

logger.add('sync.log', level='INFO', encoding="utf-8", rotation="50 MB", retention=3)


class Configure:
    """ 配置文件 """

    def __init__(self):
        self.config = self.get_config()

    @staticmethod
    def get_config():
        config_file = os.path.join(os.getcwd(), "config.yml")
        if not os.path.exists(config_file):
            logger.info('error, configuration file not found')
        try:
            with open(config_file, 'r', encoding='utf-8') as f:
                info = yaml.safe_load(f.read())
        except:
            logger.info('error, read configuration file error')
        return info

    def __getattr__(self, item):
        return self.config.get(item)


class PyMySQL:
    """ 数据库操作 """

    def __init__(self, host, port, user, passwd, db, cursor_class=pymysql.cursors.Cursor, charset='utf8'):
        self.host = host
        self.port = port
        self.user = user
        self.passwd = passwd
        self.db = db
        self.charset = charset
        self.cursor_class = cursor_class
        self.conn = self.connect()

    def connect(self):
        """
        创建链接
        :return:
        """
        conn = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd, db=self.db,
                               cursorclass=self.cursor_class,
                               charset=self.charset)
        return conn

    def execute(self, sql):
        """
        更新
        :param sql:
        :return:
        """
        cursor = self.conn.cursor()
        cursor.execute(sql)
        self.conn.commit()
        cursor.close()

    def query_one(self, sql):
        """
        查询一条记录
        :param sql:
        :return:
        """
        cursor = self.conn.cursor()
        cursor.execute(sql)
        _result = cursor.fetchone()
        cursor.close()
        return _result

    def query_all(self, sql):
        """
        查询所有记录
        :param sql:
        :return:
        """
        cursor = self.conn.cursor()
        cursor.execute(sql)
        _result = cursor.fetchall()
        cursor.close()

        return _result

    def query_many(self, sql, size=100):
        """
        查询多条记录,每次获取 size 条,避免一次将大量数据全部加载到内存
        :param sql:
        :param size: 每次查询条数
        :return:
        """
        cursor = self.conn.cursor()
        cursor.execute(sql)
        _result = []
        while True:
            rows = cursor.fetchmany(size)
            if not rows:
                break
            _result.extend(rows)
        cursor.close()

        return _result

    def query_many_by_generator(self, sql, size=10000):
        """
        通过生成器机制获取数据
        :param sql:
        :param size: 每次查询条数
        :return:
        """
        cursor = self.conn.cursor()
        cursor.execute(sql)
        while True:
            rows = cursor.fetchmany(size)
            if not rows:
                break
            for row in rows:
                yield row  # 使用 yield 返回每条记录
        cursor.close()

    def close(self):
        """
        关闭链接
        :return:
        """
        self.conn.close()


class Handler:
    """ 处理流程 """

    def __init__(self):
        # 初始化数据库
        self.conf = Configure()
        self.db = PyMySQL(host=self.conf.MYSQL_HOST, port=self.conf.MYSQL_PORT, user=self.conf.MYSQL_USER,
                          passwd=self.__decrypt_pwd(self.conf.MYSQL_PASSWORD),
                          db=self.conf.MYSQL_DB_A, cursor_class=pymysql.cursors.DictCursor
                          , charset='gbk'
                          )

    @staticmethod
    def __decrypt_pwd(encrypted):
        key = "8227973744885042"
        cipher = AES.new(key.encode('utf8'), AES.MODE_ECB)
        decrypted = unpad(cipher.decrypt(base64.b64decode(encrypted)), 16)
        return decrypted.decode('utf8')


def run():
    """ 入口函数 """
    pass


if __name__ == '__main__':
    run()

posted @ 2022-10-27 14:27  hanfe1  阅读(88)  评论(0编辑  收藏  举报