【XSY2887】【GDOI2018】小学生图论题 分治FFT 多项式exp
题目描述
在一个 \(n\) 个点的有向图中,编号从 \(1\) 到 \(n\),任意两个点之间都有且仅有一条有向边。现在已知一些单向的简单路径(路径上任意两点各不相同),例如 \(2\to 4\to 1\)。且已知的这些简单路径之间没有公共的顶点,其
余的边的方向等概率随机。
你需要求出强连通分量(如果同时存在 \(a\) 到 \(b\), \(b\) 到 \(a\) 的有向路径,则 \(a\), \(b\) 属于同一个强联通分量) 的期望个数。如果最后答案是 \(\frac{A}{B}\),则输出 \(A \times B^{-1} \bmod 998244353\), \(B^{-1}\) 表示 \(B\) 在模 \(998244353\) 意义下的逆元。
\(n\leq 100000\)
题解
直接做好像不太好做。
考虑整张图缩点后长什么样。
可以发现,强连通分量个数\(=\)关键边(红色的箭头指的那些边)个数\(+1\)。
如果\(m=0\),那么一条边都没有确定。
枚举一条关键边左边有多少点,那么这条关键边左边的点连到右边的点的边的方向都是确定的。答案是
如果有些边已经确定,那么可以做一个背包DP。
正解是用一个多项式表示一个路径,一个长度为\(k\)的路径对应的多项式是
设把所有多项式乘起来后的多项式是\(\sum_{i=0}^na_ix^i\),那么答案是
为什么这样是对的?
如果一条路径在这条关键边左边的点数在\(1\sim k-1\)之间,那么这条路径的一条边会从左边连到右边,会消掉一个\(\frac{1}{2}\)。
可以分治FFT做。时间复杂度:\(O(n\log^2 n)\)
其实还可以继续推下去。
三个部分的\(\ln\)都是可以快速求的。
然后\(\exp\)回来就行了。
时间复杂度:\(O(n\log n)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
do
{
s=s*10+c-'0';
}
while((c=getchar())>='0'&&c<='9');
return s;
}
void put(int x)
{
if(!x)
{
putchar('0');
return;
}
static int c[20];
int t=0;
while(x)
{
c[++t]=x%10;
x/=10;
}
while(t)
putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
if(b<a)
{
a=b;
return 1;
}
return 0;
}
int upmax(int &a,int b)
{
if(b>a)
{
a=b;
return 1;
}
return 0;
}
const ll p=998244353;
const int W=262144;
const int N=300000;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
ll w[N];
ll inv[N];
void ntt(ll *a,int n,int t)
{
static int rev[N];
rev[0]=0;
for(int i=1;i<n;i++)
{
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
if(rev[i]>i)
swap(a[i],a[rev[i]]);
}
for(int i=2;i<=n;i<<=1)
for(int j=0;j<n;j+=i)
for(int k=0;k<i/2;k++)
{
ll u=a[j+k];
ll v=a[j+k+i/2]*w[W/i*k]%p;
a[j+k]=(u+v)%p;
a[j+k+i/2]=(u-v)%p;
}
if(t==-1)
{
reverse(a+1,a+n);
ll inv=fp(n,p-2);
for(int i=0;i<n;i++)
a[i]=a[i]*inv%p;
}
}
void mul(ll *a,ll *b,ll *c,int n,int m,int l)
{
static ll a1[N],a2[N];
if(l==-1)
l=n+m;
n=min(n,l);
l=min(m,l);
int k=1;
while(k<=n+m)
k<<=1;
for(int i=0;i<=n;i++)
a1[i]=a[i];
for(int i=n+1;i<k;i++)
a1[i]=0;
for(int i=0;i<=m;i++)
a2[i]=b[i];
for(int i=m+1;i<k;i++)
a2[i]=0;
ntt(a1,k,1);
ntt(a2,k,1);
for(int i=0;i<k;i++)
a1[i]=a1[i]*a2[i]%p;
ntt(a1,k,-1);
for(int i=0;i<=l;i++)
c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
if(n==1)
{
b[0]=fp(a[0],p-2);
return;
}
getinv(a,b,n>>1);
static ll a1[N],a2[N];
for(int i=0;i<n;i++)
a1[i]=a[i];
for(int i=n;i<n<<1;i++)
a1[i]=0;
for(int i=0;i<n>>1;i++)
a2[i]=b[i];
for(int i=n>>1;i<n<<1;i++)
a2[i]=0;
ntt(a1,n<<1,1);
ntt(a2,n<<1,1);
for(int i=0;i<n<<1;i++)
a1[i]=a2[i]*(2-a1[i]*a2[i]%p)%p;
ntt(a1,n<<1,-1);
for(int i=0;i<n;i++)
b[i]=a1[i];
}
void getln(ll *a,ll *b,int n)
{
static ll a1[N],a2[N];
for(int i=1;i<n;i++)
a1[i-1]=a[i]*i%p;
a1[n-1]=0;
getinv(a,a2,n);
mul(a1,a2,a1,n-1,n-1,n-1);
for(int i=1;i<n;i++)
b[i]=a1[i-1]*inv[i]%p;
b[0]=0;
}
void getexp(ll *a,ll *b,int n)
{
if(n==1)
{
b[0]=1;
return;
}
getexp(a,b,n>>1);
static ll a1[N],a2[N],a3[N];
for(int i=n>>1;i<n;i++)
b[i]=0;
getln(b,a1,n);
for(int i=0;i<n>>1;i++)
{
a2[i]=b[i];
a3[i]=a[i+(n>>1)]-a1[i+(n>>1)];
}
for(int i=n>>1;i<n;i++)
a2[i]=a3[i]=0;
ntt(a2,n,1);
ntt(a3,n,1);
for(int i=0;i<n;i++)
a2[i]=a2[i]*a3[i]%p;
ntt(a2,n,-1);
for(int i=0;i<n>>1;i++)
b[i+(n>>1)]=a2[i];
}
int n,m;
void init()
{
w[0]=1;
w[1]=fp(3,(p-1)/W);
inv[1]=1;
for(int i=2;i<=W;i++)
{
w[i]=w[i-1]*w[1]%p;
inv[i]=-p/i*inv[p%i]%p;
}
}
ll a[N],b[N],c[N];
int main()
{
open("graph");
init();
scanf("%d%d",&n,&m);
int k;
int sum=n;
for(int i=1;i<=m;i++)
{
k=rd();
for(int j=1;j<=k;j++)
rd();
sum-=k;
c[k]++;
}
for(int i=1;i<=n;i++)
if(c[i])
for(int j=1;i*j<=n;j++)
a[i*j]=(a[i*j]-c[i]*inv[j])%p;
for(int i=1;i<=n;i++)
{
a[i]=(a[i]+m*inv[i])%p;
a[i]=(a[i]-(sum+m)*(i&1?-1:1)*inv[i])%p;
}
int l=1;
while(l<=n)
l<<=1;
getexp(a,b,l);
ll ans=0;
for(int i=1;i<n;i++)
ans=(ans+b[i]*fp(inv[2],((ll)i*(n-i)%(p-1))))%p;
ans++;
ans=(ans+p)%p;
printf("%lld\n",ans);
return 0;
}