【XSY2887】【GDOI2018】小学生图论题 分治FFT 多项式exp

题目描述

  在一个 \(n\) 个点的有向图中,编号从 \(1\)\(n\),任意两个点之间都有且仅有一条有向边。现在已知一些单向的简单路径(路径上任意两点各不相同),例如 \(2\to 4\to 1\)。且已知的这些简单路径之间没有公共的顶点,其
余的边的方向等概率随机。

  你需要求出强连通分量(如果同时存在 \(a\)\(b\)\(b\)\(a\) 的有向路径,则 \(a\), \(b\) 属于同一个强联通分量) 的期望个数。如果最后答案是 \(\frac{A}{B}\),则输出 \(A \times B^{-1} \bmod 998244353\)\(B^{-1}\) 表示 \(B\) 在模 \(998244353\) 意义下的逆元。

  \(n\leq 100000\)

题解

  直接做好像不太好做。

  考虑整张图缩点后长什么样。

  

  可以发现,强连通分量个数\(=\)关键边(红色的箭头指的那些边)个数\(+1\)

  如果\(m=0\),那么一条边都没有确定。

  枚举一条关键边左边有多少点,那么这条关键边左边的点连到右边的点的边的方向都是确定的。答案是

\[\sum_{i=1}^{n-1}\binom{n}{i}{(\frac{1}{2})}^{i(n-i)} \]

  如果有些边已经确定,那么可以做一个背包DP。

  正解是用一个多项式表示一个路径,一个长度为\(k\)的路径对应的多项式是

\[1+2x+2x^2+\cdots+2x^{k-1}+x^k \]

  设把所有多项式乘起来后的多项式是\(\sum_{i=0}^na_ix^i\),那么答案是

\[\sum_{i=1}^{n-1}a_i{(\frac{1}{2})}^{i(n-i)} \]

  为什么这样是对的?

  如果一条路径在这条关键边左边的点数在\(1\sim k-1\)之间,那么这条路径的一条边会从左边连到右边,会消掉一个\(\frac{1}{2}\)

  可以分治FFT做。时间复杂度:\(O(n\log^2 n)\)

  其实还可以继续推下去。

\[\begin{align} &1+2x+2x^2+\cdots+2x^{k-1}+x^k\\ =&(1+x)(1+x+\cdots +x^{k-1})\\ =&\frac{(1+x)(1-x^k)}{1-x} \end{align} \]

  三个部分的\(\ln\)都是可以快速求的。

  然后\(\exp\)回来就行了。

  时间复杂度:\(O(n\log n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
void put(int x)
{
	if(!x)
	{
		putchar('0');
		return;
	}
	static int c[20];
	int t=0;
	while(x)
	{
		c[++t]=x%10;
		x/=10;
	}
	while(t)
		putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const ll p=998244353;
const int W=262144;
const int N=300000;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll w[N];
ll inv[N];
void ntt(ll *a,int n,int t)
{
	static int rev[N];
	rev[0]=0;
	for(int i=1;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		if(rev[i]>i)
			swap(a[i],a[rev[i]]);
	}
	for(int i=2;i<=n;i<<=1)
		for(int j=0;j<n;j+=i)
			for(int k=0;k<i/2;k++)
			{
				ll u=a[j+k];
				ll v=a[j+k+i/2]*w[W/i*k]%p;
				a[j+k]=(u+v)%p;
				a[j+k+i/2]=(u-v)%p;
			}
	if(t==-1)
	{
		reverse(a+1,a+n);
		ll inv=fp(n,p-2);
		for(int i=0;i<n;i++)
			a[i]=a[i]*inv%p;
	}
}
void mul(ll *a,ll *b,ll *c,int n,int m,int l)
{
	static ll a1[N],a2[N];
	if(l==-1)
		l=n+m;
	n=min(n,l);
	l=min(m,l);
	int k=1;
	while(k<=n+m)
		k<<=1;
	for(int i=0;i<=n;i++)
		a1[i]=a[i];
	for(int i=n+1;i<k;i++)
		a1[i]=0;
	for(int i=0;i<=m;i++)
		a2[i]=b[i];
	for(int i=m+1;i<k;i++)
		a2[i]=0;
	ntt(a1,k,1);
	ntt(a2,k,1);
	for(int i=0;i<k;i++)
		a1[i]=a1[i]*a2[i]%p;
	ntt(a1,k,-1);
	for(int i=0;i<=l;i++)
		c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
	if(n==1)
	{
		b[0]=fp(a[0],p-2);
		return;
	}
	getinv(a,b,n>>1);
	static ll a1[N],a2[N];
	for(int i=0;i<n;i++)
		a1[i]=a[i];
	for(int i=n;i<n<<1;i++)
		a1[i]=0;
	for(int i=0;i<n>>1;i++)
		a2[i]=b[i];
	for(int i=n>>1;i<n<<1;i++)
		a2[i]=0;
	ntt(a1,n<<1,1);
	ntt(a2,n<<1,1);
	for(int i=0;i<n<<1;i++)
		a1[i]=a2[i]*(2-a1[i]*a2[i]%p)%p;
	ntt(a1,n<<1,-1);
	for(int i=0;i<n;i++)
		b[i]=a1[i];
}
void getln(ll *a,ll *b,int n)
{
	static ll a1[N],a2[N];
	for(int i=1;i<n;i++)
		a1[i-1]=a[i]*i%p;
	a1[n-1]=0;
	getinv(a,a2,n);
	mul(a1,a2,a1,n-1,n-1,n-1);
	for(int i=1;i<n;i++)
		b[i]=a1[i-1]*inv[i]%p;
	b[0]=0;
}
void getexp(ll *a,ll *b,int n)
{
	if(n==1)
	{
		b[0]=1;
		return;
	}
	getexp(a,b,n>>1);
	static ll a1[N],a2[N],a3[N];
	for(int i=n>>1;i<n;i++)
		b[i]=0;
	getln(b,a1,n);
	for(int i=0;i<n>>1;i++)
	{
		a2[i]=b[i];
		a3[i]=a[i+(n>>1)]-a1[i+(n>>1)];
	}
	for(int i=n>>1;i<n;i++)
		a2[i]=a3[i]=0;
	ntt(a2,n,1);
	ntt(a3,n,1);
	for(int i=0;i<n;i++)
		a2[i]=a2[i]*a3[i]%p;
	ntt(a2,n,-1);
	for(int i=0;i<n>>1;i++)
		b[i+(n>>1)]=a2[i];
}
int n,m;
void init()
{
	w[0]=1;
	w[1]=fp(3,(p-1)/W);
	inv[1]=1;
	for(int i=2;i<=W;i++)
	{
		w[i]=w[i-1]*w[1]%p;
		inv[i]=-p/i*inv[p%i]%p;
	}
}
ll a[N],b[N],c[N];
int main()
{
	open("graph");
	init();
	scanf("%d%d",&n,&m);
	int k;
	int sum=n;
	for(int i=1;i<=m;i++)
	{
		k=rd();
		for(int j=1;j<=k;j++)
			rd();
		sum-=k;
		c[k]++;
	}
	for(int i=1;i<=n;i++)
		if(c[i])
			for(int j=1;i*j<=n;j++)
				a[i*j]=(a[i*j]-c[i]*inv[j])%p;
	for(int i=1;i<=n;i++)
	{
		a[i]=(a[i]+m*inv[i])%p;
		a[i]=(a[i]-(sum+m)*(i&1?-1:1)*inv[i])%p;
	}
	int l=1;
	while(l<=n)
		l<<=1;
	getexp(a,b,l);
	ll ans=0;
	for(int i=1;i<n;i++)
		ans=(ans+b[i]*fp(inv[2],((ll)i*(n-i)%(p-1))))%p;
	ans++;
	ans=(ans+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-05-04 10:52  ywwyww  阅读(615)  评论(0编辑  收藏  举报