[字符串border-period][数论][莫比乌斯反演][CF1205E]Expected Value Again
-
大致题意:定义 \(f(s)\) 表示字符串 \(s\) 的 border 个数,即满足 \(1\le i<|s|\) 且 \(s\) 的长度为 \(i\) 的前缀和后缀相等的 \(i\) 个数。求所有长度为 \(n\),字符集大小为 \(k\) 的字符串 \(s\) 的 \(f(s)^2\) 平均值,对 \(10^9+7\) 取模,\(1\le n\le 10^5\),\(1\le k\le 10^9\)
-
首先我们知道,一个字符串 \(s\) 有长度为 \(i\) 的 border 就有长度为 \(|s|-i\) 的 period(一个字符串 \(s\) 有长度为 \(l\) 的 period 当且仅当对于所有的 \(1\le i\le|s|-l\) 都有 \(s_i=s_{i+l}\)),于是把 border 个数转为求 period 个数
-
首先转化一下问题:\(f(s)^2\) 相当于枚举两个 \(1\le i,j<n\),使得 \(s\) 同时存在 \(i,j\) 两个 period,于是可以考虑算出每一对 \(1\le i,j<n\) 的贡献后加起来
-
现在考虑对于一对 \(i,j\),有多少个长度为 \(n\) 的字符串 \(s\) 满足 \(s\) 同时存在 \(i,j\) 两个 period
-
易得这相当于 \(n\) 个点,在编号差为 \(i\) 或 \(j\) 的点对之间连边表示这两个下标上的字符相等,计算出连通块个数为 \(cnt\),则这对 \(i,j\) 的贡献为 \(k^{cnt}\)
-
考虑如何计算连通块个数,设 \(d=\gcd(i,j)\),那么这 \(n\) 个下标显然可以按照下标模 \(d\) 的结果划分成 \(d\) 个独立的部分,总连通块个数为这 \(d\) 个部分的连通块个数之和
-
故先考虑 \(i,j\) 互质的情况。这里先说结论:如果 \(n\ge i+j\) 则这个图是连通的,否则这个图没有环
-
证明:先把所有编号差为 \(i\) 的点对连上,这样会形成 \(i\) 条链,同一条链上的点编号模 \(i\) 意义下同余
-
再考虑编号差为 \(j\) 的点对,可以看出,如果 \(n\ge i+j\),则任意一条链上都存在一个点,使得这个点能通过往右跳 \(j\) 步到达另一条链(链上最小的点编号不能超过 \(i\))
-
也就是说,我们如果现在在一条模 \(i\) 为 \(x\) 的链上,那么可以跳到一条模 \(i\) 为 \((x+j)\bmod i\) 的链上。由于 \(i,j\) 互质,所以 \(x,x+j,x+2j,\dots,x+(i-1)j\) 在模 \(i\) 意义下互不相同,取遍了 \([0,i)\) 内的所有值,也就是说只需每次从模 \(i\) 为 \(x\) 的链跳到 \((x+j)\bmod i\) 的链,就能走遍所有的 \(i\) 条链,故该图连通
-
而如果 \(n<i+j\),那么同样地,两条链最多被一对编号差为 \(j\) 的点连接起来,而对于所有的 \(x\),模 \(i\) 为 \(x\) 的链到模 \(i\) 为 \((x+j)\bmod i\) 的链的边都是可能存在的,也就是看上去最坏情况下所有的链会被连成一个环。但是我们注意到 \(x=0\) 时,链上最小的点编号为 \(i\),由于 \(n<i+j\),这条链必定不存在到 \((x+j)\bmod i\) 的边,故这些把所有 \(i\) 条链连成环的 \(i\) 条边中至少要断掉一条,也就证明了这个图无环
-
于是我们得出:
-
\[cnt=\begin{cases}n-(n-i)-(n-j)=i+j-n & n<i+j\\1 & n\ge i+j\end{cases}=\max(1,i+j-n) \]
-
而对于一般情况 \(\gcd(i,j)=d\),我们也可以得出如果对于所有的 \(0\le x<d\) 都满足 \([1,n]\) 中模 \(d\) 为 \(x\) 的数个数小于 \(\frac{i+j}d\),则这个图无环,否则这个图所有 \(d\) 个部分都是连通的。这时:
-
\[cnt=\begin{cases}i+j-n & n\le i+j-d\\d & n> i+j-d\end{cases}=\max(d,i+j-n) \]
-
这样就得出了答案的一个简洁的表达式:
-
\[ans\times k^n=\sum_{i=1}^{n-1}\sum_{j=1}^{n-1}k^{\max(\gcd(i,j),i+j-n)} \]
-
这样是 \(O(n^2)\) 的
-
先去掉 \(\max\),即不对 \(\gcd(i,j)\) 取 \(\max\),那么这时候的结果就是 \(\sum_{i=1}^{n-1}\sum_{j=1}^{n-1}k^{i+j-n}\),把 \(k^{-n}\) 提出来之后,这是一个关于 \(k\) 的多项式,每一项的系数都可以 \(O(1)\) 计算出来
-
然后再把 \(i+j-n<\gcd(i,j)\) 的贡献替换掉
-
考虑到式子中有 \(\gcd(i,j)\) 这一项,先进行莫比乌斯反演:
-
\[\sum_{d|p}\mu(\frac pd)\sum_{i=1}^{\lfloor\frac{n-1}p\rfloor}\sum_{j=1}^{\lfloor\frac{n-1}p\rfloor}k^{ip+jp-n}[ip+jp<n+d] \]
-
枚举 \(d\) 和 \(p\) 的复杂度是调和级数,即 \(O(n\log n)\)
-
先把 \(k^{-n}\) 提出来,那么这可以表示成一个关于 \(k^p\) 的多项式。首先 \(ip+jp<n+d\) 相当于 \(i+j\le\lfloor\frac{n+d-1}p\rfloor\),而我们又有 \(d\le p\),故 \(\lfloor\frac{n+d-1}p\rfloor\le\lfloor\frac{n-1}p\rfloor+1\)
-
故我们考虑这个多项式的次数不超过 \(\lfloor\frac{n-1}p\rfloor+1\) 的所有项,易得在这个项次数的限制下,次数为 \(i\)(\(\ge 2\))的项系数为 \(i-1\)
-
这相当于枚举一个 \(i\)(\(ip\le 2n-2,i\le\lfloor\frac{n+d-1}p\rfloor\)),把答案多项式的 \(ip\) 次项系数减掉 \((i-1)\times\mu(\frac pd)\)
-
具体实现可以先枚举 \(p\),然后用一个长度为 \(\lfloor\frac{2n-2}p\rfloor\) 的数组储存这个 \(p\) 导致的答案多项式系数变化量,第 \(i\) 个位置对应答案多项式的 \(ip\) 次项。然后枚举 \(p\) 的约数 \(d\),计算出扣掉的贡献之后给这个数组打上标记,约数枚举完了之后通过标记得出变化量即可计算对答案多项式的影响
-
此外上面的过程我们只是扣掉了 \(i+j-n<\gcd(i,j)\) 的 \(k^{i+j-n}\) 的贡献,我们还需对于所有的 \(i+j-n<\gcd(i,j)\) 令答案加上 \(k^d\)。这个比较好处理,和上面一样枚举 \(d|p\) 之后,把 \(\sum_{i=1}^{\lfloor\frac{n-1}p\rfloor}\sum_{j=1}^{\lfloor\frac{n-1}p\rfloor}k^d[ip+jp<n+d]\) 根据 \(\mu(\frac pd)\) 的正负性计入答案或从答案中扣除即可
-
而上面已经提到,对于 \(x\le\lfloor\frac{n-1}p\rfloor+1\),满足 \(i+j=x\) 的 \((i,j)\) 恰好有 \(x-1\) 对,故这个式子的值为 \(\binom{\lfloor\frac{n+d-1}p\rfloor}2k^d\)
-
枚举了 \(d|p\) 之后的东西都可以 \(O(1)\) 算出,故总复杂度为 \(O(n\log n)\)
Code
#include <bits/stdc++.h>
const int N = 1e5 + 5, M = N << 1, rqy = 1e9 + 7;
int n, k, tot, ik, pk[N], pri[N], miu[N], a[M], b[M], ans;
bool mark[N];
std::vector<int> di[N];
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = 1ll * res * a % rqy;
a = 1ll * a * a % rqy;
b >>= 1;
}
return res;
}
void sieve()
{
mark[0] = mark[miu[1] = 1] = 1;
for (int i = 2; i <= n; i++)
{
if (!mark[i]) miu[pri[++tot] = i] = -1;
for (int j = 1; j <= tot; j++)
{
if (1ll * i * pri[j] > n) break;
mark[i * pri[j]] = 1;
if (i % pri[j] == 0) break;
miu[i * pri[j]] = -miu[i];
}
}
for (int i = 1; i <= n; i++)
for (int j = i; j <= n; j += i)
di[j].push_back(i);
}
int main()
{
std::cin >> n >> k;
sieve(); ik = qpow(k, rqy - 2);
pk[0] = 1;
for (int i = 1; i <= n; i++) pk[i] = 1ll * pk[i - 1] * k % rqy;
for (int i = n; i <= (n - 1 << 1); i++) a[i] = (n - 1 << 1) - i + 1;
for (int p = 1; p < n; p++)
{
int l = (n - 1 << 1) / p;
for (int i = 1; i <= l + 1; i++) b[i] = 0;
for (int i = 0; i < di[p].size(); i++)
{
int d = di[p][i], m = (n + d - 1) / p;
b[m] = (miu[p / d] + b[m] + rqy) % rqy;
ans = (1ll * m * (m - 1) / 2 % rqy * pk[d]
% rqy * (miu[p / d] + rqy) + ans) % rqy;
}
for (int i = l; i >= 1; i--) b[i] = (b[i] + b[i + 1]) % rqy;
for (int i = 1; i <= l; i++)
a[i * p] = (a[i * p] - 1ll * b[i] * (i - 1) % rqy + rqy) % rqy;
}
for (int i = n; i <= (n - 1 << 1); i++)
ans = (1ll * a[i] * pk[i - n] + ans) % rqy;
for (int i = 1; i <= n; i++) ans = 1ll * ans * ik % rqy;
return std::cout << ans << std::endl, 0;
}