AspNetCore添加API限流

最近发现有客户在大量的请求我们的接口,出于性能考虑遂添加了请求频率限制。

由于我们接口请求的是.Net Core写的API网关,所以可以直接添加一个中间件,中间件中使用请求的地址当key,通过配置中心读取对应的请求频率参数设置,然后通过设置redis的过期时间就能实现了。

添加一个中间件ApiThrottleMiddleware,使用httpContext.Request.Path获取请求的接口,然后以次为key去读取配置中心设置的请求频率设置。(Ps:使用_configuration.GetSection(apiUrl).Get<ApiThrottleConfig>()不知为何返回值为null,这个还在查)

 1     public class ApiThrottleMiddleware
 2     {
 3         private readonly RequestDelegate _next;
 4         private IConfiguration _configuration;
 5         private readonly IRedisRunConfigDatabaseProvider _redisRunConfigDatabaseProvider;
 6         private readonly IDatabase _database;
 7 
 8         public ApiThrottleMiddleware(RequestDelegate next,
 9             IConfiguration configuration,
10             IRedisRunConfigDatabaseProvider redisRunConfigDatabaseProvider)
11         {
12             _next = next;
13             _configuration = configuration;
14             _redisRunConfigDatabaseProvider = redisRunConfigDatabaseProvider;
15             _database = _redisRunConfigDatabaseProvider.GetDatabase();
16         }
17 
18         public async Task Invoke(HttpContext httpContext)
19         {
20             var middlewareContext = httpContext.GetOrCreateMiddlewareContext();
21             var apiUrl = httpContext.Request.Path.ToString();
22 
23             var jsonValue= _configuration.GetSection(apiUrl).Value;
24             var apiThrottleConfig=JsonConvert.DeserializeObject<ApiThrottleConfig>(jsonValue);
25             //var apiThrottleConfig = _configuration.GetSection(apiUrl).Get<ApiThrottleConfig>();
26             
27             await _next.Invoke(httpContext);
28         }
29 }

 

我们使用的配置中心是Apollo,设置的格式如下,其中Duration为请求间隔/秒,Limit为调用次数。(下图设置为每分钟允许请求10次)

 

 

(Ps: 由于在API限流中间件前我们已经通过了一个接口签名验证的中间件了,所以我们可以拿到调用客户的具体信息)

如果请求地址没有配置请求频率控制,则直接跳过。否则先通过SortedSetLengthAsync获取对应key的记录数,其中key我们使用了 $"{客户Id}:{插件编码}:{请求地址}",以此来限制每个客户,每个插件对应的某个接口来控制请求频率。获取key对应集合,当前时间-配置的时间段到当前时间的记录。

 1         /// <summary>
 2         /// 获取key
 3         /// </summary>
 4         /// <param name="signInfo"></param>
 5         /// <param name="apiUrl">接口地址</param>
 6         /// <returns></returns>
 7         private string GetApiRecordKey(InterfaceSignInfo signInfo,string apiUrl)
 8         {
 9             var key = $"{signInfo.LicNo}:{signInfo.PluginCode}:{apiUrl}";
10             return key;
11         }
12 
13         /// <summary>
14         /// 获取接口调用次数
15         /// </summary>
16         /// <param name="signInfo"></param>
17         /// <param name="apiUrl">接口地址</param>
18         /// <param name="duration">超时时间</param>
19         /// <returns></returns>
20         public async Task<long> GetApiRecordCountAsync(InterfaceSignInfo signInfo, string apiUrl, int duration)
21         {
22             var key = GetApiRecordKey(signInfo, apiUrl);
23             var nowTicks = DateTime.Now.Ticks;
24             return await _database.SortedSetLengthAsync(key, nowTicks - TimeSpan.FromSeconds(duration).Ticks, nowTicks);
25         }        

如果请求次数大于等于我们设置的频率就直接返回接口调用频率超过限制错误,否则则在key对应的集合中添加一条记录,同时将对应key的过期时间设置为我们配置的限制时间。

需要注意的是,如果一直有接口调用的话,会使得zset对应的key的过期时间一直被更新,那么会导致集合里面的member一直增加;所以我们可以再通过调用删除掉已过期的member(或者使用定时任务定时删除过期member)

 1 /// <summary>
 2 /// 添加调用次数
 3 /// </summary>
 4 /// <param name="signInfo"></param>
 5 /// <param name="apiUrl">接口地址</param>
 6 /// <param name="duration">超时时间</param>
 7 /// <returns></returns>
 8 public async Task AddApiRecordCountAsync(InterfaceSignInfo signInfo, string apiUrl, int duration)
 9 {
10     var key = GetApiRecordKey(signInfo, apiUrl);
11     var nowTicks = DateTime.Now.Ticks;
12     await _database.SortedSetAddAsync(key, nowTicks, nowTicks);
13     await _database.KeyExpireAsync(key, TimeSpan.FromSeconds(duration));
14 
15     await _database.SortedSetRemoveRangeByScoreAsync(key, 0, nowTicks - TimeSpan.FromSeconds(duration).Ticks, flags: CommandFlags.FireAndForget);
16 }

 

然后只需要在Startup中,在API签名验证中间件后调用我们这个API限流中间件就行了。

以下为完整的代码

 

  1 using ApiGateway.Core.Configuration;
  2 using ApiGateway.Core.Domain.Authentication;
  3 using ApiGateway.Core.Domain.Configuration;
  4 using ApiGateway.Core.Domain.Errors;
  5 using Microsoft.AspNetCore.Http;
  6 using Microsoft.Extensions.Configuration;
  7 using Newtonsoft.Json;
  8 using StackExchange.Redis;
  9 using System;
 10 using System.Threading.Tasks;
 11 
 12 namespace ApiGateway.Core.Middleware.Api
 13 {
 14     /// <summary>
 15     /// API限流中间件
 16     /// </summary>
 17     public class ApiThrottleMiddleware
 18     {
 19         private readonly RequestDelegate _next;
 20         private IConfiguration _configuration;
 21         private readonly IRedisRunConfigDatabaseProvider _redisRunConfigDatabaseProvider;
 22         private readonly IDatabase _database;
 23 
 24         public ApiThrottleMiddleware(RequestDelegate next,
 25             IConfiguration configuration,
 26             IRedisRunConfigDatabaseProvider redisRunConfigDatabaseProvider)
 27         {
 28             _next = next;
 29             _configuration = configuration;
 30             _redisRunConfigDatabaseProvider = redisRunConfigDatabaseProvider;
 31             _database = _redisRunConfigDatabaseProvider.GetDatabase();
 32         }
 33 
 34         public async Task Invoke(HttpContext httpContext)
 35         {
 36             var middlewareContext = httpContext.GetOrCreateMiddlewareContext();
 37             var apiUrl = httpContext.Request.Path.ToString();
 38 
 39             var jsonValue= _configuration.GetSection(apiUrl.Replace('/', '.')).Value;
 40             if (!string.IsNullOrEmpty(jsonValue))
 41             {
 42                 var apiThrottleConfig = JsonConvert.DeserializeObject<ApiThrottleConfig>(jsonValue);
 43                 //var apiThrottleConfig = _configuration.GetSection(apiUrl).Get<ApiThrottleConfig>();
 44                 var count = await GetApiRecordCountAsync(middlewareContext.InterfaceSignInfo, apiUrl, apiThrottleConfig.Duration);
 45                 if (count >= apiThrottleConfig.Limit)
 46                 {
 47                     middlewareContext.Errors.Add(new Error("接口调用频率超过限制", GatewayErrorCode.OverThrottleError));
 48                     return;
 49                 }
 50                 else
 51                 {
 52                     await AddApiRecordCountAsync(middlewareContext.InterfaceSignInfo, apiUrl, apiThrottleConfig.Duration);
 53                 }
 54             }
 55             
 56             await _next.Invoke(httpContext);
 57         }
 58 
 59         /// <summary>
 60         /// 获取接口调用次数
 61         /// </summary>
 62         /// <param name="signInfo"></param>
 63         /// <param name="apiUrl">接口地址</param>
 64         /// <param name="duration">超时时间</param>
 65         /// <returns></returns>
 66         public async Task<long> GetApiRecordCountAsync(InterfaceSignInfo signInfo, string apiUrl, int duration)
 67         {
 68             var key = GetApiRecordKey(signInfo, apiUrl);
 69             var nowTicks = DateTime.Now.Ticks;
 70             return await _database.SortedSetLengthAsync(key, nowTicks - TimeSpan.FromSeconds(duration).Ticks, nowTicks);
 71         }
 72 
 73         /// <summary>
 74         /// 添加调用次数
 75         /// </summary>
 76         /// <param name="signInfo"></param>
 77         /// <param name="apiUrl">接口地址</param>
 78         /// <param name="duration">超时时间</param>
 79         /// <returns></returns>
 80         public async Task AddApiRecordCountAsync(InterfaceSignInfo signInfo, string apiUrl, int duration)
 81         {
 82             var key = GetApiRecordKey(signInfo, apiUrl);
 83             var nowTicks = DateTime.Now.Ticks;
 84             await _database.SortedSetAddAsync(key, nowTicks, nowTicks);
 85             await _database.KeyExpireAsync(key, TimeSpan.FromSeconds(duration));
 86 
 87             await _database.SortedSetRemoveRangeByScoreAsync(key, 0, nowTicks - TimeSpan.FromSeconds(duration).Ticks, flags: CommandFlags.FireAndForget);
 88         }
 89 
 90         /// <summary>
 91         /// 获取key
 92         /// </summary>
 93         /// <param name="signInfo"></param>
 94         /// <param name="apiUrl">接口地址</param>
 95         /// <returns></returns>
 96         private string GetApiRecordKey(InterfaceSignInfo signInfo,string apiUrl)
 97         {
 98             var key = $"_api_throttle:{signInfo.LicNo}:{signInfo.PluginCode}:{apiUrl}";
 99             return key;
100         }
101     }
102 }
View Code

 

posted @ 2021-08-13 10:00  Cyril-Hcj  阅读(786)  评论(0编辑  收藏  举报