题解 白楼剑
写一半代码没了,直接引出了 bash 学习笔记
考虑扔到 SAM 上
因为是要求子串在第二个串中出现过,所以考虑对第二个串的每个子串统计贡献
一个暴力是对于每个询问,将 \(s_{l,r}\) 放在第二个串的 SAM 上跑
那么跑到的每个节点及其祖先都可以产生 \(len*cnt\) 的贡献
然后需要优化这个东西
看着不太能优化,考虑暴力数据结构
尝试回滚莫队
令当前块的右端点为 \(R\)
将贡献拆为右端点 \(\geqslant R\) 和 左端点 \(\leqslant R\) 两种
第一种将每个块的询问按右端点排序,从每个 \(R\) 开始将 \(s_{R, n}\) 在上面跑即可
第二种对每个询问都倍增找到 \(s_{R, q_r}\),暴力移动到 \(s_{q_l, q_r}\) 同时更新答案即可
为保证这部分跳 fail 的复杂度,要预处理 \(f_{i, c}\) 为 \(i\) 的祖先中第一个有出边 \(c\) 的,虽然我代码里没有写
然后细节有很多,第二种时需要时刻保证 \(len\) 没有超出询问串的范围
然后我倍增忘记预处理 \(lg\) 数组调了一晚上
复杂度 \(O(n|\sum|+(n+q)\sqrt n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long
int n, m, q;
int lim;
char s1[N], s2[N];
const ull base=13131;
namespace force{
ull pw[N], h[N];
unordered_map<ull, int> mp;
inline ull hashing(int l, int r) {return h[r]-h[l-1]*pw[r-l+1];}
void solve() {
pw[0]=1;
for (int i=1; i<=lim; ++i) pw[i]=pw[i-1]*base;
for (int i=1; i<=m; ++i) h[i]=h[i-1]*base+s2[i];
for (int l=1; l<=m; ++l)
for (int r=l; r<=m; ++r)
++mp[hashing(l, r)];
for (int i=1; i<=n; ++i) h[i]=h[i-1]*base+s1[i];
ll ans;
for (int i=1,l,r; i<=q; ++i) {
scanf("%d%d", &l, &r); ans=0;
for (int j=l; j<=r; ++j)
for (int k=j; k<=r; ++k) if (mp.find(hashing(j, k))!=mp.end())
ans=max(ans, 1ll*(k-j+1)*mp[hashing(j, k)]);
printf("%lld\n", ans);
}
}
}
namespace task1{
ull pw[N], h[N];
ll ans[N], bit[N];
vector<pair<int, int>> que[N];
unordered_map<ull, int> mp[5010];
inline ull hashing(int l, int r) {return h[r]-h[l-1]*pw[r-l+1];}
inline void upd(int i, ll val) {for (; i<=n; i+=i&-i) bit[i]=max(bit[i], val);}
inline ll query(int i) {ll ans=1; for (; i; i-=i&-i) ans=max(ans, bit[i]); return ans;}
void solve() {
pw[0]=1;
for (int i=1; i<=lim; ++i) pw[i]=pw[i-1]*base;
for (int i=1; i<=m; ++i) h[i]=h[i-1]*base+s2[i];
for (int l=1; l<=m; ++l)
for (int r=l; r<=m; ++r)
++mp[r-l+1][hashing(l, r)];
for (int i=1; i<=n; ++i) h[i]=h[i-1]*base+s1[i];
for (int i=1,l,r; i<=q; ++i) {
scanf("%d%d", &l, &r);
que[l].pb({r, i});
}
for (int i=n; i; --i) {
ull h=0;
for (int j=i; j<=n; ++j) {
h=h*base+s1[j];
if (mp[j-i+1].find(h)!=mp[j-i+1].end()) bit[j]=max(bit[j], 1ll*(j-i+1)*mp[j-i+1][h]);
else break;
}
for (auto& it:que[i]) for (int j=i; j<=it.fir; ++j) ans[it.sec]=max(ans[it.sec], bit[j]);
}
for (int i=1; i<=q; ++i) printf("%lld\n", ans[i]);
}
}
namespace task{
ll ans[N];
int bel[N], ls[N], rs[N], mat[N], at[N], sqr;
struct ques{int l, r, id;};
inline bool operator < (ques a, ques b) {return a.r<b.r;}
vector<ques> que[N];
struct sam{
ll g[N];
int fa[21][N], dep[N], lg[N], pos[N];
int len[N], fail[N], tr[N][26], cnt[N], tem[N], right[N], now, tot;
void init() {fail[now=tot=0]=-1;}
void insert(char c) {
c-='a';
int cur=++tot;
len[cur]=len[now]+1;
int p, q;
for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
if (p==-1) fail[cur]=0;
else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
else {
int cln=++tot;
len[cln]=len[p]+1;
fail[cln]=fail[q];
memcpy(tr[cln], tr[q], sizeof(tr[q]));
for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
fail[cur]=fail[q]=cln;
}
right[now=cur]=1;
pos[len[cur]]=cur;
}
void build() {
for (int i=1; i<=tot; ++i) ++cnt[len[i]];
for (int i=1; i<=lim; ++i) cnt[i]+=cnt[i-1];
for (int i=1; i<=tot; ++i) tem[cnt[len[i]]--]=i;
for (int i=tot; i; --i) right[fail[tem[i]]]+=right[tem[i]];
for (int i=1; i<=tot; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[0]=1;
for (int i=1; i<=tot; ++i) {
int u=tem[i];
dep[u]=dep[fa[0][u]=fail[u]]+1;
g[u]=max(g[fail[u]], 1ll*len[u]*right[u]);
for (int j=1; dep[u]>=1<<j; ++j)
fa[j][u]=fa[j-1][fa[j-1][u]];
}
}
int find(int u, int l) {
for (int i=lg[dep[u]]-1; ~i; --i)
if (len[fa[i][u]]>=l)
u=fa[i][u];
return u;
}
}sam1, sam2;
void force_query(ques it) {
ll ans=0;
int u=0, len=0;
for (int i=it.l; i<=it.r; ++i) {
while (~u&&!sam1.tr[u][s1[i]-'a']) u=sam1.fail[u], len=sam1.len[u];
if (~u) {
u=sam1.tr[u][s1[i]-'a'];
ans=max({ans, sam1.g[sam1.fail[u]], 1ll*++len*sam1.right[u]});
}
else u=len=0;
}
task::ans[it.id]=ans;
}
void solve() {
sqr=sqrt(n);
for (int i=1; i<=n; ++i) bel[i]=(i-1)/sqr+1;
for (int i=1; i<=n; ++i) rs[bel[i]]=i;
for (int i=n; i; --i) ls[bel[i]]=i;
sam1.init(); sam2.init();
for (int i=1; i<=m; ++i) sam1.insert(s2[i]);
for (int i=m; i; --i) sam2.insert(s2[i]);
sam1.build(); sam2.build();
for (int i=1,l,r; i<=q; ++i) {
scanf("%d%d", &l, &r);
que[bel[l]].pb({l, r, i});
}
// cout<<"bel: "; for (int i=1; i<=n; ++i) cout<<bel[i]<<' '; cout<<endl;
for (int i=n,u=0,len=0; i; --i) {
while (~u&&!sam2.tr[u][s1[i]-'a']) u=sam2.fail[u], len=sam2.len[u];
if (~u) {
u=sam2.tr[u][s1[i]-'a'];
at[i]=u; mat[i]=++len;
}
else u=len=0;
}
// cout<<"mat: "; for (int i=1; i<=n; ++i) cout<<mat[i]<<' '; cout<<endl;
for (int i=1; i<=bel[n]; ++i) {
// cout<<"i: "<<i<<endl;
ll ans=0;
int u=0, len=0;
int l=rs[i]+1, r=rs[i];
sort(que[i].begin(), que[i].end());
for (auto& it:que[i]) {
if (bel[it.l]==bel[it.r]) force_query(it);
else {
while (r<it.r) {
++r;
while (~u&&!sam1.tr[u][s1[r]-'a']) u=sam1.fail[u], len=sam1.len[u];
if (~u) {
u=sam1.tr[u][s1[r]-'a'];
ans=max({ans, sam1.g[sam1.fail[u]], 1ll*++len*sam1.right[u]});
// cout<<"to_right: "<<r<<' '<<sam1.g[sam1.fail[u]]<<' '<<1ll*len*sam1.right[u]<<endl;
}
else u=len=0;
}
// cout<<"ans: "<<ans<<endl;
ll tem=ans;
// int v=sam2.find(at[l], min(it.r-l+1, mat[l])), len=min(it.r-l+1, mat[l]);
// int v=at[l], len=min(it.r-l+1, mat[l]);
// while (sam2.len[sam2.fail[v]]>=it.r-l+1) v=sam2.fail[v], len=min(sam2.len[v], it.r-l+1);
int v=at[l], len;
if (mat[l]<=it.r-l+1) len=mat[l];
else v=sam2.find(at[l], it.r-l+1), len=min(sam2.len[v], it.r-l+1);
// cout<<"len: "<<len<<endl;
// int v=0, len=0;
for (int j=rs[i]; j>=it.l; --j) {
// cout<<"j: "<<j<<endl;
while (~v&&!sam2.tr[v][s1[j]-'a']) v=sam2.fail[v], len=min(sam2.len[v], it.r-j);
// cout<<"new len: "<<len<<endl;
if (~v) {
v=sam2.tr[v][s1[j]-'a'];
tem=max({tem, sam2.g[sam2.fail[v]], 1ll*++len*sam2.right[v]});
// cout<<"len: "<<len<<endl;
// cout<<"to_left: "<<j<<' '<<sam2.g[sam2.fail[v]]<<' '<<1ll*len*sam2.right[v]<<endl;
}
else v=len=0;
}
task::ans[it.id]=tem;
}
}
}
for (int i=1; i<=q; ++i) printf("%lld\n", ans[i]);
}
}
signed main()
{
freopen("loss.in", "r", stdin);
freopen("loss.out", "w", stdout);
scanf("%s%s%d", s1+1, s2+1, &q);
lim=max(n=strlen(s1+1), m=strlen(s2+1));
// force::solve();
// task1::solve();
task::solve();
return 0;
}