WGOI R1 真夏飞焰 题解
【题目大意】
题目链接。
给定序列 \(a\),我们定义序列 \(a,b\) 是「\(k\) 相似」的,当且仅当对于 \(a\) 中每一个四元组 \((l_1,r_1,l_2,r_2)\),若满足 \(r_1 - l_1 + 1 = r_2 - l_2 + 1 \le k,l_2 = r_1 + 1,a_{l_1\ldots r_1} = a_{l_2 \ldots r_2}\),则有 \(b_{l_1\ldots r_1} = b_{l_2 \ldots r_2}\)。
给出 \(n\) 个限制形如 \(b_i \in [L_i,R_i]\),对 \(k \in [1,\dfrac{n}{2}]\) 求出:如果序列 \(a,b\) 是 「\(k\) 相似」的,有多少种合法的序列 \(b\),对 \(998244353\) 取模。
\(1 \le n \le 10^6,1 \le L_i \le R_i \le 10^6\)。
【题解】
考虑朴素暴力:用并查集维护 \(b\) 的相等关系,同一个连通块内 \(b\) 相等,因此我们可以把 \(L\) 取 \(\max\),\(R\) 取 \(\min\),最终答案是所有连通块的 \(R-L+1\) 之积。枚举 \(l_1,l_2\),如果有 \(a_{l_1\ldots r_1} = a_{l_2 \ldots r_2}\),则对于 \(j \in [l_1,r_1]\),合并 \(j,j+r_1-l_1\)。判断相等可以用 SA,时间复杂度 \(\mathcal O(n^3 \log n)\),期望得分 \(10\) 分。
有一个简单优化:注意到只有 \(\mathcal O(n^2)\) 对不同的 \((i,j)\),我们对于一次区间合并操作,可以差分,最后再扫一遍一起合并,时间复杂度 \(\mathcal O(n^2 \log n)\),期望得分 \(20\) 分,但是实测如果实现精良可以得 \(40\) 分。
上面的做法有两个瓶颈:一是找到所有区间 \([i,j]\) 需要 \(\mathcal O(n^2)\) 的时间,二是合并区间需要 \(\mathcal O(n^2 \log n)\) 的时间。我们分开解决。
第一部分是套路的,其实就是找所有形如 \(\text{AA}\) 的串。用 NOI2016 优秀的拆分 的套路,对于 \(\text{|A|} = i\),在 \(j = i,2i,\ldots,ki\) 设置关键点,一个形如 \(\text{AA}\) 的串会恰好跨过两个关键点。我们枚举相邻的关键点 \(p,q\),求后缀 \(p,q\) 的最长公共前缀 \(d_1\) 和前缀 \(p,q\) 的最长公共后缀 \(d_2\),如果 \(d_1 + d_2 - 1 \ge i\),可以得到一个长度为 \(\mathcal O(i)\) 的合并区间。由调和级数,枚举关键点的复杂度为 \(\mathcal O(n \log n)\),求最长公共前后缀可以 SA \(\mathcal O(1)\) 回答。
事实上,如果你写到这里就直接提交,由于数据很强,可以获得 \(100\) 分,但是我们的复杂度还是 \(\mathcal O(n^2 \log n)\),用一个全 \(\texttt{a}\) 串即可卡满。
注意到只有 \(\mathcal O(n)\) 个点,即只有 \(\mathcal O(n)\) 次有用的合并,我们希望快速找到它们。我想到这里就不会了 qwq,这里借鉴了题解的思路,类似倍增优化建图。
建 \(\mathcal O(\log n)\) 层点,每层 \(n\) 个,\((i,j)\) 表示第 \(j\) 层第 \(i\) 个点。我们用 \((i,j)\) 表示区间 \([i,i+2^j-1]\) 的合并情况,合并 \((i,j),(k,j)\) 表示 \([i,i+2^j-1]\) 和 \([k,k+2^j-1]\) 两个区间内的点连通。
考虑区间连边的简单情况:区间长度是 \(2^k\)。这时我们找到两个虚点,如果某一层已经连通则说明大区间已经连通,退出;否则,像 st 表一样递归两个子区间,最后给这一层的区间连边。由于只有 \(\mathcal O(n \log n)\) 的点,所以合并的复杂度为 \(\mathcal O(n \log^2 n)\)。
可能连边的过程说起来比较抽象,C++ 代码如下。
代码
void add(int id,int x,int y){ // x,y 是区间的左端点,2^id 是区间长度
if(!id) return mg(x,y),void(); // 叶子节点直接退出
if(!mg(get(id,x),get(id,y))) return ; // 已经连通直接退出
add(id - 1,x,y),add(id - 1,x + (1 << id - 1),y + (1 << id - 1)); // 向下递归
}
那么如果连边区间长度不是 \(2^k\),拆成两个连就行了。
总时间复杂度 \(\mathcal O(n \log ^2 n)\),其中瓶颈在并查集,可以通过。
完整代码用了 atcoder 的 SA,为了可读性就省略了。
代码
using namespace atcoder;
const int N = 2e6+5,mod = 998244353;
using namespace std;
int T,n,s[N],tmp,inv[N],ans,fa[N*10],ql[N*10],qr[N*10];
int qp(int a,int b){
int r = 1;
for(;b;b >>= 1,a = 1ll * a * a % mod) if(b & 1) r = 1ll * r * a % mod;
return r;
}int fd(int x){return x == fa[x] ? x : fa[x] = fd(fa[x]);} vector<int> g[N];
bool mg(int u,int v){
if(u = fd(u),v = fd(v),u == v) return 0;
if(ans) ans = 1ll * ans * inv[qr[u] - ql[u] + 1] % mod * inv[qr[v] - ql[v] + 1] % mod;
fa[v] = u,ql[u] = max(ql[u],ql[v]),qr[u] = min(qr[u],qr[v]);
if(qr[u] < ql[u]) ans = 0; else ans = 1ll * ans * (qr[u] - ql[u] + 1) % mod;
return 1;
}struct SAA{
int sa[N],rk[N],h[N],st[20][N],n;
void init(int _n){for(int i = 1;i <= 2 * _n;i ++) h[i] = 0;n = _n;}
void SA(int *s){
vector<int> _s;
for(int i = 1;i <= n;i ++) _s.push_back(s[i]);
vector<int> _sa = suffix_array(_s);
for(int i = 1;i <= n;i ++) sa[i] = _sa[i - 1] + 1,rk[sa[i]] = i;
for(int i = 1,k = 0;i <= n;h[rk[i ++]] = k) for(k --,k = max(k,0);s[i + k] == s[sa[rk[i] - 1] + k];k ++);
for(int i = 1;i <= n;i ++) st[0][i] = h[i];
for(int j = 1;j <= __lg(n);j ++) for(int i = 1;i + (1 << j) - 1 <= n;i ++)
st[j][i] = min(st[j-1][i],st[j-1][i+(1<<j-1)]);
}int qry(int i,int j){
if(i == j) return n;
if(i = rk[i],j = rk[j],i > j) swap(i,j);
int k = __lg(j - i);
return min(st[k][i+1],st[k][j-(1<<k)+1]);
}
}s1,s2;
int get(int id,int u){return id * n + u;}
void add(int id,int x,int y){
if(!id) return mg(x,y),void();
if(!mg(get(id,x),get(id,y))) return ;
add(id - 1,x,y),add(id - 1,x + (1 << id - 1),y + (1 << id - 1));
}void los(){
cin >> n,ans = 1;
for(int i = 1;i <= 2 * n;i ++) s[i] = fa[i] = i;
for(int i = 1;i <= 20 * n;i ++) fa[i] = i;
for(int i = 1;i <= n;i ++) cin >> s[i];
for(int i = 1;i <= n;i ++) cin >> ql[i] >> qr[i];
for(int i = 1;i <= n;i ++) ans = 1ll * ans * (qr[i] - ql[i] + 1) % mod;
s1.init(n),s2.init(n),s1.SA(s),reverse(s+1,s+n+1),s2.SA(s);
for(int i = 1;i <= n / 2;i ++){
int fg = 0;
for(int j = i;j <= n;j += i){
if(!ans) break; if(j + i > n) break;
int p = j,q = j + i,d1 = min(i,s1.qry(p,q)),d2 = min(i,s2.qry(n - p + 1,n - q + 1));
if(d1 + d2 - 1 < i) continue;
auto dda = [&](int l,int r,int ql,int qr){
int k = __lg(r - l + 1);
add(k,l,ql),add(k,r-(1<<k)+1,qr-(1<<k)+1);
}; dda(p - d2 + 1,p + d1 - 1,p - d2 + 1 + i,p + d1 - 1 + i);
}cout << ans << " ";
}cout << "\n";
}int main(){
// freopen("yoimiya.in","r",stdin),freopen("yoimiya.out","w",stdout);
ios::sync_with_stdio(0),cin.tie(0);
for(int i = 1;i <= 1000000;i ++) inv[i] = qp(i,mod - 2);
for(cin >> T;T --;) los();
}