树位DP
给定一棵n个节点的树,和一个n的排列(b[i]),求树的DFS序中严格小于给定排列的方案数。n<=1e6
这是一道。。树位DP题,我们沿用数位DP的思想,逐位确定。
首先我们考虑最没有限制的情况,如果一个以x为根的树不受限制,它的DFS序有多少种。
这个显然可以换根DP。先进行子树DP。设$f[x]$为答案,那么可得$f[x]=son[x]! \times \prod f[son]$其中son[x]表示儿子个数。
这个可以理解,就是表示从目前到遍历完子树的DFS序可以分成一段一段的,每一段是一个儿子的DFS,然后段与段是排列的关系,因为没有限制。
然后我们就可以再次换根DP找到以x为根的整棵树的答案。
枚举1到b[1]-1把他们的f加入答案。
然后考虑以b[1]为根。
我们现在用一个dfs来解决问题,问题是,以b序列中一个值为根,子树严格小于的方案数。
那么进入dfs,我们目的是甩锅给下一层,然后递归解决。但是有的东西不能甩,本层必须解决。
设当前位是len
先分一下类。
1.从len+1就小于:这种问题本层即可解决,找到接下来第一段的可能情况,也就是有多少儿子<b[len+1],然后假设为cnt,那么第一段有cnt种情况,剩下的仍然是排列和累乘。
2.在某个儿子的子树中开始小于,这是一个递归的问题一会再说。
3.从某个儿子开始不等,比如说前两棵子树都恰好覆盖了一段树上序列,然后接下来选一个小于b序列当前位的儿子作为下一位,那么应该是总儿子数减去已经让它完全覆盖的儿子数,这样得到了剩下可以选的儿子数,然后我在找到可以选的中所有小于b当前位的儿子数,还是第一位的情况*剩下的排列和累乘。
那么我们考虑顺着b数组来捋,解决第二个问题。
我们进行一个儿子次的循环。
循环内部每次找到一个儿子等于b[late],late为上一个儿子覆盖完后到的b序列的位置,第一个则为len+1,相当于给挨个拿儿子往b序列上贴,接下来我们找到了一个和当前问题一样的问题,找到这个儿子在限制下的排列数,果断甩锅,当一个儿子的子树不能完全贴到b上,break。
但是我们遇到了一个问题,怎么判断这个儿子的子树能不能把子树的size个点全贴到b上呢?
我们就需要用一个东西来记录这个儿子的子树是否能够吧子树个size全贴到b上,发现这个也是可以递归解决的,
用dfs返回结构体也好,全局变量修改也罢,总之我的dfs要返回一个flag,表示能不能全贴上,这个flag是1必须是所有的儿子都能按顺序贴到b上,即儿子的flag都是1,具体实现就是我之前的儿子次循环真的进行了儿子次,并没有从中间break掉。当然循环中如果找不到一个儿子等于b[late]也要break。
然后在顺一下思路:分三类,第一类可以一进dfs就算完,第二类是通过枚举儿子是否等于b[late]并dfs判断能不能全部贴到序列上,如果能,我累加儿子子树中开始严格小于的答案,然后接着吧late+=size[son],相当于把这个儿子贴到序列上,然后累加一下第三类答案也就是从下一个儿子处开始小于b的答案,接着找下一个儿子等于b[late]的子树中小于的答案……直至循环结束。
交叉着进行二三类答案的计算。
当递归到叶子节点时,处理flag,如果我的值等于b的当前值为1,否则为0,然后就是返回值,如果我的值小于b[len]那么返回1,表示递归的一条链是可以严格小于的。
然后问题就解决了。
然鹅会T。
观察一下数据范围,1e6,但是在dfs的过程中是对于每个点我进行了两层循环,也就是说每个点被作为儿子枚举了n次,复杂度是$O(n^{2})$的,复杂度瓶颈卡在了我在找一个儿子是否等于b[late]的时候是枚举所有儿子的,接下来就很简单了,用一个数据结构维护一下每个点的儿子,支持删除,和单点,区间查询,splay和sgtree都可以,但我觉得动态开点的sgtree好写(得多)。然后就能A了。
#include<cstring> #include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=300020,mod=1e9+7; int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n; ll f[N],fac[N]; bool v[N]; struct node{int to,pr;}mo[N*2]; long long rd() { long long s=0,w=1; char cc=getchar(); while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();} while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar(); return s*w; } ll inv(ll a) { ll ans=1,k=mod-2; for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod; return ans; } void add(int x,int y) { mo[++tt].to=y; mo[tt].pr=fr[x]; fr[x]=tt; } void first_dfs(int x) { ll ans=1;size[x]=1; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x])continue; son[x]++; fa[to]=x; first_dfs(to); ans=1ll*ans*f[to]%mod; size[x]+=size[to]; } f[x]=1ll*fac[son[x]]*ans%mod; } void re_dfs(int x) { for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; //cout<<x<<" "<<to<<endl; //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl; //cout<<son[to]<<endl; f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod; re_dfs(to); } } ll dfs(int len,int x) { if(son[x]==0) { flag=x==b[len]; return x<b[len]; } long long ans=0,sum=0; for(int i=fr[x];i;i=mo[i].pr) if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++; //cout<<ans<<endl; ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod; //cout<<x<<" "<<ans<<endl; ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x]; //cout<<x<<" "<<" "<<lat<<endl; flag=1; for(int k=1;k<=son[x];k++) { bool jud=0; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; //cout<<to<<" "<<b[lat]<<endl; if(to==b[lat]) { v[to]=1; pi=pi*inv(f[to])%mod; long long tmp=dfs(lat,to); ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod; // cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl; lat=lat+size[to],sum--; jud=1; break; } } if(!flag) break; if(!jud) break; int cnt=0; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; if(v[to]) continue; if(to<b[lat]) cnt++; } //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl; if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod; //cout<<ans<<endl; } if(flag==1&&sum==0) flag=1; else flag=0; return ans; } ll solve() { ll ans=0; first_dfs(1); re_dfs(1); for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod; //cout<<ans<<endl; memset(son,0,sizeof(son)); memset(fa,0,sizeof(fa)); memset(f,0,sizeof(f)); memset(size,0,sizeof(size)); first_dfs(b[1]); ans=(ans+dfs(1,b[1]))%mod; return ans; } int main() { //freopen("travel2.in","r",stdin); //freopen("data1.in","r",stdin); //freopen("data1.out","w",stdout); n=rd();fac[0]=1; for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod; for(int i=1,x,y;i<n;i++) { x=rd(),y=rd(); add(x,y);add(y,x); } printf("%lld\n",solve()); } /* g++ -std=c++11 1.cpp -o 1 ./1 6 1 3 6 2 5 4 1 2 1 3 1 4 4 5 1 6 */
#include<cstring> #include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=300020,mod=1e9+7; int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n; ll f[N],fac[N]; bool v[N]; struct node{int to,pr;}mo[N*2]; long long rd() { long long s=0,w=1; char cc=getchar(); while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();} while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar(); return s*w; } ll inv(ll a) { ll ans=1,k=mod-2; for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod; return ans; } void add(int x,int y) { mo[++tt].to=y; mo[tt].pr=fr[x]; fr[x]=tt; } void first_dfs(int x) { ll ans=1;size[x]=1; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x])continue; son[x]++; fa[to]=x; first_dfs(to); ans=1ll*ans*f[to]%mod; size[x]+=size[to]; } f[x]=1ll*fac[son[x]]*ans%mod; } void re_dfs(int x) { for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; //cout<<x<<" "<<to<<endl; //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl; //cout<<son[to]<<endl; f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod; re_dfs(to); } } ll dfs(int len,int x) { if(son[x]==0) { flag=x==b[len]; return x<b[len]; } long long ans=0,sum=0; for(int i=fr[x];i;i=mo[i].pr) if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++; //cout<<ans<<endl; ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod; //cout<<x<<" "<<ans<<endl; ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x]; //cout<<x<<" "<<" "<<lat<<endl; flag=1; for(int k=1;k<=son[x];k++) { bool jud=0; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; //cout<<to<<" "<<b[lat]<<endl; if(to==b[lat]) { v[to]=1; pi=pi*inv(f[to])%mod; long long tmp=dfs(lat,to); ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod; // cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl; lat=lat+size[to],sum--; jud=1; break; } } if(!flag) break; if(!jud) break; int cnt=0; for(int i=fr[x];i;i=mo[i].pr) { int to=mo[i].to; if(to==fa[x]) continue; if(v[to]) continue; if(to<b[lat]) cnt++; } //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl; if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod; //cout<<ans<<endl; } if(flag==1&&sum==0) flag=1; else flag=0; return ans; } ll solve() { ll ans=0; first_dfs(1); re_dfs(1); for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod; //cout<<ans<<endl; memset(son,0,sizeof(son)); memset(fa,0,sizeof(fa)); memset(f,0,sizeof(f)); memset(size,0,sizeof(size)); first_dfs(b[1]); ans=(ans+dfs(1,b[1]))%mod; return ans; } int main() { //freopen("travel2.in","r",stdin); //freopen("data1.in","r",stdin); //freopen("data1.out","w",stdout); n=rd();fac[0]=1; for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod; for(int i=1,x,y;i<n;i++) { x=rd(),y=rd(); add(x,y);add(y,x); } printf("%lld\n",solve()); } /* g++ -std=c++11 1.cpp -o 1 ./1 6 1 3 6 2 5 4 1 2 1 3 1 4 4 5 1 6 */