基于SQLAlchemy实现的堡垒机

堡垒机

堡垒机是一种运维安全审计系统。主要的功能是对运维人员的运维操作进行审计和权限控制。同时堡垒机还有账号集中管理,单点登陆的功能。

堡垒机的实现我们主要采用paramiko和SQLalchemy,可以参考前面的paramiko博客。

 

堡垒机实现的流程

  1. 管理员为用户在服务器上创建账号(将公钥放置服务器,或者使用用户名密码),堡垒机服务器创建jumpserver.py的主程序文件,修改用户家目录 ./bashrc: /usr/bin/python   jumperserverPATH/jumpserver.py;logout
  2. 用户登陆堡垒机,输入堡垒机用户名密码,首先会执行jumpserver.py这个主程序。这个程序会直接打印出:ip列表,用户列表,主机列表。代表用户可以选择以什么什么登录哪台主机
  3. 当用户选择机器之后,系统自动输入用户名和密码,进行登录和相关操作。jumperserver会记录用户所有的操作。

 

表结构

jumperserver中在用户登录上来之后,会直接显示出来用户的可以一什么身份,登录哪台服务器的一个列表,这里面一系列对应关系。我们用SQLALchemy的ORM框架来解决。

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import paramiko
import sys
import os
import socket
import getpass
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, Index, Table, DateTime, Boolean, Enum
from sqlalchemy import or_, and_
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy import create_engine
from sqlalchemy.sql import func
# from sqlalchemy_utils import ChoiceType  # 需要安装sqlalchemy_utils这个组件
from paramiko.py3compat import u

# 首先建库,并且让mysql给权限登录
engine = create_engine("mysql+pymysql://root:Aa132123.@10.1.1.200:3306/s14?charset=utf8", max_overflow=5, )
Base = declarative_base()  # 创建orm的基类


class Group(Base):  # 主机组
    __tablename__ = 'group'
    group_id = Column(Integer, primary_key=True, autoincrement=True)
    group_name = Column(String(32), nullable=False, unique=True)


class Host(Base):  # 主机和主机组,是一对多关系:即一个主机必须属于一个组
    __tablename__ = 'host'
    host_id = Column(Integer, primary_key=True, autoincrement=True)
    host_ip = Column(String(32), nullable=False, unique=True)
    hostname = Column(String(255), nullable=False, unique=True)
    port = Column(String(32), nullable=False)
    group_id = Column(Integer, ForeignKey('group.group_id'), nullable=False)
    re_group_id = relationship('Group', backref='a')


class UserFile(Base):  # 用户表
    __tablename__ = 'user_file'
    user_id = Column(Integer, primary_key=True, autoincrement=True)
    user_name = Column(String(32), nullable=False, unique=True)
    gender = Column(Enum('male', 'female'), server_default='male')  # 让用户选择只能输入其中一个
    password = Column(String(32), nullable=False)


class UserToGroup(Base):  # 用户和主机组的关系表,多对多:一个用户可以属于多个主机组,一个主机组可以有多个用户
    __tablename__ = 'user_to_group'
    HU_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('user_file.user_id'), nullable=False)
    group_id = Column(Integer, ForeignKey('group.group_id'), nullable=False)
    __table_args__ = (
        UniqueConstraint('user_id', 'group_id'),
    )  # 为了减少错误,我们这里使用联合唯一索引
    re_group_id = relationship('Group', backref='b')
    re_user_id = relationship('UserFile', backref='c')


class RemoteUserFile(Base):  # 主机用户
    __tablename__ = 'remote_user_file'
    user_id = Column(Integer, primary_key=True, autoincrement=True)
    user_name = Column(String(32), nullable=False, unique=True)
    password = Column(String(32), nullable=False)
    # AuthTypes = [
    #     ('p', 'SSH/Password'),
    #     ('r', 'SSH/KEY')
    # ]
    # auth_type = Column(ChoiceType(AuthTypes))
    auth_type = Column(Enum('SSH/Password', 'SSH/KEY'), default='SSH/Password')


class RemoteUserToHost(Base):  # 主机用户和主机的关系表,多对多的关系:一个主机可以有多个主机用户,一个用户有多个主机
    __tablename__ = 'remote_user_to_host'
    RH_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('remote_user_file.user_id'), nullable=False)
    host_id = Column(Integer, ForeignKey('host.host_id'), nullable=False)
    __table_args__ = (
        UniqueConstraint('user_id', 'host_id'),
    )
    re_user_id = relationship('RemoteUserFile', backref='d')
    re_host_id = relationship('Host', backref='e')


class AuditLog(Base):
    __tablename__ = 'audit_log'
    audit_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('user_file.user_id'), nullable=False)
    remote_user_id = Column(Integer, ForeignKey('remote_user_file.user_id'), nullable=False)
    host_id = Column(Integer, ForeignKey('host.host_id'), nullable=False)
    cmd = Column(String(65535))
    handle_time = Column(DateTime, server_default=func.now())  
    # server_default表示交给数据库处理,default表示交给程序处理
    
    re_user_id = relationship('UserFile', backref='f')
    re_remote_user_id = relationship('RemoteUserFile', backref='g')
    re_host_id = relationship('Host', backref='h')


Base.metadata.create_all(engine)

Session = sessionmaker(bind=engine)
session = Session()

  

 表结构设计需要注意的几个点:

1, 多对多和一对多关系的设计

2, 联合唯一索引的使用

3, ENUM让用户选择输入其中一个

4, nullable=False之后用户依然可以输入空值,但是不能输入NULL,注意区别

 

 

记录用户操作

根据用户名和私钥登录服务器

tran = paramiko.Transport((hostname, port,))
tran.start_client()
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
key = paramiko.RSAKey.from_private_key_file(default_path)
tran.auth_publickey('yang', key)
 
# 打开一个通道
chan = tran.open_session()
# 获取一个终端
chan.get_pty()
# 激活器
chan.invoke_shell()
 
#########
# 利用sys.stdin,肆意妄为执行操作
# 用户在终端输入内容,并将内容发送至远程服务器
# 远程服务器执行命令,并将结果返回
# 用户终端显示内容
#########
 
chan.close()
tran.close()

  

先让我们来看一下paramiko进行远程交互的核心代码

while True:
    # 监视用户输入和服务器返回数据
    # sys.stdin 处理用户输入
    # chan 是之前创建的通道,用于接收服务器返回信息
    readable, writeable, error = select.select([chan, sys.stdin, ],[],[],1)
    if chan in readable:
        try:
            x = chan.recv(1024)
            if len(x) == 0:
                print '\r\n*** EOF\r\n',
                break
            sys.stdout.write(x)
            sys.stdout.flush()
        except socket.timeout:
            pass
    if sys.stdin in readable:
        inp = sys.stdin.readline()
        chan.sendall(inp)

  

再来看一下在linux中paramiko源码是如何跟远程主机交互的

def interactive_shell(chan):
    if has_termios:
        posix_shell(chan)
    else:
        windows_shell(chan)


def posix_shell(chan):
    import select

    oldtty = termios.tcgetattr(sys.stdin)
    try:
        tty.setraw(sys.stdin.fileno())
        tty.setcbreak(sys.stdin.fileno())
        chan.settimeout(0.0)

        while True:
            r, w, e = select.select([chan, sys.stdin], [], [])
            if chan in r:
                try:
                    x = u(chan.recv(1024))
                    if len(x) == 0:
                        sys.stdout.write('\r\n*** EOF\r\n')
                        break
                    sys.stdout.write(x)
                    sys.stdout.flush()
                except socket.timeout:
                    pass
            if sys.stdin in r:
                x = sys.stdin.read(1)
                if len(x) == 0:
                    break
                chan.send(x)

    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)

  

 源码默认是用户用户写完命令,一行一行的发送到远程主机,并把结果显示到界面。而我们需要改变一下,让终端监听用户输入的每一个键盘字符就开始记录。

 

修改paramiko源码

# 获取原tty属性
oldtty = termios.tcgetattr(sys.stdin)
try:
    # 为tty设置新属性
    # 默认当前tty设备属性:
    #   输入一行回车,执行
    #   CTRL+C 进程退出,遇到特殊字符,特殊处理。

    # 这是为原始模式,不认识所有特殊符号
    # 放置特殊字符应用在当前终端,如此设置,将所有的用户输入均发送到远程服务器
    tty.setraw(sys.stdin.fileno())
    chan.settimeout(0.0)

    while True:
        # 监视 用户输入 和 远程服务器返回数据(socket)
        # 阻塞,直到句柄可读
        r, w, e = select.select([chan, sys.stdin], [], [], 1)
        if chan in r:
            try:
                x = chan.recv(1024)
                if len(x) == 0:
                    print '\r\n*** EOF\r\n',
                    break
                sys.stdout.write(x)
                sys.stdout.flush()
            except socket.timeout:
                pass
        if sys.stdin in r:
            x = sys.stdin.read(1)
            if len(x) == 0:
                break
            chan.send(x)

finally:
    # 重新设置终端属性
    termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)

  

windows操作系统中由于利用不了终端,所以导致监听不到服务器返回的数据。在windwos中paramiko源码是以线程的方式来接受返回数据的,详见源码。

def windows_shell(chan):
    import threading

    sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n")

    def writeall(sock):
        while True:
            data = sock.recv(256)
            if not data:
                sys.stdout.write('\r\n*** EOF ***\r\n\r\n')
                sys.stdout.flush()
                break
            sys.stdout.write(data)
            sys.stdout.flush()

    writer = threading.Thread(target=writeall, args=(chan,))
    writer.start()

    try:
        while True:
            d = sys.stdin.read(1)
            if not d:
                break
            chan.send(d)
    except EOFError:
        # user hit ^Z or F6
        pass

  

 

贴一段我再写简单的jumperserver的核心代码

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import paramiko
import sys
import os
import socket
import getpass
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, Index, Table, DateTime, Boolean, Enum
from sqlalchemy import or_, and_
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy import create_engine
from sqlalchemy.sql import func
import yaml
import logging
# from sqlalchemy_utils import ChoiceType  # 需要安装sqlalchemy_utils这个组件
from paramiko.py3compat import u

config_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),"conf","jumper_server.conf")
info = yaml.load(open(config_file))
#print(info)

#--------定义日志-----------------
# 定义文件
log_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),"log","info.log")
file_1_1 = logging.FileHandler(log_file, 'a', encoding='utf-8')
fmt = logging.Formatter(fmt="%(asctime)s - %(name)s - %(levelname)s -%(module)s: %(message)s")
file_1_1.setFormatter(fmt)
if info["Loglevel"] == "INFO":
    logger1 = logging.Logger('s2', level=logging.INFO)
logger1.addHandler(file_1_1)

#logger2.info('2222')
#logger1.critical('1111')


# 首先建库,并且让mysql给权限登录
engine = create_engine("mysql+pymysql://{}:{}@{}:{}/s14".format(
    info["MysqlUser"], info["MysqlPassWord"], info["JumperServer"], info["MysqlPort"]), max_overflow=5)
Base = declarative_base()  # 创建orm的基类


class Group(Base):  # 主机组
    __tablename__ = 'group'
    group_id = Column(Integer, primary_key=True, autoincrement=True)
    group_name = Column(String(32), nullable=False, unique=True)


class Host(Base):  # 主机和主机组,是一对多关系:即一个主机必须属于一个组
    __tablename__ = 'host'
    host_id = Column(Integer, primary_key=True, autoincrement=True)
    host_ip = Column(String(32), nullable=False, unique=True)
    hostname = Column(String(255), nullable=False, unique=True)
    port = Column(String(32), nullable=False)
    group_id = Column(Integer, ForeignKey('group.group_id'), nullable=False)
    re_group_id = relationship('Group', backref='a')


class UserFile(Base):   # 用户表
    __tablename__ = 'user_file'
    user_id = Column(Integer, primary_key=True, autoincrement=True)
    user_name = Column(String(32), nullable=False, unique=True)
    gender = Column(Enum('male', 'female'), server_default='male')
    password = Column(String(32), nullable=False)


class UserToGroup(Base):  # 用户和主机组的关系表,多对多:一个用户可以属于多个主机组,一个主机组可以有多个用户
    __tablename__ = 'user_to_group'
    HU_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('user_file.user_id'), nullable=False)
    group_id = Column(Integer, ForeignKey('group.group_id'), nullable=False)
    __table_args__ = (
        UniqueConstraint('user_id', 'group_id'),
    )
    re_group_id = relationship('Group', backref='b')
    re_user_id = relationship('UserFile', backref='c')


class RemoteUserFile(Base):  # 主机用户
    __tablename__ = 'remote_user_file'
    user_id = Column(Integer, primary_key=True, autoincrement=True)
    user_name = Column(String(32), nullable=False, unique=True)
    password = Column(String(32), nullable=False)
    # AuthTypes = [
    #     ('p', 'SSH/Password'),
    #     ('r', 'SSH/KEY')
    # ]
    # auth_type = Column(ChoiceType(AuthTypes))
    auth_type = Column(Enum('SSH/Password', 'SSH/KEY'), default='SSH/Password')


class RemoteUserToHost(Base):  # 主机用户和主机的关系表,多对多的关系:一个主机可以有多个主机用户,一个用户有多个主机
    __tablename__ = 'remote_user_to_host'
    RH_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('remote_user_file.user_id'), nullable=False)
    host_id = Column(Integer, ForeignKey('host.host_id'), nullable=False)
    __table_args__ = (
        UniqueConstraint('user_id', 'host_id'),
    )
    re_user_id = relationship('RemoteUserFile', backref='d')
    re_host_id = relationship('Host', backref='e')


class AuditLog(Base):
    __tablename__ = 'audit_log'
    audit_id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey('user_file.user_id'), nullable=False)
    remote_user_id = Column(Integer, ForeignKey('remote_user_file.user_id'), nullable=False)
    host_id = Column(Integer, ForeignKey('host.host_id'), nullable=False)
    cmd = Column(String(65000))
    handle_time = Column(DateTime, server_default=func.now())
    re_user_id = relationship('UserFile', backref='f')
    re_remote_user_id = relationship('RemoteUserFile', backref='g')
    re_host_id = relationship('Host', backref='h')

Base.metadata.create_all(engine)

Session = sessionmaker(bind=engine)
session = Session()
# session.commit()
# session.close()

# windows does not have termios...
# 判断是windows还是linux操作系统,主要是根据有没有termios模块来判断的
try:
    import termios
    import tty
    has_termios = True
except ImportError:
    has_termios = False


def interactive_shell(chan, result1):
    if has_termios:
        posix_shell(chan, result1)
    else:
        windows_shell(chan, result1)


def posix_shell(chan, result1):
    import select

    oldtty = termios.tcgetattr(sys.stdin)
    try:
        tty.setraw(sys.stdin.fileno())
        tty.setcbreak(sys.stdin.fileno())
        chan.settimeout(0.0)
#        f = open('handle.log', 'a+')
        tab_flag = False
        temp_list = [] # 存放命令的列表
        while True:
            r, w, e = select.select([chan, sys.stdin], [], [])
            if chan in r:
                try:
                    x = chan.recv(1024).decode()
                    if len(x) == 0:
                        sys.stdout.write('\r\n*** EOF\r\n')
                        break
                    if tab_flag:  # 当用户输入回车键时,需要要靠返回值才能记录,输入的内容
                        if x.startswith('\r\n'):
                            pass
                        else:
                            temp_list.append(x)
#                            f.write(x)
#                            f.flush()
                        tab_flag = False
                    sys.stdout.write(x)
                    sys.stdout.flush()
                except socket.timeout:
                    pass
            if sys.stdin in r:
                x = sys.stdin.read(1)
                if len(x) == 0:
                    break
                if x == '\t':  # 当用户输入tab键时,不做任何记录。因此此时不会记录补全的内容,需要依靠返回值记录(如上)
                    tab_flag = True
                else:
                    if ord(x) == 13:  # 当用户输入回车键时,把他变成换行符
#                        f.write("*")
                        temp_list.append("*")
#                        temp_list.append("\n")
                    else:
                        temp_list.append(x)
#                        f.write(x)
#                        f.flush()
                chan.send(x)
        #把操作命令写入数据库
        cmd_str = ''.join(temp_list)
        session.add(AuditLog(
            user_id=result1[0][0],
            remote_user_id=result1[0][1],
            host_id=result1[0][2],
            cmd=cmd_str))
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)


def windows_shell(chan, result1):
    import threading

    sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n")

    def writeall(sock):
        while True:
            data = sock.recv(256)
            if not data:
                sys.stdout.write('\r\n*** EOF ***\r\n\r\n')
                sys.stdout.flush()
                break
            sys.stdout.write(data.decode())
            sys.stdout.flush()

    writer = threading.Thread(target=writeall, args=(chan,))
    writer.start()

    f1 = open("handle.log", "a+", encoding="utf-8")   # 记录命令
    try:
        while True:
            d = sys.stdin.read(1)
            if not d:
                break
            f1.write(d)
            f1.flush()
            chan.send(d)
        f1.close()
    except EOFError:
        # user hit ^Z or F6
        pass

def run():
        active_user = getpass.getuser()
        logger1.info('{} login in'.format(active_user))
        print('\033[32mActive_user: {}\n\033[0m'.format(active_user))
        # 思路:先通过用户名-->group_id--->host_id--->remote_user, host_ip
        r = session.query(UserFile).filter(UserFile.user_name == active_user).all()
        group_id_list = list(map(lambda x: x.group_id, r[0].c))
        ret1 = session.query(Host.host_id).filter(Host.group_id.in_(group_id_list)).all()
        host_id_list = list(zip(*ret1))[0]
        host_obj = session.query(Host).filter(Host.host_id.in_(host_id_list)).all()

        host_list = []
        for i in host_obj:
            for j in i.e:
                result = session.query(RemoteUserFile.user_name,
                                       RemoteUserFile.password, Host.host_ip
                                       ).filter(RemoteUserFile.user_id == j.user_id,
                                                Host.host_id == j.host_id).all()
                host_list.append({
                    'host': result[0][2],
                    'username': result[0][0],
                    'pwd': result[0][1]
                })
        for item in enumerate(host_list, 1):
            print(item[0], item[1]['username'], item[1]['host'])

        num = input('序号:')
        sel_host = host_list[int(num)-1]
        hostname = sel_host['host']
        username = sel_host['username']
        pwd = sel_host['pwd']
        print(hostname, username, pwd)

        tran = paramiko.Transport((hostname, 22,))
        tran.start_client()
        tran.auth_password(username, pwd)
        # 打开一个通道
        chan = tran.open_session()
        # 获取一个终端
        chan.get_pty()
        # 激活器
        chan.invoke_shell()

        # 查找出当前用户是谁,操作的哪台主机,以什么主机用户的身份去操作的,操作了什么命令,都需要记录到数据库中。
        result1 = session.query(
            UserFile.user_id, RemoteUserFile.user_id, Host.host_id).filter(
            UserFile.user_name == active_user,
            RemoteUserFile.user_name == username,
            Host.host_ip == hostname).all()
        logger1.info('begain to record: {}'.format(result1))
        print(result1)
#            session.add(AuditLog(
#            user_id=result1[0][0],
#            remote_user_id=result1[0][1],
#            host_id=result1[0][2],
#            cmd=cmd_str))

        interactive_shell(chan, result1)

        session.commit()
        session.close()
        chan.close()
        tran.close()


#if __name__ == '__main__':
#    run()
View Code

 

关于日志的审计还需要写一个后端程序,来查询用户都做了那些操作

posted @ 2017-04-20 18:00  早晨我在雨中采花  阅读(281)  评论(0编辑  收藏  举报