洛谷3244 落忆枫音 (拓扑图dp+式子)

题目大意就是 给你一个DAG

然后添加一条边\(x->y\) ,询问以1为根的生成树的个数

QWQ

首先假设没有添加的边

答案就应该是

\[ans=\prod_{i=1}^{n} in[i] \]

QWQ就相当于每个点选择一个父亲。

那么加入一条边,我们会有一些不合法的情况,那就是包含一条\(y->x\)路径,剩下随便选的方案数。假设全集是\(C\),然后路径上的点的集合是\(S\),那我们实际上求的就是$$\frac{F(C)}{F(S)}$$
其中\(F(S)\)表示\(S\)集合中所有点的入度的乘积

然后对于这个东西,我们可以考虑拓扑图上dp的方式
来解决

//假设我们添加了一条x->y的边,要想不合法,就是求y->x的路径条数
//所以我们要将令起点,也就是y的初值f[y]=ans
 
void addedge(int x,int y)
{
	nxt[++cnt]=point[x];
	to[cnt]=y;
	in[y]++;
	point[x]=cnt;
}

int qsm(int i,int j)
{
	int ans=1;
	while (j)
	{
		if (j&1) ans=ans*i%mod;
		i=i*i%mod;
		j>>=1;
	}
	return ans;
}

void tpsort()
{
	//cout<<ans<<endl;
	for (int i=1;i<=n;i++)
	{
		if (!in[i]) q.push(i);
	}
	while (!q.empty())
	{
		int now = q.front();
		q.pop();
		//cout<<now<<endl;
		//int ymh=0;
		//if (now==y) ymh=1; 
		f[now]=f[now]*qsm(d[now],mod-2)%mod;
		
		//cout<<now<<" "<<f[now]<<endl;
		for (int i=point[now];i;i=nxt[i])
		{
			int p =to[i];
			in[p]--;
		    f[p]=(f[p]+f[now])%mod;
		    if (!in[p]) q.push(p);
		}
	}
}

下面是整个的代码

// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long

using namespace std;

inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}

const int maxn = 2e5+1e2;
const int maxm = 2*maxn;
const int mod = 1e9+7;

int point[maxn],nxt[maxm],to[maxm];
int n,m;
int cnt,in[maxn];
queue<int> q;
int ans;
int f[maxn];
int x,y;
int d[maxn];

//假设我们添加了一条x->y的边,要想不合法,就是求y->x的路径条数
//所以我们要将令起点,也就是y的初值f[y]=ans
 
void addedge(int x,int y)
{
	nxt[++cnt]=point[x];
	to[cnt]=y;
	in[y]++;
	point[x]=cnt;
}

int qsm(int i,int j)
{
	int ans=1;
	while (j)
	{
		if (j&1) ans=ans*i%mod;
		i=i*i%mod;
		j>>=1;
	}
	return ans;
}

void tpsort()
{
	//cout<<ans<<endl;
	for (int i=1;i<=n;i++)
	{
		if (!in[i]) q.push(i);
	}
	while (!q.empty())
	{
		int now = q.front();
		q.pop();
		//cout<<now<<endl;
		//int ymh=0;
		//if (now==y) ymh=1; 
		f[now]=f[now]*qsm(d[now],mod-2)%mod;
		
		//cout<<now<<" "<<f[now]<<endl;
		for (int i=point[now];i;i=nxt[i])
		{
			int p =to[i];
			in[p]--;
		    f[p]=(f[p]+f[now])%mod;
		    if (!in[p]) q.push(p);
		}
	}
}

signed main()
{
  n=read(),m=read(),x=read(),y=read();
  for (int i=1;i<=m;i++)
  {
  	 int u=read(),v=read();
  	 addedge(u,v);
  }
  ans=1;
  for (int i=2;i<=n;i++)
  {
  	if (i==y) ans=ans*(in[i]+1)%mod,d[i]=in[i]+1;
  	else ans=ans*in[i]%mod,d[i]=in[i];
  }
  f[y]=ans;
  if (x==1)
  {
  	cout<<ans<<"\n";
  	return 0;
  }
  tpsort();
  cout<<(ans-f[x]+mod)%mod<<endl;
  return 0;
}

posted @ 2018-12-22 16:16  y_immortal  阅读(169)  评论(0编辑  收藏  举报