重新整理 .net core 周边阅读篇————AspNetCoreRateLimit 之规则[二]
前言
本文和上文息息相关。
https://www.cnblogs.com/aoximin/p/15315102.html
是紧接着上文invoke来书写的,那么现在来逐行分析invoke到底干了啥。
正文
invoke 是一个方法,那么其一般符合一个套路。
-
参数检查
-
参数转换与检查(可能有可能无)
-
核心处理
-
返回参数(包括返回void)
那么就按照这个套路来看吧。
先看1和2 步骤吧。
// check if rate limiting is enabled
if (_options == null)
{
await _next.Invoke(context);
return;
}
// compute identity from request
var identity = await ResolveIdentityAsync(context);
// check white list
if (_processor.IsWhitelisted(identity))
{
await _next.Invoke(context);
return;
}
上面检查是否配置为空,如果为空就将请求转到下一个中间件。
// compute identity from request
var identity = await ResolveIdentityAsync(context);
这个属于参数转换。
查看ResolveIdentityAsync:
public virtual async Task<ClientRequestIdentity> ResolveIdentityAsync(HttpContext httpContext)
{
string clientIp = null;
string clientId = null;
if (_config.ClientResolvers?.Any() == true)
{
foreach (var resolver in _config.ClientResolvers)
{
clientId = await resolver.ResolveClientAsync(httpContext);
if (!string.IsNullOrEmpty(clientId))
{
break;
}
}
}
if (_config.IpResolvers?.Any() == true)
{
foreach (var resolver in _config.IpResolvers)
{
clientIp = resolver.ResolveIp(httpContext);
if (!string.IsNullOrEmpty(clientIp))
{
break;
}
}
}
return new ClientRequestIdentity
{
ClientIp = clientIp,
Path = httpContext.Request.Path.ToString().ToLowerInvariant().TrimEnd('/'),
HttpVerb = httpContext.Request.Method.ToLowerInvariant(),
ClientId = clientId ?? "anon"
};
}
这种一般先看返回值的,因为其在前方法中起作用的是返回值。
return new ClientRequestIdentity
{
ClientIp = clientIp,
Path = httpContext.Request.Path.ToString().ToLowerInvariant().TrimEnd('/'),
HttpVerb = httpContext.Request.Method.ToLowerInvariant(),
ClientId = clientId ?? "anon"
};
从这里面可以得知,是通过context,获取了ClientIp、Path、HttpVerb、clientId。
那么前文说过,我们只看下ip部分,那么看下这个ClientIp 是如何获取的吧。
if (_config.IpResolvers?.Any() == true)
{
foreach (var resolver in _config.IpResolvers)
{
clientIp = resolver.ResolveIp(httpContext);
if (!string.IsNullOrEmpty(clientIp))
{
break;
}
}
}
前文提及过了。这里再提及一遍。
这个_config 是IRateLimitConfiguration。
然后我们注册了配置:
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
RateLimitConfiguration 中的IpResolvers:
public IList<IIpResolveContributor> IpResolvers { get; } = new List<IIpResolveContributor>();
中间件初始化的时候:
_config.RegisterResolvers();
调用了RateLimitConfiguration的RegisterResolvers:
public virtual void RegisterResolvers()
{
string clientIdHeader = GetClientIdHeader();
string realIpHeader = GetRealIp();
if (clientIdHeader != null)
{
ClientResolvers.Add(new ClientHeaderResolveContributor(clientIdHeader));
}
// the contributors are resolved in the order of their collection index
if (realIpHeader != null)
{
IpResolvers.Add(new IpHeaderResolveContributor(realIpHeader));
}
IpResolvers.Add(new IpConnectionResolveContributor());
}
这里IpResolvers 就添加了一些ip的获取方式,这个在上文中细讲了,这里就只说其功能。
那么会到invoke中来,对于ip 限制来说,限制获取了clientip、path、methodverb。
那么invoke对clientip进行了检查,查看是是否在白名单中。
// check white list
if (_processor.IsWhitelisted(identity))
{
await _next.Invoke(context);
return;
}
IsWhitelisted 方法:
public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity)
{
if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId))
{
return true;
}
if (_options.IpWhitelist != null && IpParser.ContainsIp(_options.IpWhitelist, requestIdentity.ClientIp))
{
return true;
}
if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any())
{
string path = _options.EnableRegexRuleMatching ? $".+:{requestIdentity.Path}" : $"*:{requestIdentity.Path}";
if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".IsUrlMatch(x, _options.EnableRegexRuleMatching)) ||
_options.EndpointWhitelist.Any(x => path.IsUrlMatch(x, _options.EnableRegexRuleMatching)))
return true;
}
return false;
}
关注一下这个:
if (_options.IpWhitelist != null && IpParser.ContainsIp(_options.IpWhitelist, requestIdentity.ClientIp))
这里是返回是否在白名单的,如果有兴趣可以看下ContainsIp,里面关于了ip6的处理,虽然现在ip6用的不多,但是可以看看,万一真的有用户用ip6呢。
接下来就看下核心处理逻辑:
var rules = await _processor.GetMatchingRulesAsync(identity, context.RequestAborted);
var rulesDict = new Dictionary<RateLimitRule, RateLimitCounter>();
foreach (var rule in rules)
{
// increment counter
var rateLimitCounter = await _processor.ProcessRequestAsync(identity, rule, context.RequestAborted);
if (rule.Limit > 0)
{
// check if key expired
if (rateLimitCounter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow)
{
continue;
}
// check if limit is reached
if (rateLimitCounter.Count > rule.Limit)
{
//compute retry after value
var retryAfter = rateLimitCounter.Timestamp.RetryAfterFrom(rule);
// log blocked request
LogBlockedRequest(context, identity, rateLimitCounter, rule);
if (_options.RequestBlockedBehaviorAsync != null)
{
await _options.RequestBlockedBehaviorAsync(context, identity, rateLimitCounter, rule);
}
if (!rule.MonitorMode)
{
// break execution
await ReturnQuotaExceededResponse(context, rule, retryAfter);
return;
}
}
}
// if limit is zero or less, block the request.
else
{
// log blocked request
LogBlockedRequest(context, identity, rateLimitCounter, rule);
if (_options.RequestBlockedBehaviorAsync != null)
{
await _options.RequestBlockedBehaviorAsync(context, identity, rateLimitCounter, rule);
}
if (!rule.MonitorMode)
{
// break execution (Int32 max used to represent infinity)
await ReturnQuotaExceededResponse(context, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture));
return;
}
}
rulesDict.Add(rule, rateLimitCounter);
}
先看核心功能要用到的参数:
var rules = await _processor.GetMatchingRulesAsync(identity, context.RequestAborted);
var rulesDict = new Dictionary<RateLimitRule, RateLimitCounter>();
看下GetMatchingRulesAsync:
public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
{
var policies = await _policyStore.GetAsync($"{_options.IpPolicyPrefix}", cancellationToken);
var rules = new List<RateLimitRule>();
if (policies?.IpRules?.Any() == true)
{
// search for rules with IP intervals containing client IP
var matchPolicies = policies.IpRules.Where(r => IpParser.ContainsIp(r.Ip, identity.ClientIp));
foreach (var item in matchPolicies)
{
rules.AddRange(item.Rules);
}
}
return GetMatchingRules(identity, rules);
}
这个是先获取该ip是否是我们特殊ip处理的规则,然后通过GetMatchingRules 判断其是否符合规则。
GetMatchingRules 应该就是处理核心了。
protected virtual List<RateLimitRule> GetMatchingRules(ClientRequestIdentity identity, List<RateLimitRule> rules = null)
{
var limits = new List<RateLimitRule>();
if (rules?.Any() == true)
{
if (_options.EnableEndpointRateLimiting)
{
// search for rules with endpoints like "*" and "*:/matching_path"
string path = _options.EnableRegexRuleMatching ? $".+:{identity.Path}" : $"*:{identity.Path}";
var pathLimits = rules.Where(r => path.IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
limits.AddRange(pathLimits);
// search for rules with endpoints like "matching_verb:/matching_path"
var verbLimits = rules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
limits.AddRange(verbLimits);
}
else
{
// ignore endpoint rules and search for global rules only
var genericLimits = rules.Where(r => r.Endpoint == "*");
limits.AddRange(genericLimits);
}
// get the most restrictive limit for each period
limits = limits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList();
}
// search for matching general rules
if (_options.GeneralRules != null)
{
var matchingGeneralLimits = new List<RateLimitRule>();
if (_options.EnableEndpointRateLimiting)
{
// search for rules with endpoints like "*" and "*:/matching_path" in general rules
var pathLimits = _options.GeneralRules.Where(r => $"*:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
matchingGeneralLimits.AddRange(pathLimits);
// search for rules with endpoints like "matching_verb:/matching_path" in general rules
var verbLimits = _options.GeneralRules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
matchingGeneralLimits.AddRange(verbLimits);
}
else
{
//ignore endpoint rules and search for global rules in general rules
var genericLimits = _options.GeneralRules.Where(r => r.Endpoint == "*");
matchingGeneralLimits.AddRange(genericLimits);
}
// get the most restrictive general limit for each period
var generalLimits = matchingGeneralLimits
.GroupBy(l => l.Period)
.Select(l => l.OrderBy(x => x.Limit).ThenBy(x => x.Endpoint))
.Select(l => l.First())
.ToList();
foreach (var generalLimit in generalLimits)
{
// add general rule if no specific rule is declared for the specified period
if (!limits.Exists(l => l.Period == generalLimit.Period))
{
limits.Add(generalLimit);
}
}
}
foreach (var item in limits)
{
if (!item.PeriodTimespan.HasValue)
{
// parse period text into time spans
item.PeriodTimespan = item.Period.ToTimeSpan();
}
}
limits = limits.OrderBy(l => l.PeriodTimespan).ToList();
if (_options.StackBlockedRequests)
{
limits.Reverse();
}
return limits;
}
这样看起来代码挺多的,但是这种也说明水不水特别深,为什么这么说呢?因为这里面基本没有调用其他的方法,都是写基础逻辑处理。
因为有很多if,那么就通过if来分段看。
var limits = new List<RateLimitRule>();
if (rules?.Any() == true)
{
if (_options.EnableEndpointRateLimiting)
{
// search for rules with endpoints like "*" and "*:/matching_path"
string path = _options.EnableRegexRuleMatching ? $".+:{identity.Path}" : $"*:{identity.Path}";
var pathLimits = rules.Where(r => path.IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
limits.AddRange(pathLimits);
// search for rules with endpoints like "matching_verb:/matching_path"
var verbLimits = rules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
limits.AddRange(verbLimits);
}
else
{
// ignore endpoint rules and search for global rules only
var genericLimits = rules.Where(r => r.Endpoint == "*");
limits.AddRange(genericLimits);
}
// get the most restrictive limit for each period
limits = limits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList();
}
这个一段是对我们特殊ip规则的处理,然后发现里面的逻辑其实是围绕着_options.EnableEndpointRateLimiting展开的。
那么从文档中EnableEndpointRateLimiting是什么呢?
If EnableEndpointRateLimiting is set to false then the limits will apply globally and only the rules that have as endpoint * will apply. For example, if you set a limit of 5 calls per second, any HTTP call to any endpoint will count towards that limit.
If EnableEndpointRateLimiting is set to true, then the limits will apply for each endpoint as in {HTTP_Verb}{PATH}. For example if you set a limit of 5 calls per second for *:/api/values a client can call GET /api/values 5 times per second but also 5 times PUT /api/values.
这上面是说如果EnableEndpointRateLimiting 是false 的话,那么限制只用于端点为"*"的情况。举了一个例子:如果你设置了每秒访问5次,那么你访问任何端点都会被计数。
如果EnableEndpointRateLimiting设置为true,那么限制将适用于每个端点,如{HTTP_Verb}{PATH}。例如,如果你为*:/api/值设置了每秒5次调用的限制,客户端可以每秒5次调用GET /api/值,也可以5次调用PUT /api/值。
说白了就是是否可以设置访问特殊Endpoint的访问限制。
有了上面的文档解释,那么看着代码只要按照这思路去看就行。
接下来看下一个if:
// search for matching general rules
if (_options.GeneralRules != null)
{
var matchingGeneralLimits = new List<RateLimitRule>();
if (_options.EnableEndpointRateLimiting)
{
// search for rules with endpoints like "*" and "*:/matching_path" in general rules
var pathLimits = _options.GeneralRules.Where(r => $"*:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
matchingGeneralLimits.AddRange(pathLimits);
// search for rules with endpoints like "matching_verb:/matching_path" in general rules
var verbLimits = _options.GeneralRules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsUrlMatch(r.Endpoint, _options.EnableRegexRuleMatching));
matchingGeneralLimits.AddRange(verbLimits);
}
else
{
//ignore endpoint rules and search for global rules in general rules
var genericLimits = _options.GeneralRules.Where(r => r.Endpoint == "*");
matchingGeneralLimits.AddRange(genericLimits);
}
// get the most restrictive general limit for each period
var generalLimits = matchingGeneralLimits
.GroupBy(l => l.Period)
.Select(l => l.OrderBy(x => x.Limit).ThenBy(x => x.Endpoint))
.Select(l => l.First())
.ToList();
foreach (var generalLimit in generalLimits)
{
// add general rule if no specific rule is declared for the specified period
if (!limits.Exists(l => l.Period == generalLimit.Period))
{
limits.Add(generalLimit);
}
}
}
同样我们应该看:_options.GeneralRules。
GeneralRules 就是我们限定的规则,里面同样看的是Endpoint是否匹配。
然后看最后一段:
foreach (var item in limits)
{
if (!item.PeriodTimespan.HasValue)
{
// parse period text into time spans
item.PeriodTimespan = item.Period.ToTimeSpan();
}
}
limits = limits.OrderBy(l => l.PeriodTimespan).ToList();
if (_options.StackBlockedRequests)
{
limits.Reverse();
}
return limits;
上面for 循环就是将我们的时间字符串转换为timespan(时间区间),然后从小到大排序一下。
接下来就看下_options.StackBlockedRequests,还是那个老套路看到配置文件查文档。
If StackBlockedRequests is set to false, rejected calls are not added to the throttle counter. If a client makes 3 requests per second and you've set a limit of one call per second, other limits like per minute or per day counters will only record the first call, the one that wasn't blocked. If you want rejected requests to count towards the other limits, you'll have to set StackBlockedRequests to true.
我直接用自己的理解说哈,如果StackBlockedRequests 设置为false,如果被拒绝的请求将不会加入到计数中。如果一个客户端每秒3次请求,你设置了每秒请求一次。那么其他的限制像每分钟和每天的计数器将只没有被拒绝的记录一次。
如果想拒绝的请求请求进行计数,那么你应该设置StackBlockedRequests 为true。
这里面就是说白了,就是拒绝的请求是否进行计数。
当然在这里还没有涉及到计数,StackBlockedRequests为true是将时间区间,从大到小排序了,这将成为后面的一个关键。
这里可以进行一个大胆的猜测,StackBlockedRequests 为fale的情况下,limits 是根据PeriodTimespan 从小到大排序的,也就是说是秒 分 小时 天这样排序的。
根据正常逻辑一般是秒先达到阀值,那么可能计数逻辑就是进行for循环,然后如果到达了限制那么就进行request block,很巧妙的一种设计。
这里可能就有人问了,如果是分到达了限制,那么秒不还是进行计数了吗?
这是没有关系的,因为分里面包含了秒。这里其实解决的是这样的一个问题,比如我在1秒中内请求了60次,那么有59次是失败的,那么如果请求算60次的话,那么会达到每分钟60次的现在,那么这个用户在一分钟内无法请求,故而建议StackBlockedRequests 设置为false。
结
因为篇幅限制,下一节是关于如何计数的。