TOP

django 架构最佳实践相关整理

取消 CSRF 中间件

# -*- coding:utf-8 -*-
"""
取消Django的CSRF验证中间件
"""
from django.conf import settings
from django.utils.deprecation import MiddlewareMixin


class ApiDisableCsrfMiddleware(MiddlewareMixin):
    """
    api的请求都取消CSRF校验
    """

    def is_api_request(self, request):
        """
        判断是否是api的请求
        :param request: http request
        :return: True or False
        """
        path = request.path.lower()
        return path.startswith(settings.BASE_API_PREFIX)

    def process_request(self, request):
        if self.is_api_request(request):
            # 给request设置属性,不要检验csrf token
            setattr(request, '_dont_enforce_csrf_checks', True)

跨域问题处理

INSTALLED_APPS = [
    # 'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    # 'django.contrib.sessions',
    # 'django.contrib.messages',
    'django.contrib.staticfiles',
    # third app
    'rest_framework',
    'corsheaders',
    'django_filters',
    # self app
    'xxxx.apps.xxxxConfig',
    .......
]
MIDDLEWARE = [
    # 'django.middleware.security.SecurityMiddleware',
    # 'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    # 'django.contrib.auth.middleware.AuthenticationMiddleware',
    # 'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
    # cors
    'corsheaders.middleware.CorsMiddleware',
    # self middleware
    .....
]
# 跨域访问相关配置
CORS_ORIGIN_ALLOW_ALL = True
CORS_URLS_REGEX = r'^/xxx/api/.*$'
CORS_ALLOW_CREDENTIALS = True

CORS_ALLOW_METHODS = [
    'OPTIONS',
    'GET',
    'POST',
    'PUT',
    'PATCH',
    'DELETE',
]

CORS_ALLOW_HEADERS = [
    'accept',
    'accept-encoding',
    'authorization',
    'content-type',
    'dnt',
    'origin',
    'user-agent',
    'x-csrftoken',
    'x-requested-with',
    'access-control-allow-headers',
]

多配置参数

参数存放位置  config/db.ini 文件

[system]
debug=true
[mysql]
db_name=yangtuo_beedev
db_user=root
db_password=123456
db_host=127.0.0.1
db_port=33069
[redis]
redis_host_port=127.0.0.1:6379

读取配置参数

from configparser import ConfigParser


try:
    cfg = ConfigParser()
    db_path = os.path.join(BASE_DIR, 'config/db.ini')
    cfg.read(db_path)

    # system
    debug = cfg.get('system', 'debug')

    # mysql
    db_name = cfg.get('mysql', 'db_name')
    host = cfg.get('mysql', 'db_host')
    port = cfg.get('mysql', 'db_port')
    user = cfg.get('mysql', 'db_user')
    password = cfg.get('mysql', 'db_password')

    # redis
    redis_host_port = cfg.get('redis', 'redis_host_port')
except Exception as e:
    db_name = host = port = user = password = redis_host_port = None
    debug = False
    print(f"配置文件读取出错,原因:{e},开始使用环境变量读取")

配置参数生效

DATABASES = {
    'default': {
        'ENGINE': 'django.db.backends.mysql',
        'NAME': os.environ.get("CICD_DEVELOP_DB", db_name),
        'USER': os.environ.get("MYSQL_USER", user),
        'PASSWORD': os.environ.get("MYSQL_PASSWORD", password),
        'HOST': os.environ.get("MYSQL_HOST", host),
        'PORT': os.environ.get("MYSQL_PORT", port),
        'OPTIONS': {
            'init_command': 'SET default_storage_engine=INNODB',
        }
    }
}

时区配置

TIME_ZONE = 'UTC'
# TIME_ZONE = 'Asia/Shanghai'

USE_TZ = True
# USE_TZ = False

restframework 配置

REST_FRAMEWORK = {
    # 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination',
    'DEFAULT_PAGINATION_CLASS': 'dorylus.pagination.SelfPagination',
    'PAGE_SIZE': 10,
    'DEFAULT_RENDERER_CLASSES': (
        'rest_framework.renderers.JSONRenderer',
        # 为了调试,需要BrowsableAPIRenderer,正式环境需要注释下面这行
        'rest_framework.renderers.BrowsableAPIRenderer',
    ),
    # DatetimeField设置时间格式, 注释掉后以 utc 格式保存
    # 'DATETIME_FORMAT': '%Y-%m-%d %H:%M:%S',
    # 用户认证
    'DEFAULT_PERMISSION_CLASSES': [],
    'DEFAULT_AUTHENTICATION_CLASSES': (
        # JWT认证配置
        'rest_framework_jwt.authentication.JSONWebTokenAuthentication',
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication',
    ),
    'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.coreapi.AutoSchema',
}

JWT_AUTH = {
    # 过期时间 1天
    'JWT_EXPIRATION_DELTA': datetime.timedelta(days=1),
    'JWT_RESPONSE_PAYLOAD_HANDLER': 'account.auth.jwt_response_handler',
    'JWT_AUTH_HEADER_PREFIX': 'Bearer'  # 默认是JWT
}

redis 相关配置

# Redis相关配置
REDIS_HOST_PORT = os.environ.get("REDIS_HOST_PORT", redis_host_port)
# REDIS_HOST_PORT = "127.0.0.1:6479"
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", 'xxxxx')
REDIS_CONFIG = {
    "host": REDIS_HOST_PORT.split(':')[0],
    "port": REDIS_HOST_PORT.split(':')[1],
    "db": 10,
    "password": REDIS_PASSWORD
}

django-redis 使用

pip install django-redis
# 使用django-redis 作为数据cache
CACHES = {
 "default": {
     "BACKEND": "django_redis.cache.RedisCache",
     "LOCATION": "redis://:{}@{}/10".format(REDIS_PASSWORD, REDIS_HOST_PORT),
     "OPTIONS": {
         "CLIENT_CLASS": "django_redis.client.DefaultClient",
         "CONNECTION_POOL_KWARGS": {"max_connections": 100},
         "DECODE_RESPONSES": True,
         "PASSWORD": REDIS_PASSWORD,
     }
 }
}

代码中使用

from django.core.cache import cache #引入缓存模块
cache.set('v', '555', 60*60)      #写入key为v,值为555的缓存,有效期30分钟
cache.has_key('v') #判断key为v是否存在
cache.get('v')     #获取key为v的缓存

celery 相关配置

CELERY_BROKER_URL = 'redis://:{}@{}/11'.format(REDIS_PASSWORD, REDIS_HOST_PORT)
CELERY_RESULT_BACKEND = 'redis://:{}@{}/12'.format(REDIS_PASSWORD, REDIS_HOST_PORT)
CELERY_ACCEPT_CONTENT = ['application/json', 'pickle']
CELERY_RESULT_SERIALIZER = 'json'
CELERY_TASK_SERIALIZER = 'pickle'
"""
使用方式:
1. 进入项目源码目录
2. 启动Celery:celery -A cicd worker -l info
"""

import os

from celery import Celery
from django.conf import settings

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'cicd.settings')

app = Celery('cicd')

app.config_from_object('django.conf:settings', namespace='CELERY')
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
from .celery import app as celery_app

定义

@shared_task
def xxxxx(application: Application):
    pass

执行

xxxxx.delay(instance)

日志相关配置

settings 配置

# 默认的日志文件路径
LOG_FILES_DIR_PATH = os.environ.get("LOG_FILES_DIR_PATH", os.path.join(BASE_DIR, "logs"))
if not os.path.exists(LOG_FILES_DIR_PATH):
    os.makedirs(LOG_FILES_DIR_PATH, exist_ok=True)

# Django日志相关设置
from dorylus.log import LOGGING


# LOGGING相关配置
LOGGING = LOGGING

日志的定义

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os

from django.conf import settings
from django.utils import timezone as datetime

LOG_DIR = os.path.join(settings.BASE_DIR, 'logs')
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR, exist_ok=True)

now_day = str(datetime.now())[:10]

LOGGING = {
    'version': 1,
    'disable_existing_loggers': True,
    'formatters': {
        'verbose': {
            # 'format': '{levelname} {asctime} {module} {process:d} {thread:d} {filename} {lineno:d} {message}',
            'format': '{asctime} {levelname} {module} {filename} {funcName} {lineno:d} {message}',
            'style': '{',
        },
        'simple': {
            'format': '{levelname} {asctime} {message}',
            'style': '{',
        },
    },
    'filters': {
        'require_debug_true': {
            '()': 'django.utils.log.RequireDebugTrue',
        },
    },
    'handlers': {
        'console': {
            'level': 'INFO',
            'class': 'logging.StreamHandler',
            'formatter': 'simple'
        },
        'dorylus_file': {
            'level': 'INFO',
            'class': 'logging.FileHandler',
            'filename': os.path.join(LOG_DIR, f'dorylus{now_day}.log'),
            'formatter': 'verbose'
        },
        'lib_file': {
            'level': 'INFO',
            'class': 'logging.FileHandler',
            'filename': os.path.join(LOG_DIR, f'lib{now_day}.log'),
            'formatter': 'verbose'
        },
        'request_file': {
            'level': 'INFO',
            'class': 'logging.FileHandler',
            'filename': os.path.join(LOG_DIR, f'request{now_day}.log'),
            'formatter': 'verbose'
        },
        'task_file': {
            'level': 'INFO',
            'class': 'logging.FileHandler',
            'filename': os.path.join(LOG_DIR, f'task{now_day}.log'),
            'formatter': 'verbose'
        },
        'process_file': {
            'level': 'INFO',
            'class': 'logging.FileHandler',
            'filename': os.path.join(LOG_DIR, f'process{now_day}.log'),
            'formatter': 'verbose'
        },
    },
    'loggers': {
        'django': {
            'handlers': ['console'],
            'level': 'INFO',
            'propagate': True,
        },
        'django.server': {
            'handlers': ['console'],
            'level': 'INFO',
            'propagate': False,
        },
        'django.request': {
            'handlers': ['console'],
            'level': 'INFO',
            'propagate': False,
        },
        'dorylus': {
            'handlers': ['dorylus_file'],
            'level': 'INFO',
            'propagate': True,
        },
        'dorylus.request': {
            'handlers': ['request_file', 'dorylus_file'],
            'level': 'INFO',
            'propagate': False,
        },
        'dorylus.lib': {
            'handlers': ['console', 'lib_file', 'dorylus_file'],
            'level': 'INFO',
            'propagate': False,
        },
        'dorylus.task': {
            # 'handlers': ['console', 'task_file', 'dorylus_file'],
            'handlers': ['task_file'],
            'level': 'INFO',
            'propagate': False,
        },
        'dorylus.process': {
            'handlers': ['console', 'process_file', 'dorylus_file'],
            'level': 'INFO',
            'propagate': False,
        }
    }
}

 日志的重写, 把参数相关的展示出来

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import logging


class LoggerMixin(object):
    logger = logging.getLogger('dorylus.request')

    def initial(self, request, *args, **kwargs):
        super(LoggerMixin, self).initial(request, *args, **kwargs)
        method = request.method.upper()
        username = request.user.username if request and hasattr(request, 'user') else ''
        if method == 'GET':
            params = request.query_params
        else:
            params = request.data
        self.logger.info(' '.join([username, request.path, method, json.dumps(params)]))
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging

from .log import LOGGING
from .mixins import LoggerMixin

dorylus_log = logging.getLogger('dorylus')
task_log = logging.getLogger('dorylus.task')
lib_log = logging.getLogger('dorylus.lib')
request_log = logging.getLogger('dorylus.request')
process_log = logging.getLogger('dorylus.process')

__all__ = ['LOGGING', 'LoggerMixin', 'task_log', 'lib_log', 'request_log', 'process_log', 'dorylus_log']

在代码中的使用

from dorylus.log import request_log

request_log.error("....")

邮箱相关设置

# 邮箱配置
EMAIL_SMTP_HOST = 'smtp.exmail.qq.com'
EMAIL_SMTP_PORT = 465
EMAIL_USER = os.environ.get("EMAIL_USER", 'ops@linklogis.com')
EMAIL_PASSWORD = os.environ.get("EMAIL_PASSWORD", 'Wczyyr@9293')

邮箱封装

# -*- coding:utf-8 -*-
"""
发送邮件
"""
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
from email.header import Header

from django.conf import settings


class EmailApi:
    """
    邮件操作API
    """
    def __init__(self, smtp_host=settings.EMAIL_SMTP_HOST,
                 smtp_port=settings.EMAIL_SMTP_PORT, user=settings.EMAIL_USER,
                 password=settings.EMAIL_PASSWORD):
        self.smtp_host = smtp_host
        self.smtp_port = smtp_port
        self.user = user
        self.password = password
        self.smpt = None
        self.connected = False
        # self.connect()

    def connect(self):
        try:
            # 连接对象
            self.smtp = smtplib.SMTP_SSL(host=self.smtp_host, port=self.smtp_port)
            # self.smtp.connect(self.smtp_host, self.smtp_port)
            results = self.smtp.login(user=self.user, password=self.password)
            if isinstance(results, tuple) and len(results) == 2:
                if results[0] == 235:
                    self.connected = True
                    return
            print("email connect: {}".format(results))
            self.connected = False

        except Exception as e:
            print(e)
            return False

    def send_email(self, title, receivers, content, category="text", files=None):
        """
        发送邮件
        @param title:
        @param receivers:
        @param content:
        @param category:
        @param files: 文件
        @return:
        """
        try:
            # 判断是否连接
            if not self.connected:
                self.connect()

            # receivers需要是个列表
            if isinstance(receivers, str):
                receivers = receivers.split(",")

            # 构造邮件
            subtype = 'plain'
            if category == 'text':
                subtype = 'plain'
            elif category == 'html':
                subtype = category

            # 消息内容部分
            message = MIMEText(content, subtype, 'utf-8')

            # 邮件主体内容
            body = MIMEMultipart()
            body.attach(message)
            body['From'] = Header('BeeCloud<beecloud@linklogis.com>', 'utf-8')
            body['To'] = Header(",".join(receivers), 'utf-8')
            body['Subject'] = Header(title, 'utf-8')

            # 附件部分
            if files and isinstance(files, list):
                # 文件附件
                for f in files:
                    if hasattr(f, 'read'):
                        item = MIMEApplication(f.read())
                        item.add_header("content-disposition", "attachment", filename=f.name)
                    elif isinstance(f, dict) and 'path' in f and 'name' in f:
                        try:
                            f_obj = open(f['path'], 'rb')
                            item = MIMEApplication(f_obj.read())
                            item.add_header("content-disposition", "attachment", filename=f['name'])
                            f_obj.close()
                        except Exception as e:
                            print('打开文件出错:', str(e))
                    else:
                        print('未知文件类型:', f)
                        item = None

                    if item:
                        body.attach(item)
                # 发送
                result = self.smtp.sendmail(self.user, receivers, body.as_string())
                return result
            else:
                # 发送消息
                result = self.smtp.sendmail(self.user, receivers, body.as_string())
                return result

        except Exception as e:
            print(e)
            return False

    def close(self):
        try:
            if self.smpt:
                self.smtp.close()
        except Exception as e:
            print(e)

 

在代码中的使用

                api = EmailApi()
                result = api.send_email(
                    title=self.title,
                    content=self.content,
                    receivers=receivers, category=self.category,
                    files=files,
                )
                results += "{}\n{}".format(results, result)

 

基础用法

from django.core.mail import send_mail

send_mail("title", "msg", None, [user1, user2], fail_silently=False)

审计日志中间件

class LoggingViewSetMixin:
    """
    审计日志记录的中间件
    """
    # 保密字段,如果是修改这几个字段,则会是:保密字段(new/old)
    secret_fields = ('password', 'admin_pwd')
    # Retrieve日志开关
    retrieve_log_toogle = False

    def get_ip_address(self):
        try:
            ip_address_keys = ['HTTP_X_FORWARDED_FOR', 'REMOTE_ADDR']
            for k in ip_address_keys:
                if self.request.META.get(k):
                    return self.request.META.get(k)
            return None
        except Exception:
            return None

    def retrieve(self, request, *args, **kwargs):
        result = super().retrieve(request, *args, **kwargs)

        # 如果设置了记录retrieve,那么则记录查询IP
        if self.retrieve_log_toogle:
            # 第1步:获取信息
            # 发起请求的用户
            user = self.request.user
            # 获取到实例对象
            instance = self.get_object()
            # 对象model对应的ContentType
            content_type = ContentType.objects.get_for_model(instance)
            object_id = instance.pk
            object_repr = repr(instance)

            try:
                # 第3步:写入日志
                # message = "查看对象:{}".format(instance.__class__)
                message = "查看对象:{}".format(instance)
                AuditLog.objects.create(
                    user=user,
                    action=1,
                    app_label=content_type.app_label,
                    model=content_type.model,
                    object_id=object_id,
                    object_repr=object_repr,
                    message_type="text",
                    message=message,
                    address=self.get_ip_address(),
                )
            except Exception:
                pass

        # 返回结果
        return result

    def perform_create(self, serializer):
        super().perform_create(serializer)
        try:
            # 发起请求的用户
            user = self.request.user
            # 这个对象的Model
            model = serializer.Meta.model
            # model对应的ContentType
            content_type = ContentType.objects.get_for_model(model)
            # 消息从data中提取
            data = json.loads(json.dumps(serializer.data))
            for field in data:
                if field in self.secret_fields: data[field] = "保密字段"
            obj = model.objects.get(pk=data['id'])

            AuditLog.objects.create(
                user=user,
                action=2,
                app_label=content_type.app_label,
                model=content_type.model,
                object_id=obj.id,
                object_repr=repr(obj),
                message=json.dumps(data),
                address=self.get_ip_address(),
            )
        except Exception:
            pass

    def perform_update(self, serializer):
        """
        更新对象日志
        :param serializer: 序列化对象
        """
        # 第1步:先获取到修改前的对象, 得到老的数值
        # 1-1: 得到老的对象,处理处理,后续比较会用到
        obj_old = self.get_object()
        obj_old_dic = {}

        try:
            # 1-2:迭代每个validated_data字段,获取老对象这个字段的值
            for field in serializer.validated_data:
                field_v_old = getattr(obj_old, field)
                # 判断field是不是多对多关系类型
                if field_v_old.__repr__().find('ManyRelatedManager') > 0:
                    field_v_old_list_pk = list(field_v_old.values_list('pk', flat=True))
                    obj_old_dic[field] = field_v_old_list_pk
                else:
                    # 如果不是多对多的关系,直接设置其为这个值,后面字符串输出要用,field_v_old.__repr__()
                    obj_old_dic[field] = field_v_old
        except Exception:
            # 取老对象的值,如果出现异常,依然要调用父类的方法,记得要return
            super().perform_update(serializer)
            return

        # 第2步:执行父类的方法, 出错直接会返回不执行后续步骤了的
        super().perform_update(serializer)

        # 第3步:获取新的对象和其它需要用到的数据
        obj_new = self.get_object()

        # 发起请求的用户
        user = self.request.user
        # 这个对象的Model
        model = serializer.Meta.model
        # model对应的ContentType
        content_type = ContentType.objects.get_for_model(model)
        # 消息从data中提取
        try:
            data = json.loads(json.dumps(serializer.data))
        except:
            data = serializer.data

        # 第4步:判断哪些字段变更了
        # 4-1: validated_data
        validated_data = serializer.validated_data

        message = []

        try:
            # 第5步:迭代每个校验过的字段
            for field in validated_data:
                # 5-1:获取老的字段值和新的字段值
                # obj_old_dic:老对象的值,而且多对关系的数据已经改成了pk列表
                field_v_old = obj_old_dic[field]
                field_v_new = getattr(obj_new, field)

                # 5-2:判断field是不是多对多关系类型
                if field_v_new.__repr__().find('ManyRelatedManager') > 0:
                    # 说明这个字段是多对多的关系,判断其是否相等要用.all()
                    # 5-4: 多对多关系根据主键的列表,判断是否相等
                    # list_pk_old = list(field_v_old.values_list('pk', flat=True))
                    list_pk_new = list(field_v_new.values_list('pk', flat=True))
                    if field_v_old != list_pk_new:
                        # print({'field': field, 'value': data[field]})
                        # 5-4:构造消息
                        message_i = {
                            'action': 'changed',
                            'field': field,
                            'value_new': '值修改了' if field in self.secret_fields else data[field],
                            'value_old': '值修改了' if field in self.secret_fields else field_v_old
                        }
                        message.append(message_i)
                    # else:
                    #     print('关系型数据库没变', data[field])
                else:
                    # 不是多对多关系,就直接判断值是否相等
                    if field_v_old != field_v_new:
                        # 5-4:构造消息
                        message_i = {
                            'action': 'changed',
                            'field': field,
                            'value_new': '保密字段(new)' if field in self.secret_fields else data[field],
                            'value_old':
                                '保密字段(old)' if field in self.secret_fields else field_v_old.__repr__()
                        }
                        message.append(message_i)
                        # print({'field': field, 'value': data[field]})

            # 第6步:写入日志
            if message:
                AuditLog.objects.create(
                    user=user,
                    action=3,
                    app_label=content_type.app_label,
                    model=content_type.model,
                    object_id=obj_new.pk,
                    object_repr=repr(obj_new),
                    message=json.dumps(message),
                    address=self.get_ip_address(),
                )
        except Exception as e:
                # print(e)
                pass

    def perform_destroy(self, instance):
        """删除对象"""
        # 第1步:获取信息
        # 发起请求的用户
        user = self.request.user
        # 对象model对应的ContentType
        content_type = ContentType.objects.get_for_model(instance)
        object_id = instance.pk
        object_repr = repr(instance)

        # 第2步:执行父级的perform_destroy方法
        super().perform_destroy(instance)

        try:
            # 第3步:写入日志
            message = "删除对象:{}".format(instance.__class__)
            AuditLog.objects.create(
                user=user,
                action=4,
                app_label=content_type.app_label,
                model=content_type.model,
                object_id=object_id,
                object_repr=object_repr,
                message_type="text",
                message=message,
                address=self.get_ip_address(),
            )
        except Exception:
            pass

封装视图

class GenericViewSet(QuerysetMixin, RGenericViewSet):
    serializer_detail_class = None  # 详情序列化类
    filter_backends = (DjangoFilterBackend, SearchFilter, OrderingFilter)
    search_fields = []  # 可搜索的字段
    pagination_class = SelfPagination
    permission_classes = (IsAuthenticated,)

    # 增加3个字段
    serializer_class_set = ()  # 当一个api支持多个序列化时,就添加到这里,根据序号来选择
    serializer_class_index = 0
    serializer_class_index_key = 'detail'  # 当detail有冲突的时候,就可以使用个其它的值

    def get_serializer(self, *args, **kwargs):
        """
        Return the serializer instance that should be used for validating and
        deserializing input, and for serializing output.
        """

        query_params = self.request.query_params
        if self.serializer_class_index_key in query_params:
            serializer_class_index = query_params.get(self.serializer_class_index_key, '0')
        else:
            serializer_class_index = query_params.get('detail', '0')

        # get serizlizer class
        serializer_class = self.get_serializer_class(serializer_class_index)
        kwargs['context'] = self.get_serializer_context()
        return serializer_class(*args, **kwargs)

    def get_serializer_class(self, serializer_class_index=0):
        """
        Return the class to use for the serializer.
        Defaults to using `self.serializer_class`.

        You may want to override this if you need to provide different
        serializations depending on the incoming request.

        (Eg. admins get full serialization, others get basic serialization)
        """
        # 1. check serializer_class_index is Number and not None
        if serializer_class_index is not None:
            # we get serializer_class_index
            try:
                serializer_class_index = int(serializer_class_index)
            except Exception as e:
                print(str(e))
                # reset it to zero
                serializer_class_index = 0

            if len(self.serializer_class_set) == 0 and self.serializer_class:
                self.serializer_class_set = (self.serializer_class,)

            assert len(self.serializer_class_set) > serializer_class_index, (
                    "'%s' serializer_class_set length need greater than serializer_class_index"
                    "or override the `get_serializer_class()` method."
                    % self.__class__.__name__
            )

            return self.serializer_class_set[serializer_class_index]

        else:
            if not self.serializer_class and self.serializer_class_set:
                self.serializer_class = self.serializer_class_set[0]

            assert self.serializer_class is not None, (
                    "'%s' should either include a `serializer_class` attribute, "
                    "or override the `get_serializer_class()` method."
                    % self.__class__.__name__
            )

            return self.serializer_class


class ModelViewSet(LoggingViewSetMixin,
                   CreateModelMixin,
                   RetrieveModelMixin,
                   UpdateModelMixin,
                   DestroyModelMixin,
                   ListModelMixin,
                   GenericViewSet):
    pass

对称加密

# -*- coding:utf-8 -*-
"""
密码相关的工具
1. random_password: 随机生成一个密码(默认16位)
2. Cryptography: 对称加密
"""
import random
import string
from Crypto.Cipher import AES
from binascii import b2a_hex, a2b_hex

from django.conf import settings


def random_password(length=16):
    """
    随机获取N位密码
    :param length: 密码长度,默认16位
    :return: length位字符
    """
    strings = string.ascii_letters + string.digits
    # 方式一:
    # password = ''.join(random.SystemRandom().choice(strings) for _ in range(length))

    # 方式二:
    password = ''.join(random.sample(strings, length))
    return password


class Cryptography:
    def __init__(self, key=None):
        self.key = key or settings.PASSWORD_KEY
        self.mode = AES.MODE_ECB
        self.cryptor = AES.new(self.pad(self.key).encode(), self.mode)

    # 加密函数,如果text不是16的倍数【加密文本text必须为16的倍数!】,那就补足为16的倍数
    @staticmethod
    def pad(text):
        # 加密内容需要长达16位字符,所以进行空格拼接
        while len(text) % 16 != 0:
            text += ' '
        return text

    def encrypt(self, text):
        # 这里密钥key 长度必须为16(AES-128)、24(AES-192)、或32(AES-256)Bytes 长度.目前AES-128足够用
        # 加密的字符需要转换为bytes
        # 因为AES加密时候得到的字符串不一定是ascii字符集的,输出到终端或者保存时候可能存在问题
        # 所以这里统一把加密后的字符串转化为16进制字符串
        return b2a_hex(self.cryptor.encrypt(self.pad(text).encode())).decode()

    def decrypt(self, text):
        # 解密后,去掉补足的空格用strip() 去掉
        res = self.cryptor.decrypt(a2b_hex(text))
        return res.decode().strip(' ')

    # 判断value是否是加密后的值
    def check_can_decrypt(self, value):
        try:
            de_p = self.decrypt(value)
            return True, de_p
        except ValueError:
            return False, None
        except Exception as e:
            print('解密错误:', e)
            return False, None


if __name__ == '__main__':
    pc = Cryptography("AAAA")  # 初始化密钥

    e = pc.encrypt("0123456789ABCDEF")
    d = pc.decrypt(e)
    print(e, d)
    print(pc.check_can_decrypt(e))
    print(pc.check_can_decrypt(d))

    e = pc.encrypt("00000000000000000000000000")
    d = pc.decrypt(e)
    print(e, d)

文件上传

class FileStorage(FileSystemStorage):

    def __init__(self, location=settings.FILE_STORAGE_ROOT, base_url=settings.FILE_STORAGE_URL):
        super().__init__(location, base_url)

    def _save(self, name, content):
        ext = os.path.splitext(name)[1]
        d = os.path.dirname(name)
        filename = time.strftime('%d%H%M%S')
        filename = '{}_{}_{}'.format(filename, random.randint(0, 1000), random.randint(0, 1000))
        name_new = os.path.join(d, filename + ext)
        self.file_permissions_mode = 0o644
        return super()._save(name=name_new, content=content)
    file = models.FileField(verbose_name="文件", help_text="上传文件", upload_to="files/%Y/%m",
                            storage=FileStorage())

请求封装

import requests
from dorylus.log import request_log


class APIClient(object):
    @staticmethod
    def start_requests(url, data=None, headers=None, files=None, methods='get'):
        if not headers:
            headers = {}
        try:
            if methods == 'get':
                response = requests.get(url=url, params=data, headers=headers)
            elif methods == 'post':
                response = requests.post(url=url, json=data, headers=headers)
            elif methods == 'put':
                if files:
                    headers['Content-Type'] = 'application/octet-stream'
                    response = requests.put(url=url, files=files, headers=headers)
                else:
                    response = requests.put(url=url, json=data, headers=headers)
            elif methods == 'delete':
                response = requests.delete(url=url, headers=headers)
            elif methods == 'patch':
                response = requests.patch(url=url, json=data, headers=headers)
            else:
                response = {'error': '请求methods错误'}
        except requests.exceptions.ConnectionError as err:
            request_log.error(str(err))
            return "连接用户服务失败", None
        except Exception as err:
            request_log.error(str(err))
            return str(err), None
        return None, response

    @staticmethod
    def response(response, res_json=True):
        if 200 <= response.status_code < 400:
            if res_json:
                # noinspection PyBroadException
                try:
                    result = response.json()
                except Exception as _:
                    result = response.text
            else:
                result = response.text
        else:
            # noinspection PyBroadException
            try:
                result = response.json()
            except Exception as _:
                result = response.text
        return [result, response.status_code]

    def get(self, url, params=None, parse=False, headers=None):
        err, response = self.start_requests(url=url, data=params, headers=headers)
        if err:
            return err, 500
        res, code = self.response(response)
        if parse:
            return self.result_parse(res)
        return res, code

    def post(self, url, data=None, parse=False, headers=None):
        err, response = self.start_requests(url=url, data=data, methods='post', headers=headers)
        if err:
            return err, 500
        res, code = self.response(response)
        if parse:
            return self.result_parse(res)
        return res, code

    @staticmethod
    def result_parse(data):
        if isinstance(data, str):
            return -1, data
        else:
            if data.get("code") == -1:
                return -1, data.get("msg", {})
            return data.get("code", -1), data.get("result", {})

基础Model

# -*- coding:utf-8 -*-
from django.db import models
from django.contrib.contenttypes.models import ContentType
from django.utils import timezone
from dorylus.tools.password import Cryptography


class BaseModel(models.Model):
    """
    基础Model
    添加了deleted字段,覆写了delete方法
    添加了处理不允许使用外键的处理方式
    """
    # 需要执行的删除函数的任务(函数列表)
    delete_tasks = []
    # 需要加密的字段
    SECRET_FIELDS = []
    deleted = models.BooleanField(verbose_name="删除", blank=True, default=False)
    time_added = models.DateTimeField(verbose_name="添加时间", blank=True, auto_now_add=True, null=True)

    @staticmethod
    def strftime(fmt='%Y%m%d%H%M%S'):
        return timezone.datetime.now(tz=timezone.utc).strftime(fmt)

    def set_decrypt_value(self):
        # 加密存储的字段
        if self.SECRET_FIELDS and isinstance(self.SECRET_FIELDS, (list, tuple)):
            p = Cryptography()

            # 自己配置SECRET_FIELDS, BaseModel中设置的是[]
            for i in self.SECRET_FIELDS:
                value = getattr(self, i)
                if i and value:
                    # 判断是否是加密的
                    success, _ = p.check_can_decrypt(value)
                    if not success:
                        setattr(self, i, p.encrypt(text=value))
            # 对需要加密的字段加密完毕

    def get_decrypt_value(self, field: str) -> str:
        """
        获取解密后的值
        """
        value = getattr(self, field)
        if value:
            p = Cryptography()
            success, de_p = p.check_can_decrypt(value=value)
            if success:
                return de_p
            else:
                return value

    def save(self, force_insert=False, force_update=False, using=None,
             update_fields=None):
        # 调用设置加密字段的方法
        self.set_decrypt_value()
        # 调用父类的save方法
        return super().save(force_insert=force_insert, force_update=force_update, using=using,
                            update_fields=update_fields)

    def get_relative_object_by_model(self, model, args=None, value=None, many=False, field="pk"):
        """
        通过model获取关系的对象
        多值还是单值,自行判断
        :param model: 类,必须是django的类
        :param args: 过滤条件,必须是dict
        :param value: 过滤检索的值,有可能是列表
        :param many: 是否是多值
        :param field: 过滤的字段,默认是pk
        :return:
        """
        # 1. 构造检索的数据
        # field自己处理,比如:id__in,等
        if args and isinstance(args, dict):
            data = args
        elif value:
            data = {field: value}
        else:
            raise ValueError("args或者value必须传递一个")

        # 2. 判断model是否正确
        if not issubclass(model, models.Model):
            raise ValueError("传入的model必须是django.db.Modeel的子类")

        # 3. 开始过滤数据
        if many:
            queryset = model.objects.filter(**data)
            return queryset
        else:
            # 这个直接用get是有可能报错的(比如根据一条字段,得到了2条数据),这个传入端去处理
            obj = model.objects.get(**data)
            return obj

    def get_relative_object_by_content_type(self, app_label, model, args=None, value=None, many=False, field="pk"):
        # 1. 先获取到model
        ct = ContentType.objects.get(app_label=app_label, model=model)
        model_cls = ct.model_class()

        # 2. 获取对象
        return self.get_relative_object_by_model(model=model_cls, args=args, value=value, many=many, field=field)

    def do_delete_action(self):
        """
        运维系统,所有数据都是只标记删除,而不做物理删除
        当我们删除的时候,需要执行可能需要执行额外的处理
        比如:把某个字段加个时间戳,删除额外的关联数据,比如关联了我的数据也需要删除掉
        """
        if self.deleted:
            # 1. 判断是否已经删除,如果已经是标记删除的了,那我们就直接返回
            return
        else:
            # 2. 遍历需要执行的删除的任务,我们对其进行删除
            if isinstance(self.delete_tasks, (list, tuple)):
                for i in self.delete_tasks:
                    # 别循环调用自己了
                    if i and hasattr(self, i) and i != 'do_delete_action':
                        task_func = getattr(self, i)
                        # 判断一下这个是否是函数
                        if hasattr(task_func, '__call__'):
                            # 我们执行调用函数
                            task_func()
                        else:
                            print('{}不是可调用的函数', i)
            self.deleted = True
            self.save(update_fields=('deleted',))

    def delete(self, using=None, keep_parents=False):
        # 1. 判断是否有do_delete_action的方法
        if hasattr(self, "do_delete_action"):
            self.do_delete_action()
        else:
            super().delete(using=using, keep_parents=keep_parents)

    class Meta:
        abstract = True

自定义分页

# -*- coding:utf-8 -*-
from rest_framework.pagination import PageNumberPagination


class SelfPagination(PageNumberPagination):
    """
    Rest FrameWork 自定义分页器类
    """
    page_size = 10
    max_page_size = 1000
    page_size_query_param = 'page_size'

    def paginate_queryset(self, queryset, request, view=None):
        if request.query_params.get('page'):
            self.django_paginator_class._check_object_list_is_ordered = lambda x: None
            return super().paginate_queryset(queryset, request, view=view)
        else:
            # 为传入 page, 取出全部数据
            return None

requests 套壳

用于统一入口, 透传API

# -*- coding:utf-8 -*-
import requests
from django.http.response import HttpResponse
from django.conf import settings

from dorylus.views import APIView


class KubeApiView(APIView):

    def redirect_to_kube_server(self, request):

        user = request.user
        # 处理url:
        kube_url = '{}/{}'.format(settings.BEC_KUBE_SERVER, request.path.replace('/dev/api/v1/cloud/kube/', ''))
        kube_url = kube_url.replace('/tenants/', '/tenants/{}/'.format(user.tenant_id))
        # print(kube_url)

        headers = {}
        for k, v in request.headers.items():
            if k in ['Authorization', 'Host', 'User-Agent', 'Accept-Encoding', 'Accept', 'Content-Type']:
                headers[k] = v
            # else:
            #     print(k, v)

        if 'Host' not in headers or headers['Host'].startswith('127.0.0.1'):
            host = settings.BEC_KUBE_SERVER
            host = host.replace('https://', '')
            host = host.replace('http://', '')
            host = host.replace('/bke-apiserver/apis/v1', '')
            headers['Host'] = host

        if request.path == '/dev/api/v1/cloud/kube/tenants/projects//apps':
            pass

        body = {
             "url": kube_url,
             "method": request.method,
             "headers": headers,
             # files=request._files,
             "params": request.query_params,
             "cookies": request.COOKIES,
             "stream": True
        }
        if 'Content-Type' in headers and headers['Content-Type'].startswith("application/json"):
            body['json'] = request.data
        else:
            body['data'] = request.data

        try:
            resp = requests.request(
                **body
            )
            # print(resp)

            excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
            resp_headers = [(name, value) for (name, value) in resp.raw.headers.items()
                            if name.lower() not in excluded_headers]

            return HttpResponse(content=resp.content, status=resp.status_code, headers=resp_headers)
        except Exception as e:
            print('向kube发起接口报错', e)
            return HttpResponse(content="向kube发起接口报错")

    def get(self, request):
        # print(self, request)
        return self.redirect_to_kube_server(request=request)

    def post(self, request):
        return self.redirect_to_kube_server(request=request)

 

github 对接

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import AnyStr, List, Dict
from urllib import parse

import requests

from dorylus.log import lib_log


class Github(object):
    def __init__(self, url: AnyStr, token: AnyStr):
        # 示例:https://api.github.com   xxx
        self.url = url
        self.token = token

        self.headers = {
            'Accept': 'application/vnd.github.v3+json',
            'authorization': 'Token {}'.format(token)
        }

    @staticmethod
    def _encode_path(application: str):
        """url编码项目路径"""
        return parse.quote_plus(application)

    def _request_branch(self, application: AnyStr):
        """
        协程获取项目分支
        :param application: 命名空间/项目名
        :return: 字典 {application: []}
        """

        def rc_branch(branches: List, current_page=1):
            # https://api.github.com/repos/xxx/xxx/branches
            url = '{}/repos/{}/branches?per_page=100&page={}' \
                .format(self.url, application, current_page)

            with requests.get(url, headers=self.headers) as resp:
                try:
                    branch = [i.get('name', '') for i in resp.json() if type(i) == dict]
                    branches.extend(branch)
                    if len(branch) == 100:
                        rc_branch(branches, current_page + 1)
                except Exception as e:
                    lib_log.error('application:{} url:{} '.format(application, url) + e.__str__())
            return

        all_branches = []
        rc_branch(all_branches)
        return {application: all_branches}

    def get_branches(self, application):
        """
        获取项目分支
        https://docs.github.com/en/rest/reference/repos#branches
        :param application: 命名空间/项目名 或者 列表
        :return: 字典
        """
        result: Dict = {}
        if not application:
            return result
        if type(application) == str:
            application = [application]
        for i in application:
            result.update(self._request_branch(i))
        return result

    def last_commit(self, project, branch, page_size=1):
        # 最后一次commit的branch的信息
        # https://api.github.com/repos/xx/xx/commits
        url = '{}/repos/{}/commits?sha={}&per_page={}' \
            .format(self.url, project, branch, page_size)
        # print(url, self.headers)
        commit = {}
        with requests.get(url, headers=self.headers, timeout=3) as resp:
            try:
                commit = resp.json()
                if isinstance(commit, list) and len(commit) > 0:
                    commit = commit[0]
            except Exception as e:
                error = 'get {} application {} branch last_commit error: {}' \
                    .format(project, url, e.__str__())
                lib_log.error(error)
        return commit


def test_github():
    application = 'xxx'
    token = 'xxxxx'
    git1 = Github('https://api.github.com', token)
    print(git1.last_commit(application, 'master'))
    print(git1.get_branches(application=application))

 

gitlab 对接

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import os
# os.environ.setdefault("DJANGO_SETTINGS_MODULE", "nvwa.settings")
# import django
# django.setup()

from typing import AnyStr, List, Dict
from urllib import parse

import requests

from dorylus.log import lib_log


class Git(object):
    def __init__(self, url: AnyStr, token: AnyStr):
        self.url = url
        self.token = token

    @staticmethod
    def _encode_path(application: str):
        """url编码项目路径"""
        return parse.quote_plus(application)

    def _request_branch(self, application: AnyStr):
        """
        协程获取项目分支
        :param application: 命名空间/项目名
        :return: 字典 {application: []}
        """

        def rc_branch(branches: List, current_page=1):
            url = '{}/api/v4/projects/{}/repository/branches/?private_token={}&per_page=100&page={}' \
                .format(self.url, self._encode_path(application), self.token, current_page)
            with requests.get(url) as resp:
                try:
                    # print(url)
                    branch = [i.get('name', '') for i in resp.json() if type(i) == dict]
                    branches.extend(branch)
                    if len(branch) == 100:
                        rc_branch(branches, current_page + 1)
                except Exception as e:
                    lib_log.error('application:{} url:{} '.format(application, url) + e.__str__())
            return

        all_branches = []
        rc_branch(all_branches)
        return {application: all_branches}

    def get_branches(self, application):
        """
        获取项目分支
        :param application: 命名空间/项目名 或者 列表
        :return: 字典
        """
        result: Dict = {}
        if not application:
            return result
        if type(application) == str:
            application = [application]
        for i in application:
            result.update(self._request_branch(i))
        return result

    def _request_tag(self, project: AnyStr, info="name"):
        """
        协程获取项目分支
        :param project: 命名空间/项目名
        :return: 字典 {project: []}
        """

        def rc_tag(tags: List, current_page=1):
            url = '{}/api/v4/projects/{}/repository/tags/?private_token={}&per_page=100&page={}' \
                .format(self.url, self._encode_path(project), self.token, current_page)

            with requests.get(url) as resp:
                try:
                    # print(url)
                    tags_current = []
                    # for i in resp.json() if type(i) == dict
                    for i in resp.json():
                        if type(i) == dict:
                            if info == "all":
                                tags_current.append(i)
                            elif info in i:
                                tags_current.append(i.get(info, ''))
                            else:
                                # 字段不对的就返回全部的数据
                                tags_current.append(i)
                    tags.extend(tags_current)
                    if len(tags_current) == 100:
                        rc_tag(tags, current_page + 1)
                except Exception as e:
                    lib_log.error('application:{} url:{} '.format(project, url) + e.__str__())
            return

        all_tags = []
        rc_tag(all_tags, 1)
        return {project: all_tags}

    def get_tags(self, project, info="name"):
        """
        获取项目的标签
        获取类型是名字
        """
        result: Dict = {}
        if not project:
            return result
        if type(project) == str:
            projects = [project]
        for i in projects:
            result.update(self._request_tag(i, info))
        return result

    def create_tag(self, project, tag_name, ref, message=""):
        """
        创建分支
        project: 项目
        tag_name: 想创建的标签名字
        ref: 从哪里来: 比如 master, develop,某个提交id
        """
        url = '{}/api/v4/projects/{}/repository/tags'\
            .format(self.url, self._encode_path(project))

        # 好坑 http不ok
        if url.startswith("http://git.hrlyit"):
            url = url.replace("http://git.hrlyit", "https://git.hrlyit")
        headers = {
            "PRIVATE-TOKEN": self.token,
            'Content-Type': 'application/json'
        }
        data = {
            "tag_name": tag_name,
            "ref": ref,
            "message": message
        }
        # print(url, headers, data)

        r = requests.request("POST", url, headers=headers, data=json.dumps(data))

        if r.ok:
            result = r.json()
            if isinstance(result, list) and len(result) > 0:
                return result[0]
            else:
                return r.json()
        else:
            print("创建标签失败:", r.text)
            return False

    def last_commit(self, application, branch):
        # 最后一次commit的branch的信息
        url = '{}/api/v4/projects/{}/repository/commits/{}?private_token={}' \
            .format(self.url, self._encode_path(application), branch, self.token)
        commit = {}
        with requests.get(url, timeout=3) as resp:
            try:
                commit = resp.json()
            except Exception as e:
                error = 'get {} project {} branch last_commit error: {}' \
                    .format(application, url, e.__str__())
                lib_log.error(error)
        return commit

    # 添加webhook API
    def add_webhook(self, application, tokens, branch):
        url = '{}/api/v4/applications/{}/hooks/?private_token={}' \
            .format(self.url, self._encode_path(application), self.token)
        newHook = {'token': tokens,
                   'url': 'http://172.16.90.177:8000/' + 'application/' + 'auto_build/',
                   'push_events': True,
                   "push_events_branch_filter": branch}
        webhook = requests.post(url, timeout=3, data=newHook)
        webhook_json = webhook.json()
        webhookId = webhook_json['id']
        return webhookId

    # 编辑webhook
    def put_webhook(self, application, webhookId, branch):
        url = '{}/api/v4/applications/{}/hooks/{}?private_token={}' \
            .format(self.url, self._encode_path(application), webhookId, self.token)
        newData = {
                   'url': 'http://172.16.90.177:8000/' + 'application/' + 'auto_build/',
                   'push_events': True,
                   'push_events_branch_filter': branch}
        requests.put(url, timeout=3, data=newData)

    # 删除webhook
    def delete_webhook(self, application, webhookId):
        url = '{}/api/v4/applications/{}/hooks/{}?private_token={}' \
            .format(self.url, self._encode_path(application), webhookId, self.token)
        requests.delete(url, timeout=3)


if __name__ == '__main__':
    git1 = Git('https://git.hrlyit.com', 'YHkSZkMoYDkEZmBD4ESs')
    print(git1.last_commit('luojinyu/test', 'master'))

 

Jenkins 对接

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import re
from typing import List, Dict, AnyStr
from xml.etree import ElementTree
import jenkins
from lib.redis import redis_client
from dorylus.log import lib_log
from lib.gitlab import Git
from django.core.cache import cache


class Jenkins(object):
    def __init__(self, url, username, password):
        self.url = url
        self.username = username
        self.password = password
        self.j = jenkins.Jenkins(self.url, self.username, self.password)

        asyncio.set_event_loop(asyncio.new_event_loop())
        self.loop = asyncio.get_event_loop()

    async def _do_build(self, name: AnyStr, params: Dict, registry_address: AnyStr, get_console: bool = False) -> Dict:
        """
        构建工程并获取console输出
        :param name: 项目名
        :param params: 构建参数
        :param get_console: 是否获取console
        :return: 字典
        """
        status: AnyStr = ''
        console: AnyStr = ''
        result: Dict = {name: {'status': status, 'result': console,
                               'build_name': name, 'registry_address': registry_address}}
        # 判断项目是否存在
        if not self.j.job_exists(name):
            result[name]['status'] = 'FAILED'
            result[name]['result'] = 'Job {} do not exist'.format(name)
            return result
        # 触发构建
        build_number = self.j.get_job_info(name)['nextBuildNumber']
        result[name]['build_number'] = build_number
        jenkins_tmp_build_id = '/jenkins/build/{}'.format(name)
        cache.set(jenkins_tmp_build_id, build_number, 300)  # 只最多锁住五分钟
        try:
            queue = self.j.build_job(name, params)
        except Exception as e:
            lib_log.error(e)
            result[name]['status'] = 'FAILED'
            result[name]['result'] = 'Job {} request build failed'.format(name)
            return result
        # 判断是否获取console
        if not get_console:
            return result
        # 检测是否在队列中
        while 1:
            await asyncio.sleep(1)
            if not self.j.get_queue_item(queue)['why']:
                break
        # 获取构建状态与console
        while 1:
            await asyncio.sleep(0.5)
            try:
                build_info = self.j.get_build_info(name, build_number)
                res = self.j.get_build_console_output(name, build_number)
                jenkins_tmp_build_res = '/jenkins/build/{}/res/{}'.format(build_number, name)
                cache.set(jenkins_tmp_build_res, res, 1)  # 只最多锁住五分钟

                if not build_info['building']:
                    result[name]['status'] = build_info['result']
                    result[name]['result'] = res
                    break
            except Exception as e:
                lib_log.error(e)
                result[name]['status'] = 'CANCELLED'
                break
        return result

    def build(self, application: List[Dict], get_console: bool = False) -> Dict:
        """
        构建工程
        :param application: [{'name': , 'params': 'registryAddress':}]
        :param get_console: 是否获取console
        :return: 构建结果 Dict
        """
        result: Dict = {}
        tasks = []
        for i in application:
            tasks.append(
                asyncio.ensure_future(self._do_build(i['name'], i['params'], i['registry_address'], get_console)))
        self.loop.run_until_complete(asyncio.wait(tasks))
        for i in tasks:
            result.update(i.result())
        return result

    async def _do_method(self, param_methods: List, conf, git: Dict):
        result: Dict = {}
        for i in param_methods:
            if 'git' in i:
                result.update(self.__getattribute__(i)(conf, git))
            else:
                result.update(self.__getattribute__(i)(conf))
        return result

    async def _do_param(self, name: str, param_methods: List, git: Dict) -> Dict:
        result: Dict = {name: {}}
        try:
            conf = ElementTree.fromstring(self.j.get_job_config(name))
            result[name] = await self._do_method(param_methods, conf, git)
        except Exception as e:
            lib_log.error(e)
            result[name] = e.__str__()
        return result

    def params(self, name, git: Dict = None) -> Dict:
        """
        获取项目所有参数信息
        :param name: 项目
        :param git: {url: '', token: ''}
        :return:
        """
        if type(name) == str:
            name = [name]
        result: Dict = {}
        tasks = []
        param_methods = [i for i in self.__dir__() if i.endswith('_param_get')]
        for i in name:
            tasks.append(asyncio.ensure_future(self._do_param(i, param_methods, git)))
        self.loop.run_until_complete(asyncio.wait(tasks))
        for i in tasks:
            result.update(i.result())
        return result

    @staticmethod
    def registry(conf) -> str:
        """
        获取git仓库地址
        :param conf:
        :return:
        """
        result: str = None
        my_re = ':(.*).git'
        registry_node = conf.find(
            './scm/userRemoteConfigs/hudson.plugins.git.UserRemoteConfig'
        ).find('url').text
        if registry_node:
            x = re.findall(my_re, registry_node)
            result = x.pop() if x else None
        return result

    def _git_param_get(self, conf, git: Dict) -> Dict:
        """
        获取git 分支参数
        :param conf: 项目名
        :return: 分支列表
        """
        result: Dict = {}
        if not git:
            return result
        parameter = conf.find(
            './properties/hudson.model.ParametersDefinitionProperty/parameterDefinitions'
        )
        if not parameter:
            return result
        for i in parameter:
            if i.tag == 'net.uaznia.lukanus.hudson.plugins.gitparameter.GitParameterDefinition' and i.find(
                    'name').text == 'branch':
                result['branch'] = {}
                g = Git(url=git['url'], token=git['token'])
                registry = self.registry(conf)
                if not registry:
                    break
                result['branch']['option'] = ['origin/' + k for k in g.get_branches(registry)[registry]]
                result['branch']['default'] = i.find('defaultValue').text
                break
        return result

    @staticmethod
    def _default_param_get(conf) -> Dict:
        """
        获取自定义参数
        :param conf:
        :return:
        """
        result: Dict = {}
        parameter = conf.find(
            './properties/hudson.model.ParametersDefinitionProperty/parameterDefinitions'
        )
        if not parameter:
            return result
        for i in parameter:
            if i.tag == 'hudson.model.ChoiceParameterDefinition':
                name = i.find('name').text
                option = [k.text for k in i.iter('string')]
                result[name] = {}
                result[name]['option'] = option
                result[name]['default'] = option[0] if option else None
        return result

    def get_config(self, name):
        conf = self.j.get_job_config(name=name)
        return conf

    def create_job(self, name, config_xml):
        result: Dict = {
            'status': 'SUCCESS',
            'result': ''
        }
        # 如果新工程存在则返回失败
        if self.j.job_exists(name):
            result['status'] = 'FAILED'
            result['result'] = 'Job {} exists'.format(name)
            return result
        try:
            self.j.create_job(name=name, config_xml=config_xml)
        except Exception as e:
            print(e)
            result['status'] = 'FAILED'
        return result

    def copy_job(self, copy_name, name):
        result: Dict = {
            'status': 'SUCCESS',
            'result': ''
        }
        # 如果新工程存在则返回失败
        if self.j.job_exists(name):
            result['status'] = 'FAILED'
            result['result'] = 'Job {} exists'.format(name)
            return result
        try:
            self.j.copy_job(copy_name, name)
        except Exception as e:
            print(e)
            result['status'] = 'FAILED'
        return result

    def get_all_job(self):
        return [n['name'] for n in self.j.get_all_jobs()]

    def update_config(self, name, **kwargs):
        # import datetime
        # update_config_file = '/tmp/config-' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + '.xml'
        conf = ElementTree.fromstring(self.j.get_job_config(name))
        # 修改registry_address
        if 'registry_address' in kwargs:
            try:
                registry_address = kwargs['registry_address']
                my_re = ':(.*).git'
                # 找到git地址
                registry_node = conf.find(
                    './scm/userRemoteConfigs/hudson.plugins.git.UserRemoteConfig'
                ).find('url')
                if registry_node is None:
                    pass
                else:
                    registry_tag = registry_node.text
                    # 有些是ssh的有些是http开头的
                    if registry_tag.startswith(('http://', 'https://')):
                        my_re = '^htt.*?:\/\/.*?\/(.*?)$'

                    x = re.findall(my_re, registry_tag)
                    if x:
                        result = x.pop()
                        update_tag = registry_tag.replace(result, registry_address)
                        registry_node.text = update_tag
                        print('Update Git {} OK'.format(registry_address))
            except Exception as e:
                # 因为有些复制的jenkins没这个配置,就会报错的,所以使用这个
                print('替换registry_address有误:', str(e))
        # # 更新conf
        # tree = ElementTree.ElementTree(conf)
        # tree.write(update_config_file, encoding='UTF-8')
        config_xml = ElementTree.tostring(conf).decode()
        return config_xml

    def add_to_views(self, name_list, jenkins_name):
        """
        :param name_list: [项目英文名, 项目中文名]
        :param jenkins_name: job name
        :return:
        """
        view_name = ''
        for n in name_list:
            if n and self.j.view_exists(n):
                view_name = n
        # 如果找不到项目视图则跳过添加视图
        if not view_name:
            return
        view_config = ElementTree.fromstring(self.j.get_view_config(view_name))
        job_node = view_config.find(
            './jobNames'
        )
        if not job_node:
            return
        job_list = [s.text for s in job_node.findall('string')]
        if len(job_list) == 0:
            index = len(job_node)
        else:
            # 找到项目下第一个job并确认实际index
            first_index = 1
            first_text = job_list[0]
            for i in range(len(job_node)):
                if job_node.text == first_text:
                    first_index = i
                    break
            # 对视图下的job进行排序, 获取新增job顺序
            job_list.append(jenkins_name)
            job_list.sort()
            index = job_list.index(jenkins_name) + first_index
        # 构建新节点
        elem = ElementTree.Element("string")
        elem.text = jenkins_name
        elem.tail = '\n'
        # 插入节点
        job_node.insert(index, elem)
        # 更新config xml
        view_config_xml = ElementTree.tostring(view_config).decode()
        self.j.reconfig_view(view_name, view_config_xml)
        return True

 

k8s 对接

# -*- coding:utf-8 -*-
import re

import yaml
from kubernetes import client, config
from kubernetes.client import ApiClient
from kubernetes.client import V1Namespace, V1ObjectMeta
from kubernetes.client.rest import ApiException


def hump_to_underline(hump):
    return re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), hump)


def serialization(y):
    def z(self, *arg, **kwargs):
        return ApiClient().sanitize_for_serialization(y(self, *arg, **kwargs))

    return z


def log_handle(y) -> (str, str):
    def z(self, *args, **kwargs):
        try:
            return y(self, *args, **kwargs), None
        except ApiException as e:
            return None, e.__str__()

    return z


class GlobalBaseModel:
    def __init__(self, c):
        self.class_name = hump_to_underline(type(self).__name__)
        self.client = c
        self._create_method = 'create' + self.class_name
        self._read_method = 'read' + self.class_name
        self._delete_method = 'delete' + self.class_name
        self._list_method = 'list' + self.class_name
        self._replace_method = 'replace' + self.class_name

    def all(self, limit: int = 10, page: int = 1, search: str = ''):
        f, _ = self.filter()
        items_list = f['items'] if 'items' in f else []
        if search:
            items = [i for i in items_list if search in i.get('metadata', {}).get('name', '')]
        else:
            items = items_list
        count = len(items)
        # limit = 0表所有数据
        if limit:
            start = limit * (page - 1)
            end = limit * page if limit * page < count else count
            items = items[start:end]
        res = {
            'data': items,
            'count': count
        }
        return res

    @log_handle
    @serialization
    def filter(self, label_selector: str = '', field_selector: str = '', **kwargs):
        return getattr(self.client, self._list_method)(
            label_selector=label_selector, field_selector=field_selector, **kwargs)

    @log_handle
    @serialization
    def get(self, name: str = None):
        return getattr(self.client, self._read_method)(name=name)

    @log_handle
    def create(self, body, **kwargs):
        return getattr(self.client, self._create_method)(body=body, **kwargs)

    @log_handle
    def delete(self, name, body=client.V1DeleteOptions(), **kwargs):
        return getattr(self.client, self._delete_method)(name=name, body=body, **kwargs)

    @log_handle
    def replace(self, name, body, **kwargs):
        return getattr(self.client, self._replace_method)(name=name, body=body, **kwargs)

    @log_handle
    def patch(self, name, body, **kwargs):
        return getattr(self.client, self._patch_method)(name=name, body=body, **kwargs)

    def apply(self, name, body):
        if self.get(name=name)[0]:
            return self.replace(name=name, body=body)
        else:
            return self.create(body=body)


class BaseModel:
    def __init__(self, c):
        self.class_name = hump_to_underline(type(self).__name__)
        self.client = c
        self._namespaced_list = 'list_namespaced' + self.class_name
        self._all_list = 'list' + self.class_name + '_for_all_namespaces'
        self._read_method = 'read_namespaced' + self.class_name
        self._create_method = 'create_namespaced' + self.class_name
        self._replace_method = 'replace_namespaced' + self.class_name
        self._delete_method = 'delete_namespaced' + self.class_name
        self._patch_method = 'patch_namespaced' + self.class_name

    def all(self, namespace: str = None, field_selector: str = '', label_selector: str = '',
            limit: int = 10, page: int = 1, search: str = ''):
        f, _ = self.filter(namespace=namespace, field_selector=field_selector, label_selector=label_selector)
        items_list = f['items'] if 'items' in f else []
        if search:
            items = [i for i in items_list if search in i.get('metadata', {}).get('name', '')]
        else:
            items = items_list
        count = len(items)
        # limit = 0表所有数据
        if limit:
            start = limit * (page - 1)
            end = limit * page if limit * page < count else count
            items = items[start:end]
        res = {
            'data': items,
            'count': count
        }
        return res

    @log_handle
    @serialization
    def filter(self, namespace: str = None, label_selector: str = '', field_selector: str = ''):
        """
        过滤
        :param namespace: 命名空间
        :param label_selector: 标签过滤 如 app=ubuntu,
        :param field_selector:
        :return:
        """
        if namespace:
            return getattr(self.client, self._namespaced_list)(namespace=namespace, label_selector=label_selector,
                                                               field_selector=field_selector)
        else:
            return getattr(self.client, self._all_list)(label_selector=label_selector, field_selector=field_selector)

    @log_handle
    @serialization
    def get(self, name: str = None, namespace: str = None):
        return getattr(self.client, self._read_method)(name=name, namespace=namespace)

    @log_handle
    @serialization
    def create(self, namespace, body, **kwargs):
        return getattr(self.client, self._create_method)(namespace=namespace, body=body, **kwargs)

    @log_handle
    def delete(self, name, namespace, body=client.V1DeleteOptions(), **kwargs):
        return getattr(self.client, self._delete_method)(name=name, namespace=namespace, body=body, **kwargs)

    @log_handle
    @serialization
    def replace(self, name, namespace, body, **kwargs):
        return getattr(self.client, self._replace_method)(name=name, namespace=namespace, body=body, **kwargs)

    @log_handle
    @serialization
    def patch(self, name, namespace, body, **kwargs):
        return getattr(self.client, self._patch_method)(name=name, namespace=namespace, body=body, **kwargs)

    def apply(self, name, namespace, body):
        if self.get(name=name, namespace=namespace)[0]:
            return self.replace(name=name, namespace=namespace, body=body)
        else:
            return self.create(namespace=namespace, body=body)


class Namespace(GlobalBaseModel):
    """命名空间"""
    pass


class Node(GlobalBaseModel):
    """节点"""
    pass


class PersistentVolume(GlobalBaseModel):
    """持续性挂载"""
    pass


class PersistentVolumeClaim(BaseModel):
    """持续性挂载认领"""
    pass


class Deployment(BaseModel):
    pass


class StatefulSet(BaseModel):
    pass


class Service(BaseModel):
    # def apply(self, name, namespace, body):
    #     if self.get(name=name, namespace=namespace)[0]:
    #         x = self.delete(name=name, namespace=namespace)
    #     y = self.create(namespace=namespace, body=body)
    #     return y
    def apply(self, name, namespace, body):
        # 2021-09-11修复apply问题
        service_old = self.get(name=name, namespace=namespace)[0]
        if service_old:
            # 注意用的是patch才行
            # body = [{"op": "replace", "path": "/spec/ports/0/port", "value": 8080}]
            # 把新的对象换成json
            new_body = body.to_dict()
            # 如果端口数是相等的那么就需要替换一下
            if len(service_old['spec']['ports']) == len(new_body['spec']['ports']):
                i = 0
                for item in new_body['spec']['ports']:
                    # 如果出现了开始是80-80, 第二个也还是80-80那么是会出错的:2021-09-29
                    patch_port_data = [
                        {"op": "replace", "path": f"/spec/ports/{i}/name", "value": item['name']},
                        {"op": "replace", "path": f"/spec/ports/{i}/port", "value": item['port']},
                        {"op": "replace", "path": f"/spec/ports/{i}/targetPort", "value": item['target_port']},
                    ]
                    i += 1
                self.client.patch_namespaced_service(name, namespace, patch_port_data)
                # 重新获取一下服务数据
                service_old = self.get(name=name, namespace=namespace)[0]
            else:
                service_old['spec']['ports'] = new_body['spec']['ports']

            # 修改labels
            service_old['metadata']['labels'] = new_body['metadata']['labels']
            service_old['metadata']['annotations'] = new_body['metadata']['annotations']

            service_old['spec']['selector'] = new_body['spec']['selector']

            return self.replace(name, namespace, service_old)
        else:
            return self.create(namespace=namespace, body=body)


class ReplicaSet(BaseModel):
    pass


class ConfigMap(BaseModel):
    pass


class Event(BaseModel):
    pass


class StorageClass(GlobalBaseModel):
    pass


class Secret(BaseModel):
    pass


class Pod(BaseModel):
    @log_handle
    def log(self, name, namespace, **kwargs):
        return self.client.read_namespaced_pod_log(name=name, namespace=namespace, **kwargs)


class CustomObjects:
    """ 自定义类型接口 """

    def __init__(self, group, version, plural, kind, c):
        self.group = group
        self.version = version
        self.plural = plural
        self.kind = kind
        self.api_version = '{group}/{version}'.format(group=self.group, version=self.version)
        self.must_arg = {
            'group': self.group,
            'version': self.version,
            'plural': self.plural
        }
        self.client = c

    @log_handle
    @serialization
    def filter(self, namespace: str = None, label_selector: str = ''):
        """
        过滤
        :param namespace: 命名空间
        :param label_selector: 标签过滤 如 app=ubuntu,
        :return:
        """
        arg = {
            'label_selector': label_selector,
            **self.must_arg
        }
        method = self.client.list_cluster_custom_object
        if namespace:
            method = self.client.list_namespaced_custom_object
            arg['namespace'] = namespace
        return method(**arg)

    @log_handle
    @serialization
    def get(self, name: str = None, namespace: str = None):
        return self.client.get_namespaced_custom_object(name=name, namespace=namespace, **self.must_arg)

    @log_handle
    @serialization
    def create(self, namespace, body, **kwargs):
        # body = {'name': 'test-api-create',
        #         'metadata': {'name': 'test-api-create'},
        #          'spec': {'headers': {'accessControlMaxAge': 100,
        #                   'accessControlAllowMethods': '*'}},
        #          'apiVersion': 'traefik.containo.us/v1alpha1',
        #          'kind':'Middleware' }
        return self.client.create_namespaced_custom_object(namespace=namespace, body=body, **self.must_arg, **kwargs)

    @log_handle
    def delete(self, name, namespace, body=client.V1DeleteOptions(), **kwargs):
        return self.client.delete_namespaced_custom_object(
            name=name, namespace=namespace, body=body, **self.must_arg, **kwargs)

    @log_handle
    @serialization
    def replace(self, name, namespace, body):
        """更新资源"""
        return self.client.replace_namespaced_custom_object(name=name, namespace=namespace, body=body, **self.must_arg)

    @log_handle
    @serialization
    def patch(self, name, namespace, spec):
        # body = {'spec': {'headers': {'accessControlMaxAge': 300, 'accessControlAllowMethods': '*'}}}
        body = {'spec': spec}
        return self.client.patch_namespaced_custom_object(name=name, namespace=namespace, body=body, **self.must_arg)

    def apply(self, name, namespace, spec):
        res, _ = self.get(name=name, namespace=namespace)
        body = {
            'spec': spec,
            'name': name,
            'metadata': {'name': name, 'namespace': namespace},
            'apiVersion': self.api_version,
            'kind': self.kind,
        }
        if res:
            resource_version = res.get('metadata', {}).get('resourceVersion')
            body['metadata']['resourceVersion'] = resource_version
            return self.replace(name=name, namespace=namespace, body=body)
            # return self.patch(name=name, namespace=namespace, spec=spec)
        else:
            return self.create(namespace=namespace, body=body)


class Ingress:
    """ingress对象"""

    def __init__(self, c):
        self.group = 'traefik.containo.us'
        self.version = 'v1alpha1'
        self.router_plural = 'ingressroutes'
        self.router_kind = 'IngressRoute'
        self.middleware_plural = 'middlewares'
        self.middleware_kind = 'Middleware'
        self.router = CustomObjects(
            c=c,
            group=self.group,
            version=self.version,
            plural=self.router_plural,
            kind=self.router_kind
        )
        self.middleware = CustomObjects(
            c=c,
            group=self.group,
            version=self.version,
            plural=self.middleware_plural,
            kind=self.middleware_kind
        )


class Kubernetes:
    def __init__(self, config_file: str):
        # config.load_kube_config(config_file=config_file)
        self.api_client = config.new_client_from_config(config_file=config_file)
        self.clientV1 = client.CoreV1Api(self.api_client)
        self.clientV1beta2 = client.AppsV1beta2Api(self.api_client)
        self.clientAppV1 = client.AppsV1Api(self.api_client)
        self.clientStorageV1 = client.StorageV1Api(self.api_client)
        self.clientCustomObject = client.CustomObjectsApi(self.api_client)
        self.Node = Node(self.clientV1)
        self.Deploy = Deployment(self.clientAppV1)
        self.StatefulSet = StatefulSet(self.clientAppV1)
        self.Service = Service(self.clientV1)
        self.Pod = Pod(self.clientV1)
        self.RS = ReplicaSet(self.clientAppV1)
        self.Namespace = Namespace(self.clientV1)
        self.ConfigMap = ConfigMap(self.clientV1)
        self.StorageClass = StorageClass(self.clientStorageV1)
        self.PersistentVolume = PersistentVolume(self.clientV1)
        self.PersistentVolumeClaim = PersistentVolumeClaim(self.clientV1)
        self.Event = Event(self.clientV1)
        self.Secret = Secret(self.clientV1)
        self.ingress = Ingress(self.clientCustomObject)

    @staticmethod
    def api_version_to_client(api_version: str) -> str:
        group, _, version = api_version.partition('/')
        if not version:
            version = group
            group = 'Core'
        group = ''.join(group.split('.k8s.io'))
        func_to_call = "{0}{1}Api".format(group.capitalize(), version.capitalize())
        return func_to_call

    def create_from_yaml(self, stream):
        for i in yaml.safe_load_all(stream=stream):
            print(self.api_version_to_client(i['apiVersion']), i['kind'])


if __name__ == '__main__':
    k = Kubernetes('~/.kube/config-lls-test')
    # k.Namespace.create(body=V1Namespace(metadata=V1ObjectMeta(name='qa')))
    print(k.Pod.all())

 

posted @ 2022-04-19 16:51  羊驼之歌  阅读(159)  评论(0编辑  收藏  举报