[CSP-S模拟测试]:旅行(数学+线段树)
题目传送门(内部题12)
输入格式
第一行,一个整数$n$,代表树的点数。
第二行,$n$个整数,第$i$个整数是$B_i$,描述排列$B$。
接下来$n−1$行,每行两个整数$u,v$,描述一条树边$(u,v)$。
保证$1\leqslant B_i\leqslant n$,$1\leqslant u\neq v\leqslant n$。保证数据合法。
输出格式
输出一个整数表示答案对${10}^9+7$取模的值。
样例
样例输入1:
5
2 1 3 5 4
1 2
2 3
2 4
4 5
样例输出1:
3
样例输入2:
6
6 4 5 3 2 1
1 2
2 3
3 4
4 5
5 6
样例输出2:
9
数据范围与提示
样例$1$解释:
满足条件的数列$A$分别是:
•$(1,2,3,4,5)$
•$(1,2,4,5,3)$
•$(2,1,3,4,5)$
数据范围:
对于所有数据,$1\leqslant n\leqslant 300,000$。
题解
首先,如果$dfs$的根小于$B[1]$,那么要求的就是以这个点为根的$dfs$序的个数,而对于一棵有根树,设其根为$i$,$dfs$序的个数为$f(i)$,则有:
$f(u)=|son(u)|!\prod \limits_{v\in son(u)}f(v)$
我们可以任选一个根求出$f$,然后计算其祖先贡献即可。
那么我们现在考虑以$B[1]$为根该怎么办?
思想和上述方法类似,即如果当前在$B_i$点,要走到$B_i+1$点,需要求出所有第$i+1$位小于$B_{i+1}$的方案数,简单计算即可。
为了通过最后两个点,可以用线段树或树状数组快速求出某个点下某棵子树的名次。
时间复杂度:$\Theta(n\log n)$。
期望得分:$100$分。
实际得分:$100$分。
代码时刻
#include<bits/stdc++.h>
using namespace std;
struct rec
{
int nxt;
int to;
}e[600000];
int head[300001],cnt;
int n;
int du[300001];
long long jc[300001],inv[300001],dp[300001],son[300001];
bool vis[300001],jump;
int rt[300001],b[300001],trval[1300000],trwzc[1300000],trsum[1300000],ls[1300000],rs[1300000],tot;
int dfn=1;
map<pair<int,int>,bool> mp;
long long ans,res=1,now;
void add(int x,int y)
{
e[++cnt].nxt=head[x];
e[cnt].to=y;
head[x]=cnt;
}
long long qpow(long long x,long long y)
{
long long res=1;
while(y)
{
if(y&1)res=res*x%1000000007;
x=x*x%1000000007;
y>>=1;
}
return res;
}
void pre_work()
{
jc[0]=inv[0]=1;
for(int i=1;i<=n;i++)
{
jc[i]=1LL*i*jc[i-1]%1000000007;
inv[i]=qpow(jc[i],1000000005);
}
}
void pushup(int x)
{
trsum[x]=trwzc[x];
if(ls[x])trsum[x]+=trsum[ls[x]];
if(rs[x])trsum[x]+=trsum[rs[x]];
}
void insert(int &x,int w)
{
if(!x)
{
x=++tot;
trval[x]=w;
trwzc[x]=trsum[x]=1;
return;
}
if(trval[x]>w)insert(ls[x],w);
else insert(rs[x],w);
pushup(x);
}
void change(int x,int w)
{
if(trval[x]==w)
{
trwzc[x]=0;
pushup(x);
return;
}
if(trval[x]>w)change(ls[x],w);
else change(rs[x],w);
pushup(x);
}
int find(int x,int w)
{
if(!x)return 0;
if(trval[x]==w)return trsum[ls[x]];
if(trval[x]>w)return find(ls[x],w);
return trsum[ls[x]]+trwzc[x]+find(rs[x],w);
}
void pre_dfs(int x)
{
dp[x]=1;
vis[x]=1;
for(int i=head[x];i;i=e[i].nxt)
if(!vis[e[i].to])
{
mp[make_pair(x,e[i].to)]=1;
insert(rt[x],e[i].to);
pre_dfs(e[i].to);
son[x]++;
dp[x]=dp[x]*dp[e[i].to]%1000000007;
}
dp[x]=dp[x]*jc[son[x]]%1000000007;
}
void pro_dfs(int x)
{
long long ez=son[x];
long long flag=0;
while(1)
{
if(!ez)break;
flag=find(rt[x],b[dfn+1]);
now=now*inv[ez]%1000000007*jc[ez-1]%1000000007;
ans=(ans+now*flag%1000000007)%1000000007;
if(mp[make_pair(x,b[dfn+1])])
{
dfn++;
ez--;
change(rt[x],b[dfn]);
pro_dfs(b[dfn]);
if(jump)break;
}else{jump=1;break;}
}
}
int main()
{
scanf("%d",&n);
pre_work();
for(int i=1;i<=n;i++)
scanf("%d",&b[i]);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
du[x]++;
du[y]++;
}
long long res=jc[du[1]];
for(int i=2;i<=n;i++)
res=res*jc[du[i]-1]%1000000007;
if(b[1]!=1)ans=res;
for(int i=2;i<b[1];i++)
{
res=res*inv[du[i-1]]%1000000007*inv[du[i]-1]%1000000007*jc[du[i-1]-1]%1000000007*jc[du[i]]%1000000007;
ans=(ans+res)%1000000007;
}
pre_dfs(b[1]);
now=dp[b[1]];
pro_dfs(b[1]);
printf("%lld",ans);
return 0;
}
rp++