Codeforces 775 E Tyler and Strings
Codeforces 775 E Tyler and Strings
题意
给出长度为 \(n\) 的字符串 \(s\) 和长度为 \(m\) 的字符串 \(t\)。 若将 \(s\) 中字符重新排列,生成字符串 \(p\),求字符串 \(p\) 的个数,满足字典序 \(p < t\)。输出答案对 \(998244353\) 取模后的值。
数据满足: \(1 \leq n,m,s_i,t_i \leq 2\times 10^5\)。
题解
两个字符串长度关系,共有三种情况:\(|s| < |t|\) , \(|s| = |t|\) 或 \(|s| > |t|\)。
我们先考虑两个字符串长度相等的情况,如果想要 \(p < t\),字符串 \(p\) 必然满足:
- 存在一个位置 \(i(1 \leq i \leq n)\) 满足 \(p_j = t_j (1 \leq j < i)\) 并且 \(p_i < t_i\)。
对于不同的字符串 \(p\) ,我们可以枚举这个第一个小于 \(t\) 串对应字符位置 \(i\) ,那么 \(p\) 的构造就是:
- \(p_j = t_j (1 \leq j < i)\) , \(p_i\) 位置选择任意一个小于 \(t_i\) 的字符,\(p_j (i < j \leq n)\) 随意选择即可。
我们需要维护后缀地随意排列构成不同字符串地个数,这显然是一个组合数问题:
\(x_1 + x_2 + x_3 + ... + x_k = tot\),求可以构成不同排列地个数。
答案为: \(\frac{tot!}{x_1! ...x_p! ... x_k!}\)。
如果存在一个某个数字 \(w < t_i\) ,其还剩余 \(x_p\) 个,那么如果在 \(s\) 的第 \(i\) 个位置放置该数字,那么后续可以构成的不重复的字符串为:\(\frac{(tot-1)!}{x_1! ...(x_p-1)! ... x_k!}\),此时答案应该加上这个数字。
注意到\(\frac{(tot-1)!}{x_1! ...(x_p-1)! ... x_k!} = \frac{tot!}{x_1! ...x_p! ... x_k!} \times \frac{x_p}{tot}\),那么如果存在很多不同的\(w < t_i\),我们可以合并计算。
即:\(\frac{tot!}{x_1! ...x_p! ... x_k!} \times \frac{x_{p1} + x_{p2} + ... x_{pr}}{tot}\)
因此需要有一个数据结构,支持单点修改,区间查询,用树状数组可容易实现。
按照上面的方法,分析\(|s| > |t|\) 和 \(|s| < |t|\) 的两种情况,发现有一个细节没有考虑完善:
- 当 \(|s| < |t|\) 时,如果 \(s\) 和 \(t\) 的前缀完全相同,那么会少考虑一次,原因是此时不存在一个位置\(i(1 \leq i \leq n)\) 让 \(p_i < t_i\),但事实上字典序\(p < t\)成立。
特殊判断后,本题就能顺利通过,时间复杂度 \(O(n \log_2 n)\)。
C++ 代码示例
# include <bits/stdc++.h>
# define int long long
# define lowbit(x) (x&(-x))
using namespace std;
const int N = 2e5+10;
const int mo = 998244353;
int fac[N],c[N],a[N],b[N];
void update(int x,int d) {
for (;x<=N-10;x+=lowbit(x)) c[x]+=d;
}
int query(int x) {
int res=0;
for (;x;x-=lowbit(x)) res+=c[x];
return res;
}
int pow(int x,int n) {
int ans = 1;
while (n) {
if (n&1) ans = ans * x %mo;
x=x*x%mo;
n>>=1;
}
return ans;
}
int inv(int x) {
return pow(x,mo-2);
}
signed main() {
int n,m; cin>>n>>m;
map<int,int>tmp;
for (int i=1;i<=n;i++) {
cin>>a[i];
update(a[i],1);
tmp[a[i]]++;
}
for (int i=1;i<=m;i++) cin>>b[i];
fac[0]=1; for (int i=1;i<=N-10;i++) fac[i]=fac[i-1]*i%mo;
int res = fac[n];
for (auto x :tmp) {
res = res * inv(fac[x.second]) % mo;
}
int ans = 0,tot = n;
for (int i=1;i<=min(n,m);i++) {
(ans+=res*inv(tot)%mo*query(b[i]-1)%mo)%=mo;
res=res*inv(tot)%mo*(query(b[i])-query(b[i]-1))%mo;
tot--;
update(b[i],-1);
}
if (n < m) {
sort(a+1,a+1+n);
sort(b+1,b+1+n);
bool f = true;
for (int i=1;i<=n;i++) if (a[i]!=b[i]) {
f = false; break;
}
if (f) ans = (ans+1)%mo;
}
cout<<ans<<endl;
return 0;
}