FastAPI系列:中间件
中间件介绍
中间件是一个函数,它在每个请求被特定的路径操作处理之前 ,以及在每个响应返回之前工作
装饰器版中间件
1.必须使用装饰器@app.middleware("http"),且middleware_type必须为http
2.中间件参数:request, call_next,且call_next
它将接收 request
作为参数
@app.middleware("http")
async def custom_middleware(request: Request, call_next):
logger.info("Before request")
response = await call_next(request) # 让请求继续处理
logger.info("After request")
# 也可以在返回response之前做一些事情,比如添加响应头header
# response.headers['xxx'] = 'xxx'
return response
@app.get("/")
def read_root():
logger.info("执行了.......")
return {"message": "hello world"}
自定义中间件BaseHTTPMiddleware
BaseHTTPMiddleware是一个抽象类,允许您针对请求/响应接口编写ASGI中间件
要使用 实现中间件类BaseHTTPMiddleware
,您必须重写该 async def dispatch(request, call_next)
方法,
如果您想为中间件类提供配置选项,您应该重写该__init__
方法,确保第一个参数是app
,并且任何剩余参数都是可选关键字参数。app
如果执行此操作,请确保在实例上设置该属性。
# 通过继承BaseHTTPMiddleware来实现自定义的中间件
import time
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from fastapi import FastAPI, Request
from starlette.responses import Response
app = FastAPI()
# 基于BaseHTTPMiddleware的中间件实例
class TimeCcalculateMiddleware(BaseHTTPMiddleware):
# dispatch必须实现
async def dispatch(self, request: Request, call_next):
print('start')
start_time = time.time()
response = await call_next(request)
process_time = round(time.time() - start_time, 4)
#返回接口响应事件
response.headers['X-Process-Time'] = f"{process_time} (s)"
print('end')
return response
class AuthMiddleware(BaseHTTPMiddleware):
def __init__(self,app, header_value='auth'):
super().__init__(app)
self.header_value = header_value
#dispatch必须实现
async def dispatch(self, request:Request, call_next):
print('auth start')
response = await call_next(request)
response.headers['Custom'] = self.header_value
print('auth end')
return response
# fastapi实例的add_middleware方法
app.add_middleware(TimeCcalculateMiddleware)
app.add_middleware(AuthMiddleware, header_value='CustomAuth')
@app.get('/index')
async def index():
print('index start')
return {
'code': 200
}
"""执行顺序
auth start
start
index start
end
auth end
"""
ip白名单中间件(基于纯ASGI中间)
根据官网说明BaseHTTPMiddleware有一些已知的局限性:
使用BaseHTTPMiddleware
将阻止对contextlib.ContextVar
的更改向上传播。
也就是说,如果您ContextVar
在端点中设置 a 值并尝试从中间件读取它,您会发现该值与您在端点中设置的值不同
纯ASGI中间件,使用类的方式
class ASGIMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
上面的中间件是最基本的ASGI中间件。它接收父 ASGI 应用程序作为其构造函数的参数,并实现async __call__
调用该父应用程序的方法。
无论如何,ASGI 中间件必须是接受三个参数的可调用对象:scope
、receive
和send
- scope是一个保存有关连接信息的字典,其中scope["type"]可能是:
"http"
:用于 HTTP 请求。"websocket"
:用于 WebSocket 连接。"lifespan"
:用于 ASGI 生命周期消息。
receive
和send
可以用来与ASGI服务器交换ASGI事件消息。这些消息的类型和内容取决于作用域类型。在ASGI规范中了解更多信息
# 基于自定义类来实现
from fastapi import FastAPI
app = FastAPI()
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.requests import HTTPConnection
import typing
class WhiteIpMiddleware:
def __init(self, app:ASGIApp, allow_ip: typing.Sequence[str] = ()) -> None:
self.app = app
self.allow_ip = allow_ip or '*'
async def __call__(self, scope:Scope, receive:Receive, send:Send)->None:
if scope['type'] in ('http','websocket') and scope['scheme'] in ('http', 'ws'):
conn = HTTPConnection(scope=scope)
if self.allow_ip and conn.client.host not in self.allow_ip:
response = PlainTextResponse(content='不在ip白名单内', status_code=403)
await response(scope, receive, send)
return
await self.app(scope, receive, send)
else:
await self.app(scope, receive, send)
app.add_middleware(WhiteIpMiddleware, allow_ip=['127.0.0.2'])
@app.get('/index')
async def index():
print('index-start')
return {'code': 200}
跨域中间件cors
同源:协议,域,端口相同
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware # fastapi内置了一个CORSMiddleware,可以直接使用
import uvicorn
app = FastAPI()
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, #一个允许跨域请求的源列表
allow_credentials=True, #指示跨域请求支持 cookies,默认False, 另外,允许凭证时allow_origins 不能设定为 ['*'],必须指定源。
allow_methods=["*"], # 一个允许跨域请求的 HTTP 方法列表,默认get
allow_headers=["*"], # 一个允许跨域请求的 HTTP 请求头列表
)
@app.get("/")
async def main():
return {"message": "hello world"}
if __name__ == '__main__':
uvicorn.run(app=app)
-------------------------------------------
个性签名:代码过万,键盘敲烂!!!
如果觉得这篇文章对你有小小的帮助的话,记得在右下角点个“推荐”哦,博主在此感谢!