toml
[dependencies]
actix-web = "4"
mysql = "25.0.0"
chrono = "0.4"
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9"
constants
pub const MESSAGE_OK: &str = "ok";
pub const MESSAGE_CAN_NOT_FETCH_DATA: &str = "查询失败";
pub const MESSAGE_CAN_NOT_INSERT_DATA: &str = "插入失败";
pub const MESSAGE_CAN_NOT_UPDATE_DATA: &str = "更新失败";
pub const MESSAGE_CAN_NOT_DELETE_DATA: &str = "删除失败";
pub const SERVER_ERROR: &str = "系统错误";
pub const MESSAGE_TOKEN_GENERATE_ERROR: &str = "token生成失败";
pub const MESSAGE_TOKEN_EMPTY: &str = "token为空";
pub const MESSAGE_TOKEN_VALIDATE_ERROR: &str = "token校验失败";
pub const EMPTY: &str = "";
pub const IGNORE_ROUTES: [&str; 2] = ["/generate-token", "/"];
main.rs
use std::future::Future;
use std::io::empty;
use actix_web::{get, post, web, App, HttpServer, Responder, HttpResponse, Error, HttpRequest, middleware::{from_fn, Next}, body::MessageBody, Result, error, dev::{ServiceRequest, ServiceResponse}, HttpMessage};
use mysql::*;
use mysql::prelude::*;
use serde::{Serialize, Deserialize};
use chrono::Utc;
use mysql::binlog::jsonb::Value::U64;
use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use actix_web::http::header::AUTHORIZATION;
use std::task::{Context, Poll};
use std::pin::Pin;
use actix_web::http::StatusCode;
use actix_web::body::BoxBody;
mod constants;
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Claims {
exp: i64,
name: String,
uid: u64,
}
#[derive(Debug, Serialize)]
struct User {
id: u64,
test: String,
num: u64,
}
#[derive(Deserialize, Serialize)]
struct Info {
test: String,
num: u64,
}
#[derive(Deserialize, Serialize)]
struct u_info {
name: String,
uid: u64,
}
#[derive(Deserialize, Serialize)]
pub struct MyResponse<T> {
pub code: u64,
pub msg: String,
pub data: T,
}
impl<T> MyResponse<T> {
pub fn new(code: u64, message: &str, data: T) -> MyResponse<T> {
MyResponse {
code,
msg: message.to_string(),
data,
}
}
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let url = "mysql://root:root@localhost:3306/fiber";
let pool = Pool::new(url).unwrap();
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(pool.clone())) // 将连接池克隆并存储在应用状态中
//.wrap(AuthMiddleware)
.wrap(from_fn(my_middleware)) //中间件
//.route("/info",web::post().to(infos))
.service(infos)
.service(hello)
.service(index) // 使用 index 函数
.service(index2) // 使用 index 函数
.service(token) // 使用 index 函数
.service(validate_token) // 使用 index 函数
.service(some_handler)
})
.workers(8) // 设置工作线程数量
.bind(("127.0.0.1", 8080))?
.run()
.await
}
#[post("/generate-token")]
async fn token(u_info: web::Json<u_info>) -> impl Responder {
let now = Utc::now().timestamp();
//TODO 查数据库,账号 密码
let my_claims =
Claims {
exp: now + 86400,
name: u_info.name.to_string(),
uid: u_info.uid.clone(),
};
let key = b"secret";
let header = Header { kid: Some("signing_key".to_owned()), alg: Algorithm::HS512, ..Default::default() };
match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
Ok(token) => HttpResponse::Ok().json(MyResponse::new(
200, constants::MESSAGE_OK, token,
)),
Err(_) => HttpResponse::Ok().json(MyResponse::new(
500, constants::MESSAGE_TOKEN_GENERATE_ERROR, constants::EMPTY,
))
}
}
#[post("/validate-token")]
async fn validate_token(req: HttpRequest) -> impl Responder {
// 获取 Authorization 请求头
let authorization_header = req.headers().get("Authorization");
// 检查是否存在 Authorization 请求头,并提取 token
let token1 = match authorization_header {
Some(header) => {
match header.to_str() {
Ok(token_str) => {
if token_str.starts_with("Bearer ") {
// 将 &str 转换为 String
(&token_str["Bearer ".len()..]).to_string()
} else {
// 返回一个空的 String
"".to_string()
}
}
Err(_) => {
// 返回一个空的 String
"".to_string()
}
}
}
None => {
// 返回一个空的 String
"".to_string()
}
};
// 在这里打印 token,仅用于调试目的
println!("Token: {}", token1);
if token1.is_empty() {
return HttpResponse::Ok().json(MyResponse::new(
403, constants::MESSAGE_TOKEN_EMPTY, constants::EMPTY,
));
}
let key = b"secret";
match jsonwebtoken::decode::<Claims>(&token1, &DecodingKey::from_secret(key), &Validation::new(Algorithm::HS512)) {
Ok(c) => HttpResponse::Ok().json(MyResponse::new(
200, constants::MESSAGE_OK, u_info {
name: c.claims.name,
uid: c.claims.uid,
},
)),
Err(_) => HttpResponse::Ok().json(MyResponse::new(
403, constants::MESSAGE_TOKEN_VALIDATE_ERROR, constants::EMPTY,
)),
}
}
#[post("/info")]
async fn infos(infoo: web::Json<Info>) -> impl Responder {
let obj = Info {
test: infoo.test.to_string(),
num: infoo.num.clone(),
};
//web::Json(obj)
HttpResponse::Ok().json(obj)
}
#[get("/")]
async fn hello() -> impl Responder {
HttpResponse::Ok().body("Hello world!")
}
#[get("/getAll")]
async fn index(pool: web::Data<Pool>) -> Result<impl Responder, Error> {
let mut conn = pool.get_conn().map_err(|e| {
actix_web::error::ErrorInternalServerError(e)
})?;
// 执行查询并映射结果到 User 结构体
let results = conn.query_map(
"SELECT id, test, num FROM fiber_user",
|(id, test, num)| User { id, test, num },
).map_err(|e| {
actix_web::error::ErrorInternalServerError(e)
})?;
Ok(web::Json(results))
}
#[get("/getAll2")]
async fn index2(pool: web::Data<Pool>) -> impl Responder {
match pool.get_conn() {
Ok(mut conn) => {
match conn.query_map(
"SELECT id, test, num FROM fiber_user",
|(id, test, num)| User { id, test, num },
) {
Ok(results) => HttpResponse::Ok().json(MyResponse::new(200, constants::MESSAGE_OK, results)),
Err(e) => HttpResponse::Ok().json(MyResponse::new(500, constants::MESSAGE_CAN_NOT_FETCH_DATA, constants::EMPTY)),
}
}
Err(e) => HttpResponse::Ok().json(MyResponse::new(
500, constants::SERVER_ERROR, constants::EMPTY,
)),
}
}
#[post("/some-endpoint")]
async fn some_handler(req: HttpRequest) -> impl Responder {
// 从请求的 Extensions 中获取 Claims
if let Some(claims) = req.extensions().get::<Claims>() {
// 使用 claims 中的数据
println!("User Name: {}", claims.name);
println!("User ID: {}", claims.uid);
return HttpResponse::Ok().json(MyResponse::new(
200,
"Success",
format!("Welcome, {}!", claims.name),
));
}
HttpResponse::Unauthorized().json(MyResponse::new(
403,
"Claims not found",
constants::EMPTY,
))
}
//校验token的简单中间件
async fn my_middleware(
req: ServiceRequest,
next: Next<BoxBody>,
) -> Result<ServiceResponse<BoxBody>, Error> {
// pre-processing
let authorization_header = req.headers().get("Authorization");
// 检查是否存在 Authorization 请求头,并提取 token
let token1 = match authorization_header {
Some(header) => {
match header.to_str() {
Ok(token_str) => {
if token_str.starts_with("Bearer ") {
// 将 &str 转换为 String
(&token_str["Bearer ".len()..]).to_string()
} else {
// 返回一个空的 String
"".to_string()
}
}
Err(_) => {
// 返回一个空的 String
"".to_string()
}
}
}
None => {
// 返回一个空的 String
"".to_string()
}
};
// 在这里打印 token,仅用于调试目的
println!("Token: {}", token1);
let mut authenticate_pass = true;
println!("url->{}", req.path());
//判断白名单,使用的完全等于。
let path = req.path();
if constants::IGNORE_ROUTES.contains(&path) {
authenticate_pass = false;
}
//这是包含xx开头的方法
// for ignore_route in ignore_routes.iter() {
// if req.path().starts_with(ignore_route) {
// authenticate_pass = false;
// break; // 找到匹配项后退出循环
// }
// }
println!("authenticate_pass->{}", authenticate_pass);
if authenticate_pass {
if token1.is_empty() {
// 直接返回 ServiceResponse
return Ok(req.into_response(
HttpResponse::Unauthorized().json(MyResponse::new(
403,
"Authorization header missing",
constants::EMPTY,
))
));
}
//校验token
let key = b"secret";
match jsonwebtoken::decode::<Claims>(&token1, &DecodingKey::from_secret(key), &Validation::new(Algorithm::HS512)) {
Ok(c) => {
// 将 Claims 存储到请求的 Extensions 中
req.extensions_mut().insert(c.claims);
}
Err(_) => {
return Ok(req.into_response(
HttpResponse::Unauthorized().json(MyResponse::new(
403,
"Token validation error",
constants::EMPTY,
))
));
}
}
}
let resp = next.call(req).await?;
println!("after-middleware");
Ok(resp)
}