「NOI十联测」反函数

30pts

令(为1,)为-1;

暴力枚举每个点为起始点的路径,一条路径是合法的当且仅当路径权值和为0且路径上没有出现过负数。

将所有答案算出。

100pts

使用点分治。

要求知道经过重心root的路径,这里默认把root当做树的根。

经过root的路径$ (x,y)$分为两种:

  1. root是路径一端点。
  2. root不是路径\((x,y)\)一端点,可以将路径分为\((x,root.son[x]),root,(root.son[y],y)\),其中\(root.son[x]\)表示\((x,root)\)路径中是root子节点的点。

这样就只需要链的信息。

每个节点记录从\(root.son\)到它的权值和\(val\)

\((x,root.son[x] )\)是合法的仅当\(val[x]\)是该路径上最大的。

否则令权值最大的点为\(y\),那么\((x,y)\)的权值和为负数,不合法。

同样的\((root.son[x],x )\)合法仅当\(val[x]\)是该路径上最小的。

同时再记录\((x,root.son[x])\)路径上最大值,\((root.son[x],x)\)上最小值出现的次数,即它可以划分的个数。

具体实现如下:

void add(int f,int x,int Mx,int ct,int Mi,int ct1) {
    sum[x]=sum[f]+val[x];
    if(sum[x]>Mx)ct=1,Mx=sum[x];
    else if(sum[x]==Mx)++ct;
    if(sum[x]<Mi)ct1=1,Mi=sum[x];
    else if(sum[x]==Mi)++ct1;
    if(Mx>=0&&sum[x]==Mx)Max[Mx].push_back(ct);
    if(Mi<=0&&sum[x]==Mi)Min[-Mi].push_back(ct1);
    ++all_si;
    travel(x)if(to[q]!=f&&!mark[q])add(x,to[q],Mx,ct,Mi,ct1);
}

合并的时候,要是路径和为0,分三类。

  1. 从根向下
  2. \((x,root.son[x]),(root.son[y],y)\)中某一个值为0
  3. \((x,root.son[x]),(root.son[y],y)\)均不为0

由于root也要算在内,根据root分类。

第一种情况暴力。

if(val[rt]==1) {
	rep(w,0,(int)Min[1].size()-1)++ans[Min[1][w]];
} else {
	rep(w,0,(int)Max[1].size()-1)++ans[Max[1][w]];
}

第二种情况,先写暴力代码:

if(val[rt]==1) {
	if(Max[0].size()&&Min[1].size()) {
		rep(w,0,(int)Min[1].size()-1) {
			rep(e,0,(int)Max[0].size()-1) {
				++ans[Min[1][w]+Max[0][e]];
			}
		}
	}
} else {
	if(Max[1].size()&&Min[0].size()) {
		rep(w,0,(int)Min[0].size()-1) {
			rep(e,0,(int)Max[1].size()-1) {
				++ans[Min[0][w]+Max[1][e]];
			}
		}
	}
}

发现其实就是一个卷积的形式。

而且由于vector里面是路径上出现最大或最小的次数,因此必然是值连续的。

故使用FFT复杂度正确,共计\(o(nlog(n^2))\)

第3种情况两条路拼成一条,因此算S的时候要减1,其他和2一样。

总复杂度$o(nlog(n^2)) $,常数不管。

Code

#include<bits/stdc++.h>
#define rep(q,a,b) for(int q=a,q##_end_=b;q<=q##_end_;++q)
#define dep(q,a,b) for(int q=a,q##_end_=b;q>=q##_end_;--q)
#define mem(a,b) memset(a,b,sizeof a )
#define debug(a) cerr<<#a<<' '<<a<<"___"<<endl
using namespace std;
typedef long long ll;
void in(int &r) {
    static char c;
    r=0;
    while(c=getchar(),!isdigit(c));
    do r=(r<<1)+(r<<3)+(c^48);
    while(c=getchar(),isdigit(c));
}
const int mn=50005;
ll ans[132000];
const int mod=998244353;
int mlv(int x,int v){
    int ans=1;
    while(v){
        if(v&1)ans=1LL*ans*x%mod;
        x=1LL*x*x%mod,v>>=1;
    }
    return ans;
}
namespace NTT {
    const int g=3;
    const int gg=332748118;
    const int mn=132000;
    int to[mn],lim,n,a[mn],b[mn];
    void DFT(int* a,int inv) {
        rep(q,0,n-1)if(to[q]<q)swap(a[q],a[to[q]]);
        for(int len=2; len<=n; len<<=1) {
            int sp=len>>1;
            ll mv=1;
            int ml=mlv(inv?g:gg,(mod-1)/len);
            for(int k=0; k<sp; ++k) {
                for(int* p=a; p!=a+n; p+=len) {
                    int mid=1LL*p[k+sp]*mv%mod;
                    p[k+sp]=(p[k]-mid)%mod;
                    p[k]=(p[k]+mid)%mod;
                }
                mv=mv*ml%mod;
            }
        }
    }
    void solve(int* a1,int len,int* b1,int len1,int ty) {
        n=1,lim=0;
        while(n<(len+len1))n<<=1,++lim;
        rep(q,0,n-1)to[q]=(to[q>>1]>>1)|(q&1)<<(lim-1);
        rep(q,0,len-1)a[q]=a1[q];
        rep(q,len,n-1)a[q]=0;
        rep(q,0,len1-1)b[q]=b1[q];
        rep(q,len1,n-1)b[q]=0;
        DFT(a,1),DFT(b,1);
        rep(q,0,n-1)a[q]=1LL*a[q]*b[q]%mod;
        DFT(a,0);
        int inv_n=mlv(n,mod-2);
        rep(q,0,n-1)ans[q]+=ty*((1LL*a[q]*inv_n%mod+mod)%mod);
    }
}
int head[mn],ne[mn<<1],to[mn<<1],cnt1=1;
#define link(a,b) link_edge(a,b),link_edge(b,a)
#define link_edge(a,b) to[++cnt1]=b,ne[cnt1]=head[a],head[a]=cnt1
#define travel(x) for(int q(head[x]);q;q=ne[q])
int val[mn];
bool mark[mn<<1];
int si[mn],all_si,Mn,root;
void get(int f,int x) {
    si[x]=1;
    travel(x)if(to[q]!=f&&!mark[q])get(x,to[q]),si[x]+=si[to[q]];
}
void find(int f,int x) {
    travel(x)if(to[q]!=f&&!mark[q])find(x,to[q]);
    si[x]=max(si[x],all_si-si[x]);
    if(si[x]<Mn)Mn=si[x],root=x;
}
int num[mn],sum[mn];
vector<int> Max[mn],Min[mn];
void add(int f,int x,int Mx,int ct,int Mi,int ct1) {
    sum[x]=sum[f]+val[x];
    if(sum[x]>Mx)ct=1,Mx=sum[x];
    else if(sum[x]==Mx)++ct;
    if(sum[x]<Mi)ct1=1,Mi=sum[x];
    else if(sum[x]==Mi)++ct1;
    if(Mx>=0&&sum[x]==Mx)Max[Mx].push_back(ct);
    if(Mi<=0&&sum[x]==Mi)Min[-Mi].push_back(ct1);
    ++all_si;
    travel(x)if(to[q]!=f&&!mark[q])add(x,to[q],Mx,ct,Mi,ct1);
}
void clear(int v) {
    rep(q,0,v)Max[q].clear(),Min[q].clear();
}
int mid[mn],mid1[mn];
void get_FFT(int a,int b,int ty,int v) {
    int mq=0,mw=0;
    rep(w,0,(int)Max[a].size()-1)++mid[Max[a][w]],mq=max(mq,Max[a][w]);
    rep(w,0,(int)Min[b].size()-1)++mid1[Min[b][w]-v],mw=max(mw,Min[b][w]-v);
    NTT::solve(mid,mq+1,mid1,mw+1,ty);
    rep(w,0,mq)mid[w]=0;
    rep(w,0,mw)mid1[w]=0;
}
void calc(int rt,int v,int ty) {
    rep(q,0,v)mid[q]=mid1[q]=0;
    if(val[rt]==1) {
        rep(q,1,v-1)if(Max[q].size()&&Min[q+1].size())get_FFT(q,q+1,ty,1);
        if(Max[0].size()&&Min[1].size())get_FFT(0,1,ty,0);
        if(ty==1)rep(w,0,(int)Min[1].size()-1)++ans[Min[1][w]];
    } else {
        rep(q,1,v-1)if(Max[q+1].size()&&Min[q].size())get_FFT(q+1,q,ty,1);
        if(Max[1].size()&&Min[0].size())get_FFT(1,0,ty,0);
        if(ty==1)rep(w,0,(int)Max[1].size()-1)++ans[Max[1][w]];
    }
}
 
void solve(int x) {
    get(0,x);
    all_si=si[x],Mn=1e9,root=x;
    int mid=all_si;
    find(0,x);
    int mid_root=root;
    travel(mid_root)if(!mark[q]) {
        mark[q]=mark[q^1]=1;
        solve(to[q]);
        mark[q]=mark[q^1]=0;
    }
    travel(mid_root)if(!mark[q]) {
        mark[q]=mark[q^1]=1;
        add(0,to[q],-1e9,0,1e9,0);
        mark[q]=mark[q^1]=0;
    }
    calc(mid_root,mid,1);
    clear(mid);
    travel(mid_root)if(!mark[q]) {
        mark[q]=mark[q^1]=1;
        all_si=0;
        add(0,to[q],-1e9,0,1e9,0);
        calc(mid_root,all_si,-1);
        clear(all_si);
        mark[q]=mark[q^1]=0;
    }
}
int main() {
    int n,m,a,b;
    in(n);
    rep(q,1,n-1)in(a),in(b),link(a,b);
    char c;
    rep(q,1,n) {
        while(c=getchar(),c!=')'&&c!='(');
        val[q]=c=='('?1:-1;
    }
    solve(1);
    in(m);
    rep(q,1,m)in(a),printf("%lld\n",ans[a]);
    return 0;
}
posted @ 2019-08-05 20:56  Eeis  阅读(161)  评论(0编辑  收藏  举报