loj#575. 「LibreOJ NOI Round #2」不等关系

loj#575. 「LibreOJ NOI Round #2」不等关系

首先考虑如果点\(i\)\(i+1\)中连接的是\(>\),那么可以考虑从\(i\)\(i+1\)连一条边,否则从\(i+1\)\(i\)连边,如果我们先不考虑\(<\),那么就变成了一条条从大往小链接的链,这种图赋权值的方案数就是

\[\frac{n!}{\prod siz_i} \]

那考虑有反向边,可以考虑容斥:原方案数-一条反向边方案数+两条反向边方案数......

\(f_i\)表示考虑前\(i\)个点,先不考虑阶乘的方案数。

那么注意到如果\(i\)在不考虑反向边的时候不是链首(链的最后一个),很显然没有意义,因此我们考虑只对于链首计算。

考虑从\(i\)往前的一条弱联通意义下的链,如果我们加入了\(k\)条反向边,那么根据容斥,我们最终的方案数就是原本这条弱联通链的方案数乘上\((-1)^k\)

那么就很好转移了,直接枚举上一条链的链首\(j\)。设前\(i\)条边中共有\(cnt_i\)条正向边。那么有

\[f(i)=\sum_{j=0}^{i-1}{f(j)\times (-1)^{cnt[i-1]-cnt[j]}\times\frac{1}{(i-j)!}} \]

\((-1)^{cnt[i-1]}\)是个常数,因此可以提出来,那么变成了

\[f(i)=(-1)^{cnt[i-1]}\sum_{j=0}^{i-1}{f(j)\times (-1)^{-cnt[j]}\times\frac{1}{(i-j)!}} \]

考虑设\(F(x)=f(x)*(-1)^{-cnt[x]},G[x]=\frac{1}{x!}\),那么原式即为

\[F(i)=(-1)^{cnt[i-1]+cnt[i]}\times\sum_{j}^{i-1}F(j)\times G(i-j) \]

这是个卷积形式,分治\(NTT\)即可。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define N 500005
#define MAXN 500000
#define mod 998244353
#define gi 332748118
#define g 3
#define pb push_back
#define int long long
using namespace std;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return x*f;
}
int n,cnt[N],fac[N],ifac[N],inv[N],pw[N],rev[N];
char s[N];
vector<int>F,G,S,T;
int ksm(int a,int b)
{
	int res=1;
	while(b)
	{
		if(b&1)res*=a,res%=mod;
		a*=a;a%=mod;b>>=1;
	}
	return res;
}
int get_limit(int x)
{
	int limit=1;while(limit<=x)limit<<=1;
	for(int i=0;i<limit;i++)rev[i]=((rev[i>>1]>>1)|((i&1)?limit>>1:0));
	return limit;
}
void NTT(vector<int>&a,int limit,int type)
{
	for(int i=0;i<limit;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int mid=1;mid<limit;mid<<=1)
	{
		int Wn=ksm(type==1?g:gi,(mod-1)/(mid<<1));
		for(int j=0;j<limit;j+=(mid<<1))
		{
			int w=1;
			for(int k=0;k<mid;k++,w=(w*Wn%mod)%mod)
			{
				int x=a[j+k],y=w*a[j+k+mid]%mod;
				a[j+k]=(x+y)%mod;
				a[j+k+mid]=(x-y+mod)%mod;
			}
		}
	}
	if(type==-1)
	{
		int INV=ksm(limit,mod-2);
		for(int i=0;i<limit;i++)a[i]=a[i]*INV%mod;
	}
}
vector<int> operator*(vector<int>&a,vector<int>&b)
{
	int len=a.size()+b.size()-1;
	int limit=get_limit(len);
	a.resize(limit);b.resize(limit);
	NTT(a,limit,1);NTT(b,limit,1);
	for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod;
	NTT(a,limit,-1);a.resize(len);
	return a;
}
void solve(int l,int r)
{
	if(l==r)return;
	//cout<<l<<" "<<r<<endl;
	//for(int i=0;i<F.size();i++)printf("%d ",F[i]);
	//puts("");
	//for(int i=0;i<F.size();i++)printf("%d ",G[i]);
	//puts("");
	int mid=(l+r)>>1;
	solve(l,mid);
	S.clear();T.clear();
	for(int i=l;i<=mid;i++)
	{
		if(s[i]=='>')S.pb(F[i]);
		else S.pb(0);
		T.pb(G[i-l]);
	}
	for(int i=mid+1;i<=r;i++)S.pb(0),T.pb(G[i-l]);
	S=S*T;
	for(int i=mid+1;i<=r;i++)F[i]=(F[i]+S[i-l]*pw[cnt[i-1]]%mod*pw[cnt[i]]%mod)%mod;
	solve(mid+1,r);
}
signed main()
{
	scanf("%s",(s+1));s[0]='>';cnt[0]=1;n=strlen(s+1);
	fac[0]=fac[1]=ifac[1]=ifac[0]=inv[1]=1;pw[0]=1;pw[1]=(-1+mod)%mod;
	for(int i=2;i<=MAXN;i++)
	{
		pw[i]=(-pw[i-1]+mod)%mod;
		inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		fac[i]=fac[i-1]*i%mod;
		ifac[i]=ifac[i-1]*inv[i]%mod;
	}
	for(int i=1;i<=n+1;i++)cnt[i]=cnt[i-1]+(s[i]=='>');
	F.pb(-1);
	for(int i=1;i<=n+1;i++)F.pb(0);
	for(int i=0;i<=n+1;i++)G.pb(ifac[i]);
	//return 0;
	solve(0,n+1);
	//for(int i=0;i<=n+1;i++)printf("%d ",pw[i]);
	//puts("");
	printf("%d\n",fac[n+1]*pw[cnt[n+1]]%mod*F[n+1]%mod); 
}

posted @ 2021-03-03 20:46  shao0320  阅读(179)  评论(0编辑  收藏  举报
****************************************** 页脚Html代码 ******************************************