题解 高维入侵/[WC2016]论战捆竹竿
考场上看错题了,以为要求的是能形成的本质不同的字符串数,于是完全背包死活过不了样例
求的是能形成的不同长度数,可以联想到同余最短路
于是暴力的做法是求出所有boarder长度,\(O(n^2logn)\) 跑同余最短路
但显然是过不去的,于是考虑优化
有一个 \(O(n^2)\) 的优化
当我们求出在 \(\bmod mod\) 意义下的最短路数组时,考虑怎么改变mod
- 关于同余最短路改模数:
比如说变成base,那转移有两种- 用 \(dis_i\) 更新 \(dis'_{dis_i\bmod base}\)
- 用 \(dis_i+k*mod\) 更新 \(dis'_{i+k*mod \bmod base}\)
然后改模数的过程可以在 \(\bmod mod\) 意义下的等价类中考虑,每个等价类中dis最小的点一定不会被更新
所以可以从dis最小的地方断环为链,将贡献写成 \(dis_i+(j-i)*mod\) 的形式可以用前缀min优化成 \(O(n)\)
于是对每个boarder都这样做一遍就是 \(O(n^2)\) 的
但是还不够,需要再优化
一个字符串的boarder形成 \(O(logn)\) 段等差序列
尝试将每个等差序列放在一起做
对于一个形如 \(kb+a, k\in[0, l]\) 的等差数列,在模 \(a\) 意义下按同余最短路的方式建边会形成 \(\gcd(a, b)\) 个环
对每个环按上面的方法断环为链,将贡献写成 \(dis_i+(j-i)*b+a\) 的形式可以单调队列优化DP
每做完一个等差数列就将模数换为下一个等差数列的首项
复杂度 \(O(nlogn)\),在uoj上加取模优化可以过
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 500010
#define ll long long
// #define int long long
int n;
ll dis[N], dis2[N], w, ans;
char s[N];
int nxt[N], bdr[N], tem[N], sta[N], cnt, top, base, lst;
struct que{ll val; int pos;}q[N];
inline int gcd(int a, int b) {return !b?a:gcd(b, a%b);}
inline int md(int a, int mod) {return a>=mod?a-mod:a;}
void kmp() {
nxt[1]=0; cnt=0;
for (int i=2,j=0; i<=n; ++i) {
while (j && s[i]!=s[j+1]) j=nxt[j];
if (s[i]==s[j+1]) ++j;
nxt[i]=j;
}
for (int u=nxt[n]; u; u=nxt[u]) bdr[++cnt]=n-u;
bdr[++cnt]=n;
// cout<<"bdr: "; for (int i=1; i<=cnt; ++i) cout<<bdr[i]<<' '; cout<<endl;
}
void turn(int lst, int base) {
// cout<<"turn: "<<lst<<' '<<base<<endl;
if (!lst) return ;
for (int i=0; i<lst; ++i) dis2[i]=dis[i], dis[i]=INF;
for (int i=0; i<lst; ++i) dis[dis2[i]%base]=min(dis[dis2[i]%base], dis2[i]);
// cout<<"dis': "; for (int i=0; i<base; ++i) cout<<dis[i]<<' '; cout<<endl;
int g=gcd(lst, base);
for (int t=0; t<g; ++t) {
tem[top=1]=t;
for (int i=md(t+lst, base); i!=t; i=md(i+lst, base)) tem[++top]=i;
// cout<<"tem: "; for (int i=1; i<=top; ++i) cout<<tem[i]<<' '; cout<<endl;
ll mn=INF;
for (int i=1; i<=top; ++i) mn=min(mn, dis[tem[i]]);
int pos=1, now=0;
for (int i=1; i<=top; ++i) if (dis[tem[i]]==mn) {pos=i; break;}
for (int i=pos; i<=top; ++i) sta[++now]=tem[i];
for (int i=1; i<pos; ++i) sta[++now]=tem[i];
// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
mn=INF;
for (int i=1; i<=top; ++i) {
dis[sta[i]]=min(dis[sta[i]], mn+1ll*i*lst);
mn=min(mn, dis[sta[i]]-1ll*i*lst);
}
}
}
void solve(int a, int b, int l) {
// cout<<"solve: "<<a<<' '<<b<<' '<<l<<endl;
int g=gcd(a, b);
base=a;
turn(lst, base);
if (b<=0) return ;
for (int t=0; t<g; ++t) {
tem[top=1]=t;
for (int i=md(b+t, a); i!=t; i=md(i+b, a)) tem[++top]=i;
// cout<<"tem: "; for (int i=1; i<=top; ++i) cout<<tem[i]<<' '; cout<<endl;
ll mn=INF;
for (int i=1; i<=top; ++i) mn=min(mn, dis[tem[i]]);
int pos=1, now=0;
for (int i=1; i<=top; ++i) if (dis[tem[i]]==mn) {pos=i; break;}
for (int i=pos; i<=top; ++i) sta[++now]=tem[i];
for (int i=1; i<pos; ++i) sta[++now]=tem[i];
// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
int ql=1, qr=0;
for (int i=1; i<=top; ++i) {
while (ql<=qr && q[ql].pos<i-l+1) ++ql;
if (ql<=qr) dis[sta[i]]=min(dis[sta[i]], q[ql].val+1ll*i*b+a);
while (ql<=qr && q[qr].val>=dis[sta[i]]-1ll*i*b) --qr;
q[++qr]={dis[sta[i]]-1ll*i*b, i};
}
}
// cout<<"dis: "; for (int i=0; i<base; ++i) cout<<dis[i]<<' '; cout<<endl;
}
signed main()
{
int T;
scanf("%d", &T);
while (T--) {
ans=0; lst=0;
scanf("%d%lld%s", &n, &w, s+1);
memset(dis, 127, sizeof(dis)); dis[0]=0;
kmp();
for (int l=1,r; l<=cnt; l=r+1,lst=base) {
for (r=l+1; r+1<=cnt&&bdr[r+1]-bdr[r]==bdr[r]-bdr[r-1]; ++r);
// if (l!=1) puts("error");
if (l==cnt) solve(bdr[l], -1, 1);
else solve(bdr[l], bdr[r]-bdr[r-1], r-l+1);
}
// cout<<"dis: "; for (int i=0; i<base; ++i) cout<<dis[i]<<' '; cout<<endl;
for (int i=0; i<base; ++i) if (w-n>=dis[i]) ans+=(w-n-dis[i])/base+1;
printf("%lld\n", ans);
}
return 0;
}