import pymysql
from sshtunnel import SSHTunnelForwarder

class SSH_DBHandler:
    def __init__(self,ssh_ip, ssh_port,ssh_username, ssh_pkey,remote_bind_address,remote_bind_port,mysql_user,mysql_password):
        self.tunnel = SSHTunnelForwarder(
                (ssh_ip, ssh_port),
                ssh_username=ssh_username,
                ssh_pkey=ssh_pkey,
                remote_bind_address=(remote_bind_address, remote_bind_port)
        )
        self.tunnel.start()
        self.conn = pymysql.connect('127.0.0.1',
                                        port=self.tunnel.local_bind_port,
                                        user=mysql_user,
                                        password=mysql_password)
        self.cursor = self.conn.cursor()


    def query(self, sql, args=None, fetch=1):
        self.cursor.execute(sql, args)
        self.conn.commit()
        result = None
        if fetch == 1:
            result = self.cursor.fetchone()
        elif fetch > 1:
            result = self.cursor.fetchmany(fetch)
        else:
            result = self.cursor.fetchall()
        return result

    def close(self):
        self.cursor.close()
        self.conn.close()
        self.tunnel.close()

def ssh_dbhandler(conn_db):
    ssh_ip = yaml_data[conn_db]["ssh_ip"]
    ssh_port = int(yaml_data[conn_db]["ssh_port"])
    ssh_username = yaml_data[conn_db]["ssh_username"]
    ssh_pkey = TestConfig.ssh_pkey
    remote_bind_address = yaml_data[conn_db]["remote_bind_address"]
    remote_bind_port = int(yaml_data[conn_db]["remote_bind_port"])
    mysql_user = yaml_data[conn_db]["mysql_user"]
    mysql_password = yaml_data[conn_db]["mysql_password"]

    ssh_dbhandler = SSH_DBHandler(ssh_ip, ssh_port, ssh_username, ssh_pkey,
                                  remote_bind_address,remote_bind_port, mysql_user, mysql_password)
    return ssh_dbhandler
#mysql实例化
ssh_mysql = ssh_dbhandler("ssh_mysql")

if __name__ == '__main__':
    print(type(ssh_mysql.query('show databases', fetch=1)))
    ssh_mysql.close()