E:\song3\agv_backend_demo\gunicorn.conf.py
# 监听内网端口
bind = '0.0.0.0:8001'
# 工作目录
chdir = '/fsm/backend/app'
# 并行工作进程数
workers = 4
# 指定每个工作者的线程数
threads = 4
# 监听队列
backlog = 512
# 超时时间
timeout = 120
# 设置守护进程,将进程交给 supervisor 管理;如果设置为 True 时,supervisor 启动日志为:
# gave up: fastapi_server entered FATAL state, too many start retries too quickly
# 则需要将此改为: False
daemon = False
# 工作模式协程
worker_class = 'uvicorn.workers.UvicornWorker'
# 设置最大并发量
worker_connections = 2000
# 设置进程文件目录
pidfile = '/fsm/gunicorn.pid'
# 设置访问日志和错误信息日志路径
accesslog = '/var/log/fastapi_server/gunicorn_access.log'
errorlog = '/var/log/fastapi_server/gunicorn_error.log'
# 设置这个值为true 才会把打印信息记录到错误日志里
capture_output = True
# 设置日志记录水平
loglevel = 'debug'
# python程序
pythonpath = '/usr/local/lib/python3.8/site-packages'
# 启动 gunicorn -c gunicorn.conf.py main:app
E:\song3\agv_backend_demo\init_test_data.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
from email_validator import EmailNotValidError, validate_email
from faker import Faker
from backend.app.api.jwt import get_hash_password
from backend.app.common.log import log
from backend.app.database.db_mysql import async_db_session
from backend.app.models import User
class InitData:
""" 初始化数据 """
def __init__(self):
self.fake = Faker('zh_CN')
@staticmethod
async def create_superuser_by_yourself():
""" 手动创建管理员账户 """
print('请输入用户名:')
username = input()
print('请输入密码:')
password = input()
print('请输入邮箱:')
while True:
email = input()
try:
validate_email(email, check_deliverability=False).email
except EmailNotValidError:
print('邮箱不符合规范,请重新输入:')
continue
break
user_obj = User(
username=username,
password=get_hash_password(password),
email=email,
is_superuser=True,
)
async with async_db_session.begin() as db:
db.add(user_obj)
await db.commit()
log.info(f'管理员用户创建成功,账号:{username},密码:{password}')
async def fake_user(self):
""" 自动创建普通用户 """
username = self.fake.user_name()
password = self.fake.password()
email = self.fake.email()
user_obj = User(
username=username,
password=get_hash_password(password),
email=email,
is_superuser=False,
)
async with async_db_session.begin() as db:
db.add(user_obj)
await db.commit()
log.info(f"普通用户创建成功,账号:{username},密码:{password}")
async def fake_no_active_user(self):
""" 自动创建锁定普通用户 """
username = self.fake.user_name()
password = self.fake.password()
email = self.fake.email()
user_obj = User(
username=username,
password=get_hash_password(password),
email=email,
is_active=False,
is_superuser=False,
)
async with async_db_session.begin() as db:
db.add(user_obj)
await db.commit()
log.info(f"普通锁定用户创建成功,账号:{username},密码:{password}")
async def fake_superuser(self):
""" 自动创建管理员用户 """
username = self.fake.user_name()
password = self.fake.password()
email = self.fake.email()
user_obj = User(
username=username,
password=get_hash_password(password),
email=email,
is_superuser=True,
)
async with async_db_session.begin() as db:
db.add(user_obj)
await db.commit()
log.info(f"管理员用户创建成功,账号:{username},密码:{password}")
async def fake_no_active_superuser(self):
""" 自动创建锁定管理员用户 """
username = self.fake.user_name()
password = self.fake.password()
email = self.fake.email()
user_obj = User(
username=username,
password=get_hash_password(password),
email=email,
is_active=False,
is_superuser=True,
)
async with async_db_session.begin() as db:
db.add(user_obj)
await db.commit()
log.info(f"管理员锁定用户创建成功,账号:{username},密码:{password}")
async def init_data(self):
""" 自动创建数据 """
log.info('⏳ 开始初始化数据')
await self.create_superuser_by_yourself()
await self.fake_user()
await self.fake_no_active_user()
await self.fake_superuser()
await self.fake_no_active_superuser()
log.info('✅ 数据初始化完成')
if __name__ == '__main__':
init = InitData()
loop = asyncio.get_event_loop()
loop.run_until_complete(init.init_data())
E:\song3\agv_backend_demo\run.py
import uvicorn
if __name__ == "__main__":
uvicorn.run("backend.app.main:app", host="127.0.0.1", port=8000, reload=True)
E:\song3\agv_backend_demo\alembic\env.py
from backend.app.models import MappedBase
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
# fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = MappedBase.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
E:\song3\agv_backend_demo\alembic\versions\4ac986bb1ace_create_user_table.py
"""create user table
Revision ID: 4ac986bb1ace
Revises:
Create Date: 2023-03-11 18:11:19.312709
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = '4ac986bb1ace'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('sys_user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('uid', sa.String(length=50), nullable=False, comment='唯一标识'),
sa.Column('username', sa.String(length=20), nullable=False, comment='用户名'),
sa.Column('password', sa.String(length=255), nullable=False, comment='密码'),
sa.Column('email', sa.String(length=50), nullable=False, comment='邮箱'),
sa.Column('is_superuser', sa.Boolean(), nullable=False, comment='超级权限'),
sa.Column('is_active', sa.Boolean(), nullable=False, comment='用户账号状态'),
sa.Column('avatar', sa.String(length=255), nullable=True, comment='头像'),
sa.Column('mobile_number', sa.String(length=11), nullable=True, comment='手机号'),
sa.Column('wechat', sa.String(length=20), nullable=True, comment='微信'),
sa.Column('qq', sa.String(length=10), nullable=True, comment='QQ'),
sa.Column('blog_address', sa.String(length=255), nullable=True, comment='博客地址'),
sa.Column('introduction', sa.String(length=1000), nullable=True, comment='自我介绍'),
sa.Column('time_joined', sa.DateTime(), nullable=False, comment='注册时间'),
sa.Column('last_login', sa.DateTime(), nullable=True, comment='上次登录'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uid')
)
op.create_index(op.f('ix_sys_user_email'), 'sys_user', ['email'], unique=True)
op.create_index(op.f('ix_sys_user_id'), 'sys_user', ['id'], unique=False)
op.create_index(op.f('ix_sys_user_username'), 'sys_user', ['username'], unique=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_sys_user_username'), table_name='sys_user')
op.drop_index(op.f('ix_sys_user_id'), table_name='sys_user')
op.drop_index(op.f('ix_sys_user_email'), table_name='sys_user')
op.drop_table('sys_user')
# ### end Alembic commands ###
E:\song3\agv_backend_demo\backend\app\main.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import uvicorn
from path import Path
from backend.app.api.registrar import register_app
from backend.app.common.log import log
from backend.app.core.conf import settings
app = register_app()
if __name__ == '__main__':
try:
log.info(
"""\n
/$$$$$$$$ /$$ /$$$$$$ /$$$$$$$ /$$$$$$
| $$_____/ | $$ /$$__ $$| $$__ $$|_ $$_/
| $$ /$$$$$$ /$$$$$$$ /$$$$$$ | $$ | $$| $$ | $$ | $$
| $$$$$|____ $$ /$$_____/|_ $$_/ | $$$$$$$$| $$$$$$$/ | $$
| $$__/ /$$$$$$$| $$$$$$ | $$ | $$__ $$| $$____/ | $$
| $$ /$$__ $$ |____ $$ | $$ /$$| $$ | $$| $$ | $$
| $$ | $$$$$$$ /$$$$$$$/ | $$$$/| $$ | $$| $$ /$$$$$$
|__/ |_______/|_______/ |___/ |__/ |__/|__/ |______/
"""
)
uvicorn.run(app=f'{Path(__file__).stem}:app', host=settings.UVICORN_HOST, port=settings.UVICORN_PORT,
reload=settings.UVICORN_RELOAD)
except Exception as e:
log.error(f'❌ FastAPI start filed: {e}')
E:\song3\agv_backend_demo\backend\app\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\api\jwt.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from typing import Any, Union
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from jose import jwt # noqa
from passlib.context import CryptContext
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.common.exception.errors import TokenError, AuthorizationError
from backend.app.core.conf import settings
from backend.app.crud import crud_user
from backend.app.database.db_mysql import get_db
from backend.app.models import User
# import pdb
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') # 密码加密
oauth2_schema = OAuth2PasswordBearer(tokenUrl='/v1/users/login') # 指明客户端请求token的地址
def get_hash_password(password: str) -> str:
"""
使用hash算法加密密码
:param password:
:return:
"""
return pwd_context.hash(password)
def verity_password(plain_password: str, hashed_password: str) -> bool:
"""
密码校验
:param plain_password: 要验证的密码
:param hashed_password: 要比较的hash密码
:return: 比较密码之后的结果
"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: Union[int, Any], expires_delta: Union[timedelta, None] = None) -> str:
"""
生成加密 token
:param data: 传进来的值
:param expires_delta: 增加的到期时间
:return: 加密token
"""
if expires_delta:
expires = datetime.utcnow() + expires_delta
else:
expires = datetime.utcnow() + timedelta(settings.TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expires, "sub": str(data)}
encoded_jwt = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
return encoded_jwt
async def get_current_user(db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_schema)) -> User:
# pdb.set_trace()
"""
通过token获取当前用户
:param db:
:param token:
:return:
"""
try:
# 解密token
payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM])
user_id = payload.get('sub')
if not user_id:
raise TokenError
except (jwt.JWTError, ValidationError):
raise TokenError
user = await crud_user.get_user_by_id(db, user_id)
if not user:
raise TokenError
return user
async def get_current_is_superuser(user: User = Depends(get_current_user)):
"""
通过token验证当前用户权限
:param user:
:return:
"""
is_superuser = user.is_superuser
if not is_superuser:
raise AuthorizationError
return is_superuser
E:\song3\agv_backend_demo\backend\app\api\registrar.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from fastapi_pagination import add_pagination
from fastapi.responses import RedirectResponse
from fastapi.openapi.docs import get_swagger_ui_html
from backend.app.api.routers import v1
from backend.app.common.exception.exception_handler import register_exception
from backend.app.common.redis import redis_client
from backend.app.core.conf import settings
from backend.app.database.db_mysql import create_table
from backend.app.middleware import register_middleware
def register_app():
# FastAPI
app = FastAPI(
title=settings.TITLE,
version=settings.VERSION,
description=settings.DESCRIPTION,
docs_url=None,
# docs_url=None,
redoc_url=settings.REDOCS_URL,
openapi_url=settings.OPENAPI_URL
)
if settings.STATIC_FILES:
# 注册静态文件
register_static_file(app)
# 注册swagger ui
register_docs(app)
# 中间件
register_middleware(app)
# 路由
register_router(app)
# 初始化连接
register_init(app)
# 分页
register_page(app)
# 全局异常处理
register_exception(app)
return app
def register_docs(app: FastAPI):
@app.get("/", tags=['Docs'])
def index():
return RedirectResponse(settings.DOCS_URL)
@app.get(settings.DOCS_URL, include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI ",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js",
swagger_css_url="/static/swagger-ui/swagger-ui.css",
swagger_favicon_url="/static/swagger-ui/favicon.png",
)
def register_router(app: FastAPI):
"""
路由
:param app: FastAPI
:return:
"""
app.include_router(
v1,
)
def register_static_file(app: FastAPI):
"""
静态文件交互开发模式, 生产使用 nginx 静态资源服务
:param app:
:return:
"""
import os
from fastapi.staticfiles import StaticFiles
if not os.path.exists("./static"):
os.mkdir("./static")
app.mount("/static", StaticFiles(directory="static"), name="static")
def register_init(app: FastAPI):
"""
初始化连接
:param app: FastAPI
:return:
"""
@app.on_event("startup")
async def startup_event():
# 创建数据库表
await create_table()
# 连接redis
if settings.REDIS_OPEN:
# 连接redis
await redis_client.init_redis_connect()
@app.on_event("shutdown")
async def shutdown_event():
if settings.REDIS_OPEN:
# 关闭redis连接
await redis_client.close()
def register_page(app: FastAPI):
"""
分页查询
:param app:
:return:
"""
add_pagination(app)
E:\song3\agv_backend_demo\backend\app\api\routers.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from backend.app.api.v1.auth.user import user
v1 = APIRouter(prefix='/v1')
v1.include_router(user, prefix='/users', tags=['用户'])
E:\song3\agv_backend_demo\backend\app\api\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\api\service\user_service.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from hashlib import sha256
from email_validator import validate_email, EmailNotValidError
from fast_captcha import text_captcha
from fastapi import Request, HTTPException, Response, UploadFile
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_pagination.ext.async_sqlalchemy import paginate
from backend.app.api import jwt
from backend.app.common.exception import errors
from backend.app.common.log import log
from backend.app.common.redis import redis_client
from backend.app.common.response.response_code import CodeEnum
from backend.app.core.conf import settings
from backend.app.core.path_conf import AvatarPath
from backend.app.crud import crud_user
from backend.app.database.db_mysql import async_db_session
from backend.app.models import User
from backend.app.schemas.user import CreateUser, ResetPassword, UpdateUser, ELCode, Auth2
from backend.app.utils import re_verify
from backend.app.utils.format_string import cut_path
from backend.app.utils.generate_string import get_current_timestamp, get_uuid
from backend.app.utils.send_email import send_verification_code_email, SEND_EMAIL_LOGIN_TEXT
# 登录的逻辑,就是很简单
async def login(form_data: OAuth2PasswordRequestForm):
async with async_db_session() as db:
current_user = await crud_user.get_user_by_username(db, form_data.username)
if not current_user:
raise errors.NotFoundError(msg='用户名不存在')
elif not jwt.verity_password(form_data.password, current_user.password):
raise errors.AuthorizationError(msg='密码错误')
elif not current_user.is_active:
raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
# 更新登陆时间
await crud_user.update_user_login_time(db, form_data.username)
# 创建token
access_token = jwt.create_access_token(current_user.id)
return access_token, current_user.is_superuser
# async def login(obj: Auth):
# async with async_db_session() as db:
# current_user = await crud_user.get_user_by_username(db, obj.username)
# if not current_user:
# raise errors.NotFoundError(msg='用户名不存在')
# elif not jwt.verity_password(obj.password, current_user.password):
# raise errors.AuthorizationError(msg='密码错误')
# elif not current_user.is_active:
# raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
# # 更新登陆时间
# await crud_user.update_user_login_time(db, obj.username)
# # 创建token
# access_token = jwt.create_access_token(current_user.id)
# return access_token, current_user.is_superuser
async def login_email(*, request: Request, obj: Auth2):
async with async_db_session() as db:
current_email = await crud_user.check_email(db, obj.email)
if not current_email:
raise errors.NotFoundError(msg='邮箱不存在')
username = await crud_user.get_username_by_email(db, obj.email)
current_user = await crud_user.get_user_by_username(db, username)
if not current_user.is_active:
raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
try:
uid = request.app.state.email_login_code
except Exception:
raise errors.ForbiddenError(msg='请先获取邮箱验证码再登陆')
r_code = await redis_client.get(f'{uid}')
if not r_code:
raise errors.NotFoundError(msg='验证码失效,请重新获取')
if r_code != obj.code:
raise errors.CodeError(error=CodeEnum.CAPTCHA_ERROR)
await crud_user.update_user_login_time(db, username)
access_token = jwt.create_access_token(current_user.id)
return access_token, current_user.is_superuser
async def send_login_email_captcha(request: Request, obj: ELCode):
async with async_db_session() as db:
if not await crud_user.check_email(db, obj.email):
raise errors.NotFoundError(msg='邮箱不存在')
username = await crud_user.get_username_by_email(db, obj.email)
current_user = await crud_user.get_user_by_username(db, username)
if not current_user.is_active:
raise errors.ForbiddenError(msg='该用户已被锁定,无法登录,发送验证码失败')
try:
code = text_captcha()
await send_verification_code_email(obj.email, code, SEND_EMAIL_LOGIN_TEXT)
except Exception as e:
log.error('验证码发送失败 {}', e)
raise errors.ServerError(msg=f'验证码发送失败: {e}')
else:
uid = get_uuid()
await redis_client.set(uid, code, settings.EMAIL_LOGIN_CODE_MAX_AGE)
request.app.state.email_login_code = uid
async def register(obj: CreateUser):
async with async_db_session.begin() as db:
username = await crud_user.get_user_by_username(db, obj.username)
if username:
raise errors.ForbiddenError(msg='该用户名已注册')
email = await crud_user.check_email(db, obj.email)
if email:
raise errors.ForbiddenError(msg='该邮箱已注册')
try:
validate_email(obj.email, check_deliverability=False).email
except EmailNotValidError:
raise errors.ForbiddenError(msg='邮箱格式错误')
await crud_user.create_user(db, obj)
async def get_pwd_rest_captcha(*, username_or_email: str, response: Response):
async with async_db_session() as db:
code = text_captcha()
if await crud_user.get_user_by_username(db, username_or_email):
try:
response.delete_cookie(key='fastapi_reset_pwd_code')
response.delete_cookie(key='fastapi_reset_pwd_username')
response.set_cookie(
key='fastapi_reset_pwd_code',
value=sha256(code.encode('utf-8')).hexdigest(),
max_age=settings.COOKIES_MAX_AGE
)
response.set_cookie(
key='fastapi_reset_pwd_username',
value=username_or_email,
max_age=settings.COOKIES_MAX_AGE
)
except Exception as e:
log.exception('无法发送验证码 {}', e)
raise e
current_user_email = await crud_user.get_email_by_username(db, username_or_email)
await send_verification_code_email(current_user_email, code)
else:
try:
validate_email(username_or_email, check_deliverability=False)
except EmailNotValidError:
raise HTTPException(status_code=404, detail='用户名不存在')
email_result = await crud_user.check_email(db, username_or_email)
if not email_result:
raise HTTPException(status_code=404, detail='邮箱不存在')
try:
response.delete_cookie(key='fastapi_reset_pwd_code')
response.delete_cookie(key='fastapi_reset_pwd_username')
response.set_cookie(
key='fastapi_reset_pwd_code',
value=sha256(code.encode('utf-8')).hexdigest(),
max_age=settings.COOKIES_MAX_AGE
)
username = await crud_user.get_username_by_email(db, username_or_email)
response.set_cookie(
key='fastapi_reset_pwd_username',
value=username,
max_age=settings.COOKIES_MAX_AGE
)
except Exception as e:
log.exception('无法发送验证码 {}', e)
raise e
await send_verification_code_email(username_or_email, code)
async def pwd_reset(*, obj: ResetPassword, request: Request, response: Response):
async with async_db_session.begin() as db:
pwd1 = obj.password1
pwd2 = obj.password2
cookie_reset_pwd_code = request.cookies.get('fastapi_reset_pwd_code')
cookie_reset_pwd_username = request.cookies.get('fastapi_reset_pwd_username')
if pwd1 != pwd2:
raise errors.ForbiddenError(msg='两次密码输入不一致')
if cookie_reset_pwd_username is None or cookie_reset_pwd_code is None:
raise errors.NotFoundError(msg='验证码已失效,请重新获取验证码')
if cookie_reset_pwd_code != sha256(obj.code.encode('utf-8')).hexdigest():
raise errors.ForbiddenError(msg='验证码错误')
await crud_user.reset_password(db, cookie_reset_pwd_username, obj.password2)
response.delete_cookie(key='fastapi_reset_pwd_code')
response.delete_cookie(key='fastapi_reset_pwd_username')
async def get_user_info(username: str):
async with async_db_session() as db:
user = await crud_user.get_user_by_username(db, username)
if not user:
raise errors.NotFoundError(msg='用户不存在')
if user.avatar is not None:
user.avatar = cut_path(AvatarPath + user.avatar)[1]
return user
async def update(*, username: str, current_user: User, obj: UpdateUser):
async with async_db_session.begin() as db:
if not current_user.is_superuser:
if not username == current_user.username:
raise errors.AuthorizationError
input_user = await crud_user.get_user_by_username(db, username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')
if input_user.username != obj.username:
username = await crud_user.get_user_by_username(db, obj.username)
if username:
raise errors.ForbiddenError(msg='该用户名已存在')
if input_user.email != obj.email:
_email = await crud_user.check_email(db, obj.email)
if _email:
raise errors.ForbiddenError(msg='该邮箱已注册')
try:
validate_email(obj.email, check_deliverability=False).email
except EmailNotValidError:
raise errors.ForbiddenError(msg='邮箱格式错误')
if obj.mobile_number is not None:
if not re_verify.is_mobile(obj.mobile_number):
raise errors.ForbiddenError(msg='手机号码输入有误')
if obj.wechat is not None:
if not re_verify.is_wechat(obj.wechat):
raise errors.ForbiddenError(msg='微信号码输入有误')
if obj.qq is not None:
if not re_verify.is_qq(obj.qq):
raise errors.ForbiddenError(msg='QQ号码输入有误')
count = await crud_user.update_userinfo(db, input_user, obj)
return count
async def update_avatar(*, username: str, current_user: User, avatar: UploadFile):
async with async_db_session.begin() as db:
if not current_user.is_superuser:
if not username == current_user.username:
raise errors.AuthorizationError
input_user = await crud_user.get_user_by_username(db, username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')
input_user_avatar = input_user.avatar
if avatar is not None:
if input_user_avatar is not None:
try:
os.remove(AvatarPath + input_user_avatar)
except Exception as e:
log.error('用户 {} 更新头像时,原头像文件 {} 删除失败\n{}', current_user.username, input_user_avatar,
e)
new_file = await avatar.read()
if 'image' not in avatar.content_type:
raise errors.ForbiddenError(msg='图片格式错误,请重新选择图片')
file_name = str(get_current_timestamp()) + '_' + avatar.filename
if not os.path.exists(AvatarPath):
os.makedirs(AvatarPath)
with open(AvatarPath + f'{file_name}', 'wb') as f:
f.write(new_file)
else:
file_name = input_user_avatar
count = await crud_user.update_avatar(db, input_user, file_name)
return count
async def delete_avatar(*, username: str, current_user: User):
async with async_db_session.begin() as db:
if not current_user.is_superuser:
if not username == current_user.username:
raise errors.AuthorizationError
input_user = await crud_user.get_user_by_username(db, username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')
input_user_avatar = input_user.avatar
if input_user_avatar is not None:
try:
os.remove(AvatarPath + input_user_avatar)
except Exception as e:
log.error('用户 {} 删除头像文件 {} 失败\n{}', input_user.username, input_user_avatar, e)
else:
raise errors.NotFoundError(msg='用户没有头像文件,请上传头像文件后再执行此操作')
count = await crud_user.delete_avatar(db, input_user.id)
return count
async def get_user_list():
async with async_db_session() as db:
user_select = crud_user.get_users()
return await paginate(db, user_select)
async def update_permission(pk: int):
async with async_db_session.begin() as db:
if await crud_user.get_user_by_id(db, pk):
count = await crud_user.super_set(db, pk)
return count
else:
raise errors.NotFoundError(msg='用户不存在')
async def update_active(pk: int):
async with async_db_session.begin() as db:
if await crud_user.get_user_by_id(db, pk):
count = await crud_user.active_set(db, pk)
return count
else:
raise errors.NotFoundError(msg='用户不存在')
async def delete(*, username: str, current_user: User):
async with async_db_session.begin() as db:
if not current_user.is_superuser:
if not username == current_user.username:
raise errors.AuthorizationError
input_user = await crud_user.get_user_by_username(db, username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')
input_user_avatar = input_user.avatar
try:
if input_user_avatar is not None:
os.remove(AvatarPath + input_user_avatar)
except Exception as e:
log.error(f'删除用户 {input_user.username} 头像文件:{input_user_avatar} 失败\n{e}')
finally:
count = await crud_user.delete_user(db, input_user.id)
return count
E:\song3\agv_backend_demo\backend\app\api\service\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\api\v1\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\api\v1\auth\user.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends, Request, Response, UploadFile
from fastapi.security import OAuth2PasswordRequestForm
from backend.app.api import jwt
from backend.app.api.service import user_service
from backend.app.common.pagination import Page
from backend.app.common.response.response_schema import response_base
from backend.app.models import User
from backend.app.schemas.token import Token
from backend.app.schemas.user import Auth, CreateUser, GetUserInfo, ResetPassword, UpdateUser, ELCode, Auth2
user = APIRouter()
@user.post('/login', summary='表单登录', response_model=Token, description='form 格式登录支持直接在 api 文档调试接口')
async def user_login(form_data: OAuth2PasswordRequestForm = Depends()):
token, is_super = await user_service.login(form_data)
return Token(access_token=token, is_superuser=is_super)
# @user.post('/login', summary='用户登录', response_model=Token,
# description='json 格式登录, 不支持api文档接口调试, 需使用第三方api工具, 例如: postman')
# async def user_login(obj: Auth):
# token, is_super = await user_service.login(obj)
# return Token(access_token=token, is_superuser=is_super)
@user.post('/login/email/captcha', summary='发送邮箱登录验证码')
async def user_login_email_captcha(request: Request, obj: ELCode):
await user_service.send_login_email_captcha(request, obj)
return response_base.response_200(msg='验证码发送成功')
@user.post('/login/email', summary='邮箱登录', description='邮箱登录', response_model=Token)
async def user_login_email(request: Request, obj: Auth2):
token, is_super = await user_service.login_email(request=request, obj=obj)
return Token(access_token=token, is_superuser=is_super)
@user.post('/logout', summary='用户退出', dependencies=[Depends(jwt.get_current_user)])
async def user_logout():
return response_base.response_200(msg='退出登录成功')
@user.post('/register', summary='用户注册')
async def user_register(obj: CreateUser):
await user_service.register(obj)
return response_base.response_200(msg='用户注册成功')
@user.post('/password/reset/captcha', summary='获取密码重置验证码', description='可以通过用户名或者邮箱重置密码')
async def password_reset_captcha(username_or_email: str, response: Response):
await user_service.get_pwd_rest_captcha(username_or_email=username_or_email, response=response)
return response_base.response_200(msg='验证码发送成功')
@user.post('/password/reset', summary='密码重置请求')
async def password_reset(obj: ResetPassword, request: Request, response: Response):
await user_service.pwd_reset(obj=obj, request=request, response=response)
return response_base.response_200(msg='密码重置成功')
@user.get('/password/reset/done', summary='重置密码完成')
async def password_reset_done():
return response_base.response_200(msg='重置密码完成')
# @user.get('/{username}', summary='查看用户信息', )
@user.get('/{username}', summary='查看用户信息', dependencies=[Depends(jwt.get_current_user)])
async def get_userinfo(username: str):
current_user = await user_service.get_user_info(username)
return response_base.response_200(
msg='查看用户信息成功',
data=current_user,
exclude={'password'}
)
@user.put('/{username}', summary='更新用户信息')
async def update_userinfo(username: str, obj: UpdateUser, current_user: User = Depends(jwt.get_current_user)):
count = await user_service.update(username=username, current_user=current_user, obj=obj)
if count > 0:
return response_base.response_200(msg='更新用户信息成功')
return response_base.fail()
@user.put('/{username}/avatar', summary='更新头像')
async def update_avatar(username: str, avatar: UploadFile, current_user: User = Depends(jwt.get_current_user)):
count = await user_service.update_avatar(username=username, current_user=current_user, avatar=avatar)
if count > 0:
return response_base.response_200(msg='更新头像成功')
return response_base.fail()
@user.delete('/{username}/avatar', summary='删除头像文件')
async def delete_avatar(username: str, current_user: User = Depends(jwt.get_current_user)):
count = await user_service.delete_avatar(username=username, current_user=current_user)
if count > 0:
return response_base.response_200(msg='删除用户头像成功')
return response_base.fail()
@user.get('', summary='获取用户列表', response_model=Page[GetUserInfo], dependencies=[Depends(jwt.get_current_user)])
async def get_users():
return await user_service.get_user_list()
@user.post('/{pk}/super', summary='修改用户超级权限', dependencies=[Depends(jwt.get_current_is_superuser)])
async def super_set(pk: int):
count = await user_service.update_permission(pk)
if count > 0:
return response_base.response_200(msg=f'修改超级权限成功')
return response_base.fail()
@user.post('/{pk}/action', summary='修改用户状态', dependencies=[Depends(jwt.get_current_is_superuser)])
async def active_set(pk: int):
count = await user_service.update_active(pk)
if count > 0:
return response_base.response_200(msg=f'修改用户状态成功')
return response_base.fail()
@user.delete('/{username}', summary='用户注销', description='用户注销 != 用户退出,注销之后用户将从数据库删除')
async def delete_user(username: str, current_user: User = Depends(jwt.get_current_user)):
count = await user_service.delete(username=username, current_user=current_user)
if count > 0:
return response_base.response_200(msg='用户注销成功')
return response_base.fail()
E:\song3\agv_backend_demo\backend\app\api\v1\auth\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\common\log.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from loguru import logger
from backend.app.core import path_conf
class Logger:
@staticmethod
def log() -> logger:
if not os.path.exists(path_conf.LogPath):
os.mkdir(path_conf.LogPath)
# 日志文件
log_file = os.path.join(path_conf.LogPath, "FastBlog.log")
# loguru日志
# more: https://github.com/Delgan/loguru#ready-to-use-out-of-the-box-without-boilerplate
logger.add(
log_file,
encoding='utf-8',
level="DEBUG",
rotation='00:00', # 每天 0 点创建一个新日志文件
retention="7 days", # 定时自动清理文件
enqueue=True, # 异步安全
backtrace=True, # 错误跟踪
diagnose=True,
)
return logger
log = Logger().log()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from typing import TypeVar, Generic, Sequence
from fastapi import Query
from fastapi_pagination.bases import AbstractPage, AbstractParams, RawParams
from pydantic import BaseModel
T = TypeVar("T")
"""
重写分页库: fastapi-pagination
使用方法:example link: https://github.com/uriyyo/fastapi-pagination/tree/main/examples
"""
class Params(BaseModel, AbstractParams):
page: int = Query(1, ge=1, description="Page number")
size: int = Query(20, gt=0, le=100, description="Page size") # 默认 20 条记录
def to_raw_params(self) -> RawParams:
return RawParams(
limit=self.size,
offset=self.size * (self.page - 1),
)
class Page(AbstractPage[T], Generic[T]):
data: Sequence[T] # 数据
total: int # 总数据数
page: int # 第n页
size: int # 每页数量
next: str # 下一页参数
previous: str # 上一页参数
total_pages: int # 总页数
__params_type__ = Params # 使用自定义的Params
@classmethod
def create(
cls,
data: data,
total: int,
params: Params,
) -> Page[T]:
page = params.page
size = params.size
total_pages = math.ceil(total / params.size)
next = f"?page={page + 1}&size={size}" if (page + 1) <= total_pages else "null" # noqa
previous = f"?page={page - 1}&size={size}" if (page - 1) >= 1 else "null"
return cls(
data=data,
total=total,
page=params.page,
size=params.size,
next=next,
previous=previous,
total_pages=total_pages
)
E:\song3\agv_backend_demo\backend\app\common\redis.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
from aioredis import Redis, TimeoutError, AuthenticationError
from backend.app.common.log import log
from backend.app.core.conf import settings
class RedisCli(Redis):
def __init__(self):
super(RedisCli, self).__init__(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD,
db=settings.REDIS_DATABASE,
socket_timeout=settings.REDIS_TIMEOUT,
decode_responses=True # 转码 utf-8
)
async def init_redis_connect(self):
"""
触发初始化连接
:return:
"""
try:
await self.ping()
except TimeoutError:
log.error("连接redis超时")
sys.exit()
except AuthenticationError:
log.error("连接redis认证失败")
sys.exit()
except Exception as e:
log.error('连接redis异常 {}', e)
sys.exit()
# 创建redis连接对象
redis_client = RedisCli()
E:\song3\agv_backend_demo\backend\app\common\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\common\exception\errors.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any
from fastapi import HTTPException
from backend.app.common.response.response_code import CodeEnum
class BaseExceptionMixin(Exception):
code: int
def __init__(self, *, msg: str = None, data: Any = None):
self.msg = msg
self.data = data
class HTTPError(HTTPException):
pass
class RequestError(BaseExceptionMixin):
code = 400
def __init__(self, *, msg: str = 'Bad Request', data: Any = None):
super().__init__(msg=msg, data=data)
class ForbiddenError(BaseExceptionMixin):
code = 403
def __init__(self, *, msg: str = 'Forbidden', data: Any = None):
super().__init__(msg=msg, data=data)
class NotFoundError(BaseExceptionMixin):
code = 404
def __init__(self, *, msg: str = 'Not Found', data: Any = None):
super().__init__(msg=msg, data=data)
class ServerError(BaseExceptionMixin):
code = 500
def __init__(self, *, msg: str = 'Internal Server Error', data: Any = None):
super().__init__(msg=msg, data=data)
class GatewayError(BaseExceptionMixin):
code = 502
def __init__(self, *, msg: str = 'Bad Gateway', data: Any = None):
super().__init__(msg=msg, data=data)
class CodeError(BaseExceptionMixin):
def __init__(self, *, error: CodeEnum, data: Any = None):
self.code = error.code
super().__init__(msg=error.msg, data=data)
class AuthorizationError(BaseExceptionMixin):
code = 401
def __init__(self, *, msg: str = 'Permission denied', data: Any = None):
super().__init__(msg=msg, data=data)
class TokenError(BaseExceptionMixin):
code = 401
def __init__(self, *, msg: str = 'Token is invalid', data: Any = None):
super().__init__(msg=msg, data=data)
E:\song3\agv_backend_demo\backend\app\common\exception\exception_handler.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError, HTTPException
from pydantic import ValidationError
from starlette.responses import JSONResponse
from uvicorn.protocols.http.h11_impl import STATUS_PHRASES
from backend.app.common.exception.errors import BaseExceptionMixin
from backend.app.common.response.response_schema import response_base
from backend.app.core.conf import settings
def _get_exception_code(status_code):
"""
获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接
`python 状态码标准支持 <https://github.com/python/cpython/blob/6e3cc72afeaee2532b4327776501eb8234ac787b/Lib/http
/__init__.py#L7>`__
`IANA 状态码注册表 <https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml>`__
:param status_code:
:return:
"""
try:
STATUS_PHRASES[status_code]
except Exception: # noqa
code = 400
else:
code = status_code
return code
def register_exception(app: FastAPI):
@app.exception_handler(HTTPException)
def http_exception_handler(request: Request, exc: HTTPException): # noqa
"""
全局HTTP异常处理
:param request:
:param exc:
:return:
"""
return JSONResponse(
status_code=_get_exception_code(exc.status_code),
content=response_base.fail(code=exc.status_code, msg=exc.detail),
headers=exc.headers
)
@app.exception_handler(Exception)
def all_exception_handler(request: Request, exc): # noqa
"""
全局异常处理
:param request:
:param exc:
:return:
"""
# 常规
if isinstance(exc, RequestValidationError):
message = ''
data = {}
for raw_error in exc.raw_errors:
if isinstance(raw_error.exc, ValidationError):
exc = raw_error.exc
if hasattr(exc, 'model'):
fields = exc.model.__dict__.get('__fields__')
for field_key in fields.keys():
field_title = fields.get(field_key).field_info.title
data[field_key] = field_title if field_title else field_key
for error in exc.errors():
field = str(error.get('loc')[-1])
_msg = error.get('msg')
message += f'{data.get(field, field)} {_msg},'
elif isinstance(raw_error.exc, json.JSONDecodeError):
message += 'json解析失败'
return JSONResponse(
status_code=422,
content=response_base.fail(
msg='请求参数非法' if len(message) == 0 else f"请求参数非法:{message[:-1]}",
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None
)
)
# 自定义
if isinstance(exc, BaseExceptionMixin):
return JSONResponse(
status_code=_get_exception_code(exc.code),
content=response_base.fail(
code=exc.code,
msg=str(exc.msg),
data=exc.data if exc.data else None
)
)
else:
return JSONResponse(
status_code=500,
content=response_base.fail(code=500, msg=str(exc)) if settings.UVICORN_RELOAD else
response_base.fail(code=500, msg='Internal Server Error')
)
E:\song3\agv_backend_demo\backend\app\common\exception\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\common\response\response_code.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
class CodeEnum(Enum):
"""
自定义错误码
"""
CAPTCHA_ERROR = (40001, '验证码错误')
@property
def code(self):
"""
获取错误码
"""
return self.value[0]
@property
def msg(self):
"""
获取错误码码信息
"""
return self.value[1]
E:\song3\agv_backend_demo\backend\app\common\response\response_schema.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional, Any, Union, Set, Dict
from fastapi.encoders import jsonable_encoder
from pydantic import validate_arguments, BaseModel
_JsonEncoder = Union[Set[Union[int, str]], Dict[Union[int, str], Any]]
__all__ = [
'ResponseModel',
'response_base'
]
class ResponseModel(BaseModel):
"""
统一返回模型, 可以在 FastAPI 接口请求中使用 response_model=ResponseModel 及更多操作, 前提是当它是一个非 200 响应时
"""
code: int = 200
msg: str = 'Success'
data: Optional[Any] = None
class Config:
json_encoders = {
datetime: lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
}
class ResponseBase:
@staticmethod
def __encode_json(data: Any):
return jsonable_encoder(
data,
custom_encoder={
datetime: lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
}
)
@staticmethod
@validate_arguments
def success(*, code: int = 200, msg: str = 'Success', data: Optional[Any] = None,
exclude: Optional[_JsonEncoder] = None):
"""
请求成功返回通用方法
:param code: 返回状态码
:param msg: 返回信息
:param data: 返回数据
:param exclude: 排除返回数据(data)字段
:return:
"""
data = data if data is None else ResponseBase.__encode_json(data)
return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude})
@staticmethod
@validate_arguments
def fail(*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: Optional[_JsonEncoder] = None):
data = data if data is None else ResponseBase.__encode_json(data)
return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude})
@staticmethod
@validate_arguments
def response_200(*, msg: str = 'Success', data: Optional[Any] = None, exclude: Optional[_JsonEncoder] = None):
data = data if data is None else ResponseBase.__encode_json(data)
return ResponseModel(code=200, msg=msg, data=data).dict(exclude={'data': exclude})
response_base = ResponseBase()
E:\song3\agv_backend_demo\backend\app\common\response\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\core\conf.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import lru_cache
from pydantic import BaseSettings
class Settings(BaseSettings):
""" 配置类 """
# FastAPI
TITLE: str = 'FastAPI'
VERSION: str = 'v0.0.1'
DESCRIPTION: str = """fastapi_sqlalchemy_mysql"""
DOCS_URL: str = '/v1/docs'
REDOCS_URL: str = None
OPENAPI_URL: str = '/v1/openapi'
# Uvicorn
UVICORN_HOST: str = '127.0.0.1'
UVICORN_PORT: int = 8000
UVICORN_RELOAD: bool = True
# 如果此处为True,在 @app.on_event("startup") 时发生异常,则程序不会终止,详情:https://github.com/encode/starlette/issues/486
# Static Server
STATIC_FILES: bool = True
# DB
DB_ECHO: bool = False
DB_HOST: str = '127.0.0.1'
DB_PORT: int = 3306
DB_USER: str = 'root'
DB_PASSWORD: str = '123456'
DB_DATABASE: str = 'fsm'
DB_CHARSET: str = 'utf8mb4'
# Redis
REDIS_OPEN: bool = False
REDIS_HOST: str = '127.0.0.1'
REDIS_PORT: int = 6379
REDIS_PASSWORD: str = ''
REDIS_DATABASE: int = 0
REDIS_TIMEOUT: int = 5
# Token
TOKEN_ALGORITHM: str = 'HS256' # 算法
TOKEN_SECRET_KEY: str = '1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk' # 密钥 secrets.token_urlsafe(32))
TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 # token 时效 60 * 24 * 1 = 1 天
# Email
EMAIL_DESCRIPTION: str = 'fastapi_sqlalchemy_mysql' # 默认发件说明
EMAIL_SERVER: str = 'smtp.qq.com'
EMAIL_PORT: int = 465
EMAIL_USER: str = '729519678@qq.com'
EMAIL_PASSWORD: str = 'gmrvkkppberzbega' # 授权密码,非邮箱密码
EMAIL_SSL: bool = True
# 邮箱登录验证码过期时间
EMAIL_LOGIN_CODE_MAX_AGE: int = 60 * 2 # 时效 60 * 2 = 2 分钟
# Cookies
COOKIES_MAX_AGE: int = 60 * 5 # cookies 时效 60 * 5 = 5 分钟
# Middleware
MIDDLEWARE_CORS: bool = True
MIDDLEWARE_GZIP: bool = True
MIDDLEWARE_ACCESS: bool = False
@lru_cache
def get_settings():
""" 读取配置优化写法 """
return Settings()
settings = get_settings()
E:\song3\agv_backend_demo\backend\app\core\path_conf.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from pathlib import Path
# 获取项目根目录
# 或使用绝对路径,指到backend目录为止,例如windows:BasePath = D:\git_project\fastapi_mysql\backend
BasePath = Path(__file__).resolve().parent.parent.parent
# 迁移文件存放路径
Versions = os.path.join(BasePath, 'app', 'alembic', 'versions')
# 日志文件路径
LogPath = os.path.join(BasePath.parent, 'log')
# 图片上传存放路径: /static/media/uploads/
ImgPath = os.path.join(BasePath, 'app', 'static', 'media', 'uploads')
# 头像上传存放路径: /static/media/uploads/avatars/
AvatarPath = os.path.join(ImgPath, 'avatars', '')
# sqlite的路径
SqlitePath = os.path.join(BasePath.parent, 'sqlite_db')
E:\song3\agv_backend_demo\backend\app\core\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\crud\crud_user.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, NoReturn
from sqlalchemy import func, select, update, delete, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from backend.app.api import jwt
from backend.app.models import User
from backend.app.schemas.user import CreateUser, DeleteUser, UpdateUser
async def get_user_by_id(db: AsyncSession, user_id: int) -> Optional[User]:
user = await db.execute(select(User).where(User.id == user_id))
return user.scalars().first()
async def get_user_by_username(db: AsyncSession, username: str) -> Optional[User]:
user = await db.execute(select(User).where(User.username == username))
return user.scalars().first()
async def update_user_login_time(db: AsyncSession, username: str) -> int:
user = await db.execute(
update(User)
.where(User.username == username)
.values(last_login=func.now())
)
return user.rowcount
async def get_email_by_username(db: AsyncSession, username: str) -> str:
user = await get_user_by_username(db, username)
return user.email
async def get_username_by_email(db: AsyncSession, email: str) -> str:
user = await db.execute(select(User).where(User.email == email))
return user.scalars().first().username
async def get_avatar_by_username(db: AsyncSession, username: str) -> str:
user = await db.execute(select(User).where(User.username == username))
return user.scalars().first().avatar
async def create_user(db: AsyncSession, create: CreateUser) -> NoReturn:
create.password = jwt.get_hash_password(create.password)
new_user = User(**create.dict())
db.add(new_user)
async def update_userinfo(db: AsyncSession, current_user: User, obj: UpdateUser) -> int:
user = await db.execute(
update(User)
.where(User.id == current_user.id)
.values(**obj.dict())
)
return user.rowcount
async def update_avatar(db: AsyncSession, current_user: User, avatar: str) -> int:
user = await db.execute(
update(User)
.where(User.id == current_user.id)
.values(avatar=avatar)
)
return user.rowcount
async def delete_user(db: AsyncSession, user_id: DeleteUser) -> int:
user = await db.execute(delete(User).where(User.id == user_id))
return user.rowcount
async def check_email(db: AsyncSession, email: str) -> User:
mail = await db.execute(select(User).where(User.email == email))
return mail.scalars().first()
async def delete_avatar(db: AsyncSession, user_id: int) -> int:
user = await db.execute(
update(User)
.where(User.id == user_id)
.values(avatar=None)
)
return user.rowcount
async def reset_password(db: AsyncSession, username: str, password: str) -> int:
user = await db.execute(
update(User)
.where(User.username == username)
.values(password=jwt.get_hash_password(password))
)
return user.rowcount
def get_users() -> Select:
return select(User).order_by(desc(User.time_joined))
async def get_user_is_super(db: AsyncSession, user_id: int) -> bool:
user = await get_user_by_id(db, user_id)
return user.is_superuser
async def get_user_is_active(db: AsyncSession, user_id: int) -> bool:
user = await get_user_by_id(db, user_id)
return user.is_active
async def super_set(db: AsyncSession, user_id: int) -> int:
super_status = await get_user_is_super(db, user_id)
user = await db.execute(
update(User)
.where(User.id == user_id)
.values(is_superuser=False if super_status else True)
)
return user.rowcount
async def active_set(db: AsyncSession, user_id: int) -> int:
active_status = await get_user_is_active(db, user_id)
user = await db.execute(
update(User)
.where(User.id == user_id)
.values(is_active=False if active_status else True)
)
return user.rowcount
E:\song3\agv_backend_demo\backend\app\crud\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\database\base_class.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr, MappedAsDataclass
from typing_extensions import Annotated
# 通用 Mapped 类型主键, 需手动添加,参考以下使用方式
# MappedBase -> id: Mapped[id_key]
# DataClassBase && Base -> id: Mapped[id_key] = mapped_column(init=False)
id_key = Annotated[int, mapped_column(primary_key=True, index=True, autoincrement=True, comment='主键id')]
class _BaseMixin:
"""
Mixin 数据类
Mixin: 一种面向对象编程概念, 使结构变得更加清晰, `Wiki <https://en.wikipedia.org/wiki/Mixin/>`__
"""
create_user: Mapped[int] = mapped_column(comment='创建者')
update_user: Mapped[Optional[int]] = mapped_column(default=None, comment='修改者')
created_time: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='创建时间')
updated_time: Mapped[Optional[datetime]] = mapped_column(init=False, onupdate=func.now(), comment='更新时间')
class MappedBase(DeclarativeBase):
"""
声明性基类, 原始 DeclarativeBase 类, 作为所有基类或数据模型类的父类而存在
`DeclarativeBase <https://docs.sqlalchemy.org/en/20/orm/declarative_config.html>`__
`mapped_column() <https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.mapped_column>`__
"""
@declared_attr.directive
def __tablename__(cls) -> str: # noqa
return cls.__name__.lower()
class DataClassBase(MappedAsDataclass, MappedBase):
"""
声明性数据类基类, 它将带有数据类集成, 允许使用更高级配置, 但你必须注意它的一些特性, 尤其是和 DeclarativeBase 一起使用时
`MappedAsDataclass <https://docs.sqlalchemy.org/en/20/orm/dataclasses.html#orm-declarative-native-dataclasses>`__
"""
__abstract__ = True
class Base(_BaseMixin, MappedAsDataclass, MappedBase):
"""
声明性 Mixin 数据类基类, 带有数据类集成, 并包含 MiXin 数据类基础表结构, 你可以简单的理解它为含有基础表结构的数据类基类
"""
__abstract__ = True
def use_uuid() -> str:
"""
使用uuid
:return:
"""
return uuid.uuid4().hex
E:\song3\agv_backend_demo\backend\app\database\db_mysql.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from backend.app.common.log import log
from backend.app.core.conf import settings
from backend.app.core.path_conf import SqlitePath
from backend.app.database.base_class import MappedBase
"""
说明:SqlAlchemy
"""
# SQLALCHEMY_DATABASE_URL = f'mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:' \
# f'{settings.DB_PORT}/{settings.DB_DATABASE}?charset={settings.DB_CHARSET}'
SQLALCHEMY_DATABASE_URL = f'sqlite+aiosqlite:///{SqlitePath}/test.db'
try:
# 数据库引擎
async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=settings.DB_ECHO, future=True)
# log.success('数据库连接成功')
except Exception as e:
log.error('❌ 数据库链接失败 {}', e)
sys.exit()
else:
async_db_session = async_sessionmaker(bind=async_engine, autoflush=False, expire_on_commit=False)
async def get_db() -> AsyncSession:
"""
session 生成器
:return:
"""
session = async_db_session()
try:
yield session
except Exception as se:
await session.rollback()
raise se
finally:
await session.close()
async def create_table():
"""
创建数据库表
"""
async with async_engine.begin() as coon:
await coon.run_sync(MappedBase.metadata.create_all)
E:\song3\agv_backend_demo\backend\app\database\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\enums\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
E:\song3\agv_backend_demo\backend\app\middleware\access_middle.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from backend.app.common.log import log
class AccessMiddleware(BaseHTTPMiddleware):
"""
记录请求日志
"""
async def dispatch(self, request: Request, call_next) -> Response:
start_time = datetime.now()
response = await call_next(request)
end_time = datetime.now()
log.info(f"{response.status_code} {request.client.host} {request.method} {request.url} {end_time - start_time}")
return response
E:\song3\agv_backend_demo\backend\app\middleware\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from backend.app.core.conf import settings
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from backend.app.middleware.access_middle import AccessMiddleware
def register_middleware(app: FastAPI) -> None:
# cors
if settings.MIDDLEWARE_CORS:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# gzip
if settings.MIDDLEWARE_GZIP:
app.add_middleware(GZipMiddleware)
# 接口访问日志
if settings.MIDDLEWARE_ACCESS:
app.add_middleware(AccessMiddleware)
E:\song3\agv_backend_demo\backend\app\models\user.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from sqlalchemy import func, String
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import Mapped, mapped_column
from backend.app.database.base_class import use_uuid, id_key, DataClassBase
class User(DataClassBase):
""" 用户表 """
__tablename__ = 'sys_user'
id: Mapped[id_key] = mapped_column(init=False)
uid: Mapped[str] = mapped_column(String(50), init=False, insert_default=use_uuid, unique=True, comment='唯一标识')
username: Mapped[str] = mapped_column(String(20), unique=True, index=True, comment='用户名')
password: Mapped[str] = mapped_column(String(255), comment='密码')
email: Mapped[str] = mapped_column(String(50), unique=True, index=True, comment='邮箱')
is_superuser: Mapped[bool] = mapped_column(default=False, comment='超级权限')
is_active: Mapped[bool] = mapped_column(default=True, comment='用户账号状态')
avatar: Mapped[Optional[str]] = mapped_column(String(255), default=None, comment='头像')
mobile_number: Mapped[Optional[str]] = mapped_column(String(11), default=None, comment='手机号')
wechat: Mapped[Optional[str]] = mapped_column(String(20), default=None, comment='微信')
qq: Mapped[Optional[str]] = mapped_column(String(10), default=None, comment='QQ')
blog_address: Mapped[Optional[str]] = mapped_column(String(255), default=None, comment='博客地址')
introduction: Mapped[Optional[str]] = mapped_column(String(1000), default=None, comment='自我介绍')
time_joined: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='注册时间')
last_login: Mapped[Optional[datetime]] = mapped_column(init=False, onupdate=func.now(), comment='上次登录')
E:\song3\agv_backend_demo\backend\app\models\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
# 导入所有模型,并将 Base 放在最前面, 以便 Base 拥有它们
# imported by Alembic
"""
from backend.app.database.base_class import MappedBase # noqa
from backend.app.models.user import User
E:\song3\agv_backend_demo\backend\app\schemas\token.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional
from pydantic import BaseModel
class Token(BaseModel):
code: int = 200
msg: str = 'Success'
access_token: str
token_type: str = 'Bearer'
is_superuser: Optional[bool] = None
E:\song3\agv_backend_demo\backend\app\schemas\user.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
from typing import Optional
from pydantic import BaseModel, Field, EmailStr
class Auth(BaseModel):
username: str
password: str
class ELCode(BaseModel):
email: EmailStr
class Auth2(ELCode):
code: str
class CreateUser(Auth):
email: str = Field(..., example='user@example.com')
class UpdateUser(BaseModel):
username: str
email: str
mobile_number: Optional[str] = None
wechat: Optional[str] = None
qq: Optional[str] = None
blog_address: Optional[str] = None
introduction: Optional[str] = None
class GetUserInfo(UpdateUser):
id: int
uid: str
avatar: Optional[str] = None
time_joined: datetime.datetime = None
last_login: Optional[datetime.datetime] = None
is_superuser: bool
is_active: bool
class Config:
orm_mode = True
class DeleteUser(BaseModel):
id: int
class ResetPassword(BaseModel):
code: str
password1: str
password2: str
E:\song3\agv_backend_demo\backend\app\schemas\__init__.py
E:\song3\agv_backend_demo\backend\app\utils\encoding_tree.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
def list_to_tree(data_list, parent_id=0) -> list:
"""
递归获取树形结构数据
:param data_list: 数据列表
:param parent_id: 父级id
:return:
"""
tree = []
for _data in data_list:
if _data['parent_id'] == parent_id:
tree.append(_data)
_data['children'] = list_to_tree(data_list, _data['id'])
return tree
def ram_list_to_tree(data_list: list) -> list:
"""
利用对象内存共享生成树
:param data_list: 数据列表
:return:
"""
res = {}
for v in data_list:
res.setdefault(v["id"], v)
for v in data_list:
res.setdefault(v["parent_id"], {}).setdefault("children", []).append(v)
return res[0]["children"]
if __name__ == '__main__':
test_data1 = [
{'id': 1, 'title': 'GGG', 'parent_id': 0},
{'id': 2, 'title': 'AAA', 'parent_id': 0},
{'id': 3, 'title': 'BBB', 'parent_id': 1},
{'id': 4, 'title': 'CCC', 'parent_id': 1},
{'id': 5, 'title': 'DDD', 'parent_id': 2},
{'id': 6, 'title': 'EEE', 'parent_id': 3},
{'id': 7, 'title': 'FFF', 'parent_id': 4},
{'id': 3, 'title': 'BBB', 'parent_id': 1},
]
print(json.dumps(list_to_tree(test_data1), indent=4))
test_data2 = [
{'id': 10, 'parent_id': 8, 'name': "ACAB"},
{'id': 9, 'parent_id': 8, 'name': "ACAA"},
{'id': 8, 'parent_id': 7, 'name': "ACA"},
{'id': 7, 'parent_id': 1, 'name': "AC"},
{'id': 6, 'parent_id': 3, 'name': "ABC"},
{'id': 5, 'parent_id': 3, 'name': "ABB"},
{'id': 4, 'parent_id': 3, 'name': "ABA"},
{'id': 3, 'parent_id': 1, 'name': "AB"},
{'id': 2, 'parent_id': 0, 'name': "AA"},
{'id': 1, 'parent_id': 0, 'name': "A"},
]
print(json.dumps(ram_list_to_tree(test_data2), indent=4))
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from backend.app.core.path_conf import AvatarPath
def cut_path(path: str = AvatarPath, split_point: str = 'app') -> list:
"""
切割路径
:param path:
:param split_point:
:return:
"""
after_path = path.split(split_point)
return after_path
E:\song3\agv_backend_demo\backend\app\utils\generate_string.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import uuid
def get_uuid() -> str:
"""
生成uuid
:return: str(uuid)
"""
return str(uuid.uuid4())
def get_current_timestamp() -> float:
"""
生成当前时间戳
:return:
"""
return datetime.datetime.now().timestamp()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import Form
def encode_as_form(cls):
"""
pydantic 类装饰器,将 pydantic 类转化为 form_data
示例::
@encode_as_form
class Pydantic(BaseModel):
...
:param cls:
:return:
"""
cls.__signature__ = cls.__signature__.replace(
parameters=[
arg.replace(default=Form(...))
for arg in cls.__signature__.parameters.values()
]
)
return cls
E:\song3\agv_backend_demo\backend\app\utils\re_verify.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
def search_string(pattern, text) -> bool:
"""
全字段正则匹配
:param pattern:
:param text:
:return:
"""
result = re.search(pattern, text)
if result:
return True
else:
return False
def match_string(pattern, text) -> bool:
"""
从字段开头正则匹配
:param pattern:
:param text:
:return:
"""
result = re.match(pattern, text)
if result:
return True
else:
return False
def is_mobile(text: str) -> bool:
"""
检查手机号码
:param text:
:return:
"""
return match_string(r"^1[3-9]\d{9}$", text)
def is_wechat(text: str) -> bool:
"""
检查微信号
:param text:
:return:
"""
return match_string(r"^[a-zA-Z]([-_a-zA-Z0-9]{5,19})+$", text)
def is_qq(text: str) -> bool:
"""
检查QQ号
:param text:
:return:
"""
return match_string(r"^[1-9][0-9]{4,10}$", text)
E:\song3\agv_backend_demo\backend\app\utils\send_email.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import aiosmtplib
from backend.app.common.log import log
from backend.app.core.conf import settings
from backend.app.utils.generate_string import get_uuid
__only_code = get_uuid()
SEND_RESET_PASSWORD_TEXT = f"您的重置密码验证码为:{__only_code}\n为了不影响您正常使用," \
f"请在{int(settings.COOKIES_MAX_AGE / 60)}分钟内完成密码重置"
SEND_EMAIL_LOGIN_TEXT = f"您的登录验证码为:{__only_code}\n" \
f"请在{int(settings.EMAIL_LOGIN_CODE_MAX_AGE / 60)}分钟内完成登录"
async def send_verification_code_email(to: str, code: str, text: str = SEND_RESET_PASSWORD_TEXT):
"""
发送验证码电子邮件
:param to:
:param code:
:param text:
:return:
"""
text = text.replace(__only_code, code)
msg = MIMEMultipart()
msg['Subject'] = settings.EMAIL_DESCRIPTION
msg['From'] = settings.EMAIL_USER
msg.attach(MIMEText(text, _charset='utf-8'))
# 登录smtp服务器并发送邮件
try:
smtp = aiosmtplib.SMTP(hostname=settings.EMAIL_SERVER, port=settings.EMAIL_PORT, use_tls=settings.EMAIL_SSL)
async with smtp:
await smtp.login(settings.EMAIL_USER, settings.EMAIL_PASSWORD)
await smtp.sendmail(msg['From'], to, msg.as_string())
await smtp.quit()
except Exception as e:
log.error('邮件发送失败 {}', e)
raise Exception('邮件发送失败 {}'.format(e))
E:\song3\agv_backend_demo\backend\app\utils\__init__.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-