题解 Gym 104386G【CLC Loves SQRT Technology (Hard Version)】

problem

我们对于一个序列 x 定义一个函数 f(x) 表示为使得这个序列变成回文序列的最少修改次数。

现在给定一个长度为 n 的序列,要求输出他所有子序列 x 的 f(x) 的总和。\(n\leq 10^5\)

solution 1

首先不难想到枚举每个 pair 计算贡献:

\[ans=\sum_{i<j}[a_i\neq a_j]2^{j-i-1}\sum_{\ell}\binom{i-1}{\ell}\binom{n-j}{\ell} \]

发现后面关于 \(\ell\) 的枚举是范德蒙德卷积:

\[\sum_{\ell}\binom{i-1}{\ell}\binom{n-j}{\ell}=\sum_{\ell}\binom{i-1}{i-1-\ell}\binom{n-j}{\ell}=\binom{i-1+n-j}{i-1} \]

于是你会了 \(O(n^2)\) 暴力 😦

solution 2

考虑如果 \(a_i\) 是排列,输出 \(\sum_{\ell}\binom n \ell \left\lfloor\frac\ell 2\right\rfloor\),证明显然吧,枚举序列长度,然后修改次数打满。

考虑计算相等的负贡献。

对于 \(i,j\),如果 \(a_i=a_j\),它们的贡献是

\[2^{j-i-1}\binom{i-1+n-j}{i-1}=(i-1+n-j)!\cdot\frac{1}{2^{i+1}(i-1)!}\cdot\frac{2^j}{(n-j)!} \]

很想卷积啊!卷啊!构造两个生成函数 \(A,B\),把每个点打到 \(A[x-1],B[n-x]\) 上,卷!然后枚举 \(\ell=i-1+n-j\),观察到 \(j-i>0\implies-\ell+n-1>0\implies\ell<n-1\),然后这个是唯一的限制,枚举 \(0\leq\ell<n-1\),然后点乘 \(\ell!\) 计入答案。

然后这样的复杂度是 \(O(n\log n)\)

solution 3

我们有 \(O(c^2)\) 的暴力和 \(O(n\log n)\) 的卷积。

不如撮合一下,设置 \(B\),使得 \(\leq B\) 的暴力,\(>B\) 的卷积,复杂度大概为 \(O(nB+n^2\log n/B)\)(将 \(\sum c^2\) 缩放成 \(\sum c\cdot B\)),然后取等就取在 \(B=\sqrt{n\log n}\),然后是可以跑的。

code

点击查看代码

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
int glim(int x){return 1<<(32-__builtin_clz(x));}
template<unsigned P> struct modint{
	unsigned v; modint():v(0){}
	template<class T> modint(T x):v((x%int(P)+int(P))%int(P)){}
	modint operator-()const{return modint(P-v);}
	modint inv()const{return qpow(*this,int(P)-2);}
	modint&operator+=(const modint&rhs){if(v+=rhs.v,v>=P) v-=P; return *this;}
	modint&operator-=(const modint&rhs){return *this+=-rhs;}
	modint&operator*=(const modint&rhs){v=1ull*v*rhs.v%P; return *this;}
	modint&operator/=(const modint&rhs){return *this*=rhs.inv();}
	friend int raw(const modint&self){return self.v;}
	template<class T> friend modint qpow(modint a,T b){modint r=1;for(;b;b>>=1,a*=a) if(b&1) r*=a; return r;}
	friend modint operator+(modint lhs,const modint&rhs){return lhs+=rhs;}
	friend modint operator-(modint lhs,const modint&rhs){return lhs-=rhs;}
	friend modint operator*(modint lhs,const modint&rhs){return lhs*=rhs;}
	friend modint operator/(modint lhs,const modint&rhs){return lhs/=rhs;}
	friend modint operator==(const modint&lhs,const modint&rhs){return lhs.v==rhs.v;}
	friend modint operator!=(const modint&lhs,const modint&rhs){return lhs.v!=rhs.v;}
};
const int P=998244353,G=3;
typedef modint<P> mint;
void ntt(vector<mint>&a,int op){
	int n=a.size(); vector<mint> w(n);
	vector<int> rev(n);
	for(int i=1;i<n;i++) rev[i]=rev[i>>1]>>1|(i&1?n>>1:0);
	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int k=1,len=2;len<=n;k<<=1,len<<=1){
		mint wn=qpow(op==1?mint(G):mint(1)/G,(P-1)/len);
		for(int i=raw(w[0]=1);i<k;i++) w[i]=w[i-1]*wn;
		for(int i=0;i<n;i+=len){
			for(int j=0;j<k;j++){
				mint x=a[i+j],y=a[i+j+k]*w[j];
				a[i+j]=x+y,a[i+j+k]=x-y;
			}
		}
	}
	if(op==-1){mint inv=mint(1)/n; for(mint&x:a) x*=inv;}
}

template<int N> struct C_prime{
    mint fac[N+10],ifac[N+10];
    C_prime(){
        for(int i=raw(fac[0]=1);i<=N;i++) fac[i]=fac[i-1]*i;
        ifac[N]=1/fac[N];for(int i=N;i>=1;i--) ifac[i-1]=ifac[i]*i;
    }
    mint operator()(int n,int m){return n>=m?fac[n]*ifac[m]*ifac[n-m]:0;}
};
int n,a[1<<17];
mint qp2[1<<17];
C_prime<1<<17> binom;
mint solveall(){
	mint ans=0;
	for(int i=1;i<=n;i++) ans+=binom(n,i)*(i/2);
	return ans;
}
mint solveeq(vector<int> pts){
	vector<mint> A(n+1),B(n+1);
	for(int pos:pts) A[pos-1]=mint(1)/binom.fac[pos-1]/qp2[pos+1];
	for(int pos:pts) B[n-pos]=qp2[pos]/binom.fac[n-pos];
	A.resize(glim(n)<<1),B.resize(glim(n)<<1);
	ntt(A,1),ntt(B,1);
	for(int i=0;i<A.size();i++) A[i]*=B[i];
	ntt(A,-1);
	mint res=0;
	for(int i=0;i<n-1;i++) res+=A[i]*binom.fac[i];
	return res;
}
mint solveeqbf(vector<int> pts){
	mint res=0;
	for(int i:pts) for(int j:pts) if(i<j) res+=binom(n-j+i-1,i-1)*qp2[j-i-1];
	return res;
}
vector<int> buc[1<<17];
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d",&n);
	for(int i=raw(qp2[0]=1);i<=n;i++) qp2[i]=qp2[i-1]*2;
	for(int i=1;i<=n;i++) scanf("%d",&a[i]),buc[a[i]].push_back(i);
	mint res=solveall();
	//for(int i=1;i<=n;i++) res-=solveeqbf(buc[i]),debug("size=%d,bf=%d,mul=%d\n",buc[i].size(),solveeqbf(buc[i]),solveeq(buc[i]));
	for(int i=1;i<=n;i++)
		if(buc[i].size()<=1314) res-=solveeqbf(buc[i]);
		else res-=solveeq(buc[i]);
	printf("%d\n",raw(res));
	return 0;
}

posted @ 2023-09-13 20:53  caijianhong  阅读(12)  评论(0编辑  收藏  举报