floor-sum 算法 学习笔记

一、算法描述

floor-sum 算法用来在 \(\Theta(\log V)\) 的时间复杂度内解决如下问题,其中 \(V\) 是值域。

定义

\[f(a, b, c, n) = \sum_{i = 1}^n \lfloor \dfrac{ai + b}{c} \rfloor \]

给定 \(a, b, c, n(c, n \geq 0)\),求出 \(f(a, b, c, n)\) 的值。

首先,若 \(a, b\) 不满足 \(0 \leq a, b < c\),那么有

\[\begin{align*} f(a, b, c, n) &= \sum_{i = 1}^n \lfloor \dfrac{ai + b}{c} \rfloor \\ &= \sum_{i = 1}^n \lfloor \dfrac{\lfloor \frac{a}{c} \rfloor ci + (a \bmod c)i + \lfloor \frac{b}{c} \rfloor c + (b \bmod c)}{c} \rfloor \\ &= \sum_{i = 1}^n \lfloor \dfrac{a}{c} \rfloor i + \sum_{i = 1}^n \lfloor \dfrac{b}{c} \rfloor + \sum_{i = 1}^n \lfloor \dfrac{(a \bmod c)i + (b \bmod c)}{c} \rfloor \\ &= \lfloor \dfrac{a}{c} \rfloor \dfrac{n(n + 1)}{2} + \lfloor \dfrac{a}{c} \rfloor n + f(a \bmod c, b \bmod c, c, n) \end{align*} \]

因此我们只需考虑 \(0 \leq a, b < c\) 的情况。考虑对表达式变形。

\[f(a, b, c, n) = \sum_{i = 1}^n \lfloor \dfrac{ai + b}{c} \rfloor = \sum_{i = 1}^n \sum_{j = 1}^{\lfloor \frac{an + b}{c} \rfloor} [j \leq \lfloor \dfrac{ai + b}{c} \rfloor] = \sum_{j = 1}^{\lfloor \frac{an + b}{c} \rfloor} \sum_{i = 1}^n [j \leq \lfloor \dfrac{ai + b}{c} \rfloor] \]

由于我们交换了求和顺序,我们需要使中括号内的表达式变成关于 \(i\) 的限制。考虑对中括号变形。

\[j \leq \lfloor \dfrac{ai + b}{c} \rfloor \Leftrightarrow j \leq \dfrac{ai + b}{c} \Leftrightarrow cj \leq ai + b \Leftrightarrow ai \geq cj - b \Leftrightarrow ai > cj - b - 1 \Leftrightarrow i > \dfrac{cj - b - 1}{a} \Leftrightarrow i > \lfloor \dfrac{cj - b - 1}{a} \rfloor \]

然后我们将最终的式子代入到原来的 \(f(a, b, c, n)\),得到

\[\begin{align*} f(a, b, c, n) &= \sum_{j = 1}^{\lfloor \frac{an + b}{c} \rfloor} \sum_{i = 1}^n [i > \lfloor \dfrac{cj - b - 1}{a} \rfloor] \\ &= \sum_{j = 1}^{\lfloor \frac{an + b}{c} \rfloor} (n - \lfloor \dfrac{cj - b - 1}{a} \rfloor) \\ &= \lfloor \dfrac{an + b}{c} \rfloor n - \sum_{j = 1}^{\lfloor \frac{an + b}{c} \rfloor} \lfloor \dfrac{cj - b - 1}{a} \rfloor \\ &= \lfloor \dfrac{an + b}{c} \rfloor n - f(c, -b - 1, a, \dfrac{an + b}{c}) \end{align*} \]

于是我们得到了一个关于 \(f(a, b, c, n)\) 的递归式。回顾一下我们刚才的推导过程,我们发现该函数的变化形如 \(f(a, b, c, n) \rightarrow f(a \bmod c, b \bmod c, c, n) \rightarrow f(c, -(b \bmod c) - 1, a \bmod c, \lfloor \dfrac{(a \bmod c)n + (b \bmod c)}{c} \rfloor)\)。事实上,这个过程类似于对 \(a\)\(c\) 辗转相除,因此该函数的时间复杂度为 \(\Theta(\log V)\)

参考代码
int floor(int a,int b){
	return (a>=0)?(a/b):-((-a+b-1)/b);
}
int f(int a,int b,int c,int n){
	// Calculate the sum of (a * i + b) / c(1 <= i <= n).
	int sa=floor(a,c),sb=floor(b,c);
	int ans=sa*n*(n+1)/2+sb*n;
	a-=sa*c,b-=sb*c;
	if(a==0||n==0) return ans;
	return (a*n+b)/c*n-f(c,-b-1,a,(a*n+b)/c)+ans;
}

二、例题

1. ABC313G Redistribution of Piles

简要题意

\(n\) 个盘子和 \(1\) 个袋子。初始时,第 \(i\) 个盘子上有 \(a_i\) 块石头,袋子是空的。

有以下两种操作:

  • 对于每个盘子,如果当前盘子上有石头,那么从该盘子上移除一块石头。然后将所有移除掉的石头放进袋子。
  • 从袋子里拿出 \(n\) 块石头,并依次在每个盘子中放入 \(1\) 块。这种操作能够进行当且仅当此时袋子里有至少 \(n\) 块石头。

你可以按任意顺序进行任意多次操作,包括 \(0\) 次。求所有能够到达的局面数量,对 \(998244353\) 取模。

\(1 \leq n \leq 2 \times 10^5\)\(0 \leq a_i \leq 10^9\)

我们称一种局面为一种“状态”当且仅当这种局面中存在某个盘子为空。称一种局面与一种状态互相对应当且仅当从该状态出发,能够只经过操作 2 变成该局面。显然,一种局面对应的状态是唯一的(只需不断地从每个盘子中移除一块石头,直到某个盘子为空),但一种状态可能对应多个局面。因此答案即为所有状态对应的局面数量之和。

考虑如何求出一种状态对应的局面数量。显然每次操作会多出恰好 \(n\) 块石头和恰好一种局面。因此,如果总共有 \(\mathrm{sum}\) 块石头,且该状态有 \(x\) 块石头,则该状态会对应 \(\lfloor \dfrac{sum - x}{n} \rfloor + 1 = \lfloor \dfrac{sum - x + n}{n} \rfloor\) 种状态。

然而我们发现这道题的值域(单个盘子上的石子数)是 \(10^9\) 级别的,这说明虽然我们刚才的操作大大优化了复杂度,但现在仅连状态数都仍然不可接受。我们考虑将这些状态归类,定义一个状态的类别为这个状态中空盘子的数量,并设计一种算法快速地求出每一类包含的局面数量。事实上,若初始时将这些盘子按石子数升序排序,则每种状态中的空盘子一定集中在最左边,这种状态的类别即为最右端空盘子的编号。由此,对于一种类别,我们就能快速表示出它对应的状态。

现在我们只需考虑如何求出第 \(i\) 类状态包含的局面数。事实上,我们发现状态之间只能通过操作 1 转移。因此,如果我们找出了石子数最大的状态,该类别中的状态都可以被表示成一个非负整数 \(k\),表示保持左边的 \(i\) 个空盘子不变,右边 \(n - i\) 个非空的盘子上石子数量都减少 \(k\)。显然 \(k\) 的值域为 \([0, a_i - a_{i + 1})\),且其中的每个整数都会出现恰好 \(1\) 次。将上面的式子代入,我们得出第 \(i\) 类状态包含的局面数为

\[\sum_{k = 0}^{a_i - a_{i + 1}} \lfloor \dfrac{\mathrm{sum} - x + n + k(n - i)}{n} \rfloor \]

其中 \(\mathrm{sum}\) 表示总石子数,\(x\) 表示该类中石子数最大的状态中的石子数。

我们发现这个式子满足 floor-sum 算法的要求,因此只需调用 \(f(n - i, \mathrm{sum} - x + n, n, a_{i + 1} - a_i)\) 即可。时间复杂度为 \(\Theta(n \cdot \log V)\)

AC 代码
#include <algorithm>
#include <cstdio>
using namespace std;
const int N=200003,mod=998244353;
int n,a[N];
long long floor(long long a,long long b){
	return (a>=0)?(a/b):-((-a+b-1)/b);
}
int f(long long a,long long b,int c,int n){
	// calculate the sum of floor(a * i + b)/c (1 <= i <= n).
	int sa=(floor(a,c)%mod+mod)%mod,sb=(floor(b,c)%mod+mod)%mod;
	int ans=(((long long)(n+1)*n/2%mod*sa%mod+(long long)n*sb%mod)%mod+mod)%mod;
	a-=floor(a,c)*c,b-=floor(b,c)*c;
	if(n==0||a==0) return ans;
	return ((a*n+b)/c*n%mod-f(c,-b-1,a,(a*n+b)/c)+ans+mod)%mod;
}
int main(){
//	freopen("pile.in","r",stdin);
//	freopen("pile.out","w",stdout);
	int i,ans=0;
	long long sum=0;
	scanf("%d",&n);
	for(i=1;i<=n;i++)
		scanf("%d",&a[i]);
	sort(a+1,a+n+1);
	for(i=1;i<=n;i++) sum+=a[i];
	for(i=n;i>0;i--) a[i]-=a[1];
	for(i=1;i<=n;i++) sum-=a[i];
	for(i=1,ans=(sum/n+1)%mod;i<n;i++){
		ans=(ans+f(n-i,sum+n,n,a[i+1]-a[i]))%mod;
		sum+=(long long)(a[i+1]-a[i])*(n-i);
	}
	printf("%d",ans);
//	fclose(stdin);
//	fclose(stdout);
	return 0;
}

2. ABC372G Ax + By < C

简要题意

给定三个长为 \(n\) 的序列 \(A = (A_1, A_2, \cdots, A_n), B = (B_1, B_2, \cdots, B_n), C = (C_1, C_2, \cdots, C_n)\)。你需要求出满足以下条件的二元组 \((x, y)\) 的数量:

  • \(x, y\) 都是正整数。
  • 对每个 \(1 \leq i \leq n\),都有 \(A_i x + B_i y < C_i\)

单个测试点内有 \(T\) 组测试数据,你需要对每组数据都求出答案。可以证明这样的二元组数量一定是有限的。

\(1 \leq T, n, \sum n \leq 2 \times 10^5\)\(1 \leq A_i, B_i, C_i \leq 10^9\)

对于限制 \(A_i x + B_i y < C_i\),变形得 \(y \leq -\dfrac{A_i}{B_i} x + \dfrac{C_i - 1}{B_i}\),我们发现这类似于直线的解析式。考虑把这些条件画到平面直角坐标系上,那么一个条件相当于钦定第一象限上的点 \((x, y)\) 必须在某条直线上或在这条直线的下方。也就是说,每个条件都限制了 \((x, y)\) 必须在某个半平面上。为了综合这些限制,我们求一遍半平面交,相当于求出了这些直线围成的下凸壳(再加上两条坐标轴)。

显然对于直线 \(x = x_0\),若该直线与凸壳交点的纵坐标为 \(y_0\),那么横坐标为 \(x_0\) 的点就有 \(\lfloor y_0 \rfloor\) 个。我们考虑分别对于凸壳上的每一段求出其下方的点对数量,令第 \(i\) 段为 \(y = -\dfrac{a_i}{b_i} x + \dfrac{c_i - 1}{b_i}(l_i < x \leq r_i)\),那么将两个公式结合起来即可得到答案为

\[ \sum_{k = l_i + 1}^{r_i} \lfloor \dfrac{-a_i k + c_i - 1}{b_i} \rfloor = \sum_{k = 1}^{r_i - l_i} \lfloor \dfrac{-a_i k + l_i k + c_i - 1}{b_i} \rfloor \]

我们发现这个式子满足 floor-sum 算法的要求,因此只需调用 \(f(-a_i, l_i k + c_i - 1, b_i, r_i - l_i)\) 即可。时间复杂度为 \(\Theta(\sum n \cdot \log V)\)

AC 代码
#include <algorithm>
#include <cstdio>
using namespace std;
const int N=200003;
const long long PIN=4557430888798830399;
struct Line{
	int a,b,c; // y = -(a / b) * x + (c / b)
}a[N],b[N];
const Line era={0,1,0};
int n,m; long long c[N];
bool cmp(const Line& l1,const Line& l2){
	return (long long)l1.a*l2.b<(long long)l1.b*l2.a;
}
long long sec(Line l1,Line l2){
	// It is guaranteed that a1 / b1 <= a2 / b2.
	if((long long)l1.c*l2.b> (long long)l1.b*l2.c) return -1;
	if((long long)l1.a*l2.b==(long long)l1.b*l2.a) return PIN;
	long long sa=(long long)l1.b*l2.a-(long long)l1.a*l2.b;
	long long sb=(long long)l1.b*l2.c-(long long)l1.c*l2.b;
	return sb/sa;
}
int floor(int a,int b){
	return (a>=0)?(a/b):-((-a+b-1)/b);
}
long long f(int a,int b,int c,int n){
	// Calculate the sum of (a * i + b) / c(1 <= i <= n).
	long long sa=floor(a,c),sb=floor(b,c);
	long long ans=sa*n*(n+1)/2+sb*n;
	a-=sa*c,b-=sb*c; if(a==0||n==0) return ans;
	return ((long long)a*n+b)/c*n-f(c,-b-1,a,((long long)a*n+b)/c)+ans;
}
int main(){
//	freopen("line.in","r",stdin);
//	freopen("line.out","w",stdout);
	int i,t; long long ans;
	for(scanf("%d",&t);t>0;t--){
		scanf("%d",&n),m=0;
		for(i=1;i<=n;i++)
			scanf("%d%d%d",&a[i].a,&a[i].b,&a[i].c),a[i].c--;
		sort(a+1,a+n+1,cmp);
		for(i=1;i<=n;i++){
			while(m>0&&sec(b[m],a[i])==-1) m--;
			while(m>1&&sec(b[m],a[i])<=sec(b[m-1],b[m])) m--;
			b[++m]=a[i];
		}
		while(m>1&&sec(era,b[m-1])<=sec(b[m-1],b[m])) m--;
		for(i=1;i<m;i++)
			c[i]=sec(b[i],b[i+1]);
		c[m]=sec(era,b[m]),ans=0;
		for(i=1;i<=m;i++){
			ans+=f(-b[i].a,b[i].c,b[i].b,c[i  ]);
			ans-=f(-b[i].a,b[i].c,b[i].b,c[i-1]);
		}
		printf("%lld\n",ans);
	}
//	fclose(stdin);
//	fclose(stdout);
	return 0;
}
posted @ 2024-09-23 22:53  kilomiles  阅读(331)  评论(0)    收藏  举报