树位DP

  给定一棵n个节点的树,和一个n的排列(b[i]),求树的DFS序中严格小于给定排列的方案数。n<=1e6

  这是一道。。树位DP题,我们沿用数位DP的思想,逐位确定。

  首先我们考虑最没有限制的情况,如果一个以x为根的树不受限制,它的DFS序有多少种。

  这个显然可以换根DP。先进行子树DP。设$f[x]$为答案,那么可得$f[x]=son[x]! \times \prod f[son]$其中son[x]表示儿子个数。

  这个可以理解,就是表示从目前到遍历完子树的DFS序可以分成一段一段的,每一段是一个儿子的DFS,然后段与段是排列的关系,因为没有限制。

  然后我们就可以再次换根DP找到以x为根的整棵树的答案。

  枚举1到b[1]-1把他们的f加入答案。

  然后考虑以b[1]为根。

  我们现在用一个dfs来解决问题,问题是,以b序列中一个值为根,子树严格小于的方案数。

  那么进入dfs,我们目的是甩锅给下一层,然后递归解决。但是有的东西不能甩,本层必须解决。

  设当前位是len

  先分一下类。

  1.从len+1就小于:这种问题本层即可解决,找到接下来第一段的可能情况,也就是有多少儿子<b[len+1],然后假设为cnt,那么第一段有cnt种情况,剩下的仍然是排列和累乘。

  2.在某个儿子的子树中开始小于,这是一个递归的问题一会再说。

  3.从某个儿子开始不等,比如说前两棵子树都恰好覆盖了一段树上序列,然后接下来选一个小于b序列当前位的儿子作为下一位,那么应该是总儿子数减去已经让它完全覆盖的儿子数,这样得到了剩下可以选的儿子数,然后我在找到可以选的中所有小于b当前位的儿子数,还是第一位的情况*剩下的排列和累乘。

  那么我们考虑顺着b数组来捋,解决第二个问题。

  我们进行一个儿子次的循环。

  循环内部每次找到一个儿子等于b[late],late为上一个儿子覆盖完后到的b序列的位置,第一个则为len+1,相当于给挨个拿儿子往b序列上贴,接下来我们找到了一个和当前问题一样的问题,找到这个儿子在限制下的排列数,果断甩锅,当一个儿子的子树不能完全贴到b上,break。

  但是我们遇到了一个问题,怎么判断这个儿子的子树能不能把子树的size个点全贴到b上呢? 

  我们就需要用一个东西来记录这个儿子的子树是否能够吧子树个size全贴到b上,发现这个也是可以递归解决的,

  用dfs返回结构体也好,全局变量修改也罢,总之我的dfs要返回一个flag,表示能不能全贴上,这个flag是1必须是所有的儿子都能按顺序贴到b上,即儿子的flag都是1,具体实现就是我之前的儿子次循环真的进行了儿子次,并没有从中间break掉。当然循环中如果找不到一个儿子等于b[late]也要break。

  然后在顺一下思路:分三类,第一类可以一进dfs就算完,第二类是通过枚举儿子是否等于b[late]并dfs判断能不能全部贴到序列上,如果能,我累加儿子子树中开始严格小于的答案,然后接着吧late+=size[son],相当于把这个儿子贴到序列上,然后累加一下第三类答案也就是从下一个儿子处开始小于b的答案,接着找下一个儿子等于b[late]的子树中小于的答案……直至循环结束。

  交叉着进行二三类答案的计算。

  当递归到叶子节点时,处理flag,如果我的值等于b的当前值为1,否则为0,然后就是返回值,如果我的值小于b[len]那么返回1,表示递归的一条链是可以严格小于的。

  然后问题就解决了。

  然鹅会T。

  观察一下数据范围,1e6,但是在dfs的过程中是对于每个点我进行了两层循环,也就是说每个点被作为儿子枚举了n次,复杂度是$O(n^{2})$的,复杂度瓶颈卡在了我在找一个儿子是否等于b[late]的时候是枚举所有儿子的,接下来就很简单了,用一个数据结构维护一下每个点的儿子,支持删除,和单点,区间查询,splay和sgtree都可以,但我觉得动态开点的sgtree好写(得多)。然后就能A了。

  

#include<cstring>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const int N=300020,mod=1e9+7;
int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n;
ll f[N],fac[N];
bool v[N];
struct node{int to,pr;}mo[N*2];
long long rd()
{
    long long s=0,w=1;
    char cc=getchar();
    while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();}
    while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar();
    return s*w;
}
ll inv(ll a)
{
    ll ans=1,k=mod-2;
    for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod;
    return ans;
}
void add(int x,int y)
{
    mo[++tt].to=y;
    mo[tt].pr=fr[x];
    fr[x]=tt;
}
void first_dfs(int x)
{
    ll ans=1;size[x]=1;
    for(int i=fr[x];i;i=mo[i].pr)
    {
        int to=mo[i].to;
        if(to==fa[x])continue;
        son[x]++;
        fa[to]=x;
        first_dfs(to);
        ans=1ll*ans*f[to]%mod;
        size[x]+=size[to];
    }
    f[x]=1ll*fac[son[x]]*ans%mod;
}
void re_dfs(int x)
{
    for(int i=fr[x];i;i=mo[i].pr)
    {
        int to=mo[i].to;
        if(to==fa[x]) continue;
        //cout<<x<<" "<<to<<endl;
        //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl;
        //cout<<son[to]<<endl;
        f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod;
        re_dfs(to);
    }
}
ll dfs(int len,int x)
{
    if(son[x]==0) 
    {
        flag=x==b[len];
        return x<b[len];
    }
    long long ans=0,sum=0;
    for(int i=fr[x];i;i=mo[i].pr)
        if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++;
    //cout<<ans<<endl;
    ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod;
    //cout<<x<<" "<<ans<<endl;
    ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x];
    //cout<<x<<" "<<" "<<lat<<endl;
    flag=1;
    for(int k=1;k<=son[x];k++)
    {
        bool jud=0;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            //cout<<to<<" "<<b[lat]<<endl;
            if(to==b[lat])
            {
                v[to]=1;
                pi=pi*inv(f[to])%mod;
                long long tmp=dfs(lat,to);
                ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod;
            //    cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl;
                lat=lat+size[to],sum--;
                jud=1;
                break;
            }
        }
        if(!flag) break;
        if(!jud) break;
        int cnt=0;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            if(v[to]) continue;
            if(to<b[lat]) cnt++;
        }
        //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl;
        if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod;
        //cout<<ans<<endl;
        
    }
    if(flag==1&&sum==0) flag=1;
    else flag=0;
    return ans;
}
ll solve()
{
    ll ans=0;
    first_dfs(1);
    re_dfs(1);
    for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod;
    //cout<<ans<<endl;
    memset(son,0,sizeof(son));
    memset(fa,0,sizeof(fa));
    memset(f,0,sizeof(f));
    memset(size,0,sizeof(size));
    first_dfs(b[1]);
    ans=(ans+dfs(1,b[1]))%mod;
    return ans;
}
int main()
{
    //freopen("travel2.in","r",stdin);
    //freopen("data1.in","r",stdin);
    //freopen("data1.out","w",stdout);
    n=rd();fac[0]=1;
    for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod;
    for(int i=1,x,y;i<n;i++)
    {
        x=rd(),y=rd();
        add(x,y);add(y,x);
    }
    printf("%lld\n",solve());
}
/*
g++ -std=c++11 1.cpp -o 1
./1
6
1 3 6 2 5 4 
1 2
1 3
1 4
4 5
1 6
*/
80pts更容易理解
#include<cstring>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const int N=300020,mod=1e9+7;
int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n;
ll f[N],fac[N];
bool v[N];
struct node{int to,pr;}mo[N*2];
long long rd()
{
    long long s=0,w=1;
    char cc=getchar();
    while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();}
    while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar();
    return s*w;
}
ll inv(ll a)
{
    ll ans=1,k=mod-2;
    for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod;
    return ans;
}
void add(int x,int y)
{
    mo[++tt].to=y;
    mo[tt].pr=fr[x];
    fr[x]=tt;
}
void first_dfs(int x)
{
    ll ans=1;size[x]=1;
    for(int i=fr[x];i;i=mo[i].pr)
    {
        int to=mo[i].to;
        if(to==fa[x])continue;
        son[x]++;
        fa[to]=x;
        first_dfs(to);
        ans=1ll*ans*f[to]%mod;
        size[x]+=size[to];
    }
    f[x]=1ll*fac[son[x]]*ans%mod;
}
void re_dfs(int x)
{
    for(int i=fr[x];i;i=mo[i].pr)
    {
        int to=mo[i].to;
        if(to==fa[x]) continue;
        //cout<<x<<" "<<to<<endl;
        //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl;
        //cout<<son[to]<<endl;
        f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod;
        re_dfs(to);
    }
}
ll dfs(int len,int x)
{
    if(son[x]==0) 
    {
        flag=x==b[len];
        return x<b[len];
    }
    long long ans=0,sum=0;
    for(int i=fr[x];i;i=mo[i].pr)
        if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++;
    //cout<<ans<<endl;
    ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod;
    //cout<<x<<" "<<ans<<endl;
    ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x];
    //cout<<x<<" "<<" "<<lat<<endl;
    flag=1;
    for(int k=1;k<=son[x];k++)
    {
        bool jud=0;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            //cout<<to<<" "<<b[lat]<<endl;
            if(to==b[lat])
            {
                v[to]=1;
                pi=pi*inv(f[to])%mod;
                long long tmp=dfs(lat,to);
                ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod;
            //    cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl;
                lat=lat+size[to],sum--;
                jud=1;
                break;
            }
        }
        if(!flag) break;
        if(!jud) break;
        int cnt=0;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            if(v[to]) continue;
            if(to<b[lat]) cnt++;
        }
        //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl;
        if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod;
        //cout<<ans<<endl;
        
    }
    if(flag==1&&sum==0) flag=1;
    else flag=0;
    return ans;
}
ll solve()
{
    ll ans=0;
    first_dfs(1);
    re_dfs(1);
    for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod;
    //cout<<ans<<endl;
    memset(son,0,sizeof(son));
    memset(fa,0,sizeof(fa));
    memset(f,0,sizeof(f));
    memset(size,0,sizeof(size));
    first_dfs(b[1]);
    ans=(ans+dfs(1,b[1]))%mod;
    return ans;
}
int main()
{
    //freopen("travel2.in","r",stdin);
    //freopen("data1.in","r",stdin);
    //freopen("data1.out","w",stdout);
    n=rd();fac[0]=1;
    for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod;
    for(int i=1,x,y;i<n;i++)
    {
        x=rd(),y=rd();
        add(x,y);add(y,x);
    }
    printf("%lld\n",solve());
}
/*
g++ -std=c++11 1.cpp -o 1
./1
6
1 3 6 2 5 4 
1 2
1 3
1 4
4 5
1 6
*/
100pts

 

posted @ 2019-08-23 17:41  starsing  阅读(257)  评论(0编辑  收藏  举报