雅礼集训2019Day2 T2—Bracket(点分治+FFT)

描述

给定一棵有 nn 个节点的无根树,每个节点上是一个字符,要么是((,要么是))

定义 S(x,y)S(x, y) 为从$ $x 开始沿着最短路走到 yy,将沿途经过的点上的字符依次连起来得到的字符串。 合法括号序定义如下:

1,()()是合法的。

2,若 AA合法,则(A)(A)也合法。

3,若 ABA, B 分别合法,则 ABAB 也合法。

函数f(x,y)f(x, y) 等于对 S(x,y)S(x, y) 进行划分,使得每一个部分都是合法括号序,能得到的最大的段数,比如(())()()(())()()的最大段数为 3, (()())(())(()())(())的最大段数为 22

特别的,如果 S(x,y)S(x,y)本身并不是合法括号序,则f(x,y)=0f(x,y)=0

mm 次询问,每次输入一个 kk,查询有多少点对的 ff 值为kk


点分治,得到每个点到重心的前缀和ss
一条路径合法只有2端前缀和分别为sss-s

考虑一条路径能被划分成多少段,也就是有多少字段和为00
如果一个前缀和为ss,那该点到重心的点中前缀和为ss的次数就是划分的段数

每次fftfft合并2端的答案即可

#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<21|1;
inline char gc(){
    static char ibuf[RLEN],*ib,*ob;
    (ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
    return (ib==ob)?EOF:*ib++;
} 
#define gc getchar
inline int read(){
    char ch=gc();
    int res=0,f=1;
    while(!isdigit(ch))f^=ch=='-',ch=gc();
    while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
    return f?res:-res;
}
#define ll long long
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define pob pop_back
#define pf push_front
#define pof pop_front
#define mp make_pair
#define bg begin
#define re register
struct plx{
	double x,y;
	plx(double _x=0,double _y=0):x(_x),y(_y){}
	friend inline plx operator +(const plx &a,const plx &b){
		return plx(a.x+b.x,a.y+b.y);
	}
	friend inline plx operator -(const plx &a,const plx &b){
		return plx(a.x-b.x,a.y-b.y);
	}
	friend inline plx operator *(const plx &a,const plx &b){
		return plx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
	}
	friend inline plx operator /(const plx &a,const double &b){
		return plx(a.x/b,a.y/b);
	}
};
const int N=50005;
int rev[N<<2];
inline void init_rev(int lim){
	for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)*(lim>>1));
}
const double pi=acos(-1);
inline void fft(plx *f,int lim,int kd){
	for(int i=0;i<lim;i++)if(i>rev[i])swap(f[i],f[rev[i]]);
	for(int mid=1;mid<lim;mid<<=1){
		plx now=plx(cos(pi/mid),kd*sin(pi/mid));
		for(int i=0;i<lim;i+=mid*2){
			plx w=plx(1,0);
			for(int j=0;j<mid;j++,w=w*now){
				plx a0=f[i+j],a1=w*f[i+j+mid];
				f[i+j]=a0+a1,f[i+j+mid]=a0-a1;
			}
		}
	}
	if(kd==-1)for(int i=0;i<lim;i++)f[i]=f[i]/lim;
}
vector<int> e[N];
vector<int> f[N],g[N];
plx A[N<<2],B[N<<2];
ll ans[N];
inline void chemx(int &a,int b){
	a<b?a=b:0;
}
inline void mul(int k,int kd){
	int n=f[k].size(),m=g[k].size();
	if(!n||!m)return;
	int lim=1,mxa=0,mxb=0;
	for(int i=0;i<n;i++)chemx(mxa,f[k][i]);
	for(int i=0;i<m;i++)chemx(mxb,g[k][i]);
	while(lim<=(mxa+mxb))lim<<=1;
	init_rev(lim);
	for(int i=0;i<lim;i++)A[i]=B[i]=plx(0,0);
	for(int i=0;i<n;i++)A[f[k][i]].x++;
	for(int i=0;i<m;i++)B[g[k][i]].x++;
	fft(A,lim,1),fft(B,lim,1);
	for(int i=0;i<lim;i++)A[i]=A[i]*B[i];
	fft(A,lim,-1);
	for(int i=0;i<lim;i++)ans[i+(k!=0)]+=(ll)(A[i].x+0.5)*kd;
}
int siz[N],son[N],vis[N],maxn,rt,mxd;
int n,q,a[N];
void getrt(int u,int fa){
	siz[u]=1,son[u]=0;
	for(int &v:e[u]){
		if(v==fa||vis[v])continue;
		getrt(v,u),siz[u]+=siz[v];
		if(siz[v]>son[u])son[u]=siz[v];
	}
	son[u]=max(son[u],maxn-siz[u]);
	if(son[u]<son[rt])rt=u;
}
void getdisf(int u,int fa,int s,int mx,int cnt){
	s+=a[u],mxd++;
	if(!s)cnt++;
	if(s==1)s=cnt=0,mx++;
	if(!s)f[mx].pb(cnt);
	for(int &v:e[u]){
		if(v==fa||vis[v])continue;
		getdisf(v,u,s,mx,cnt);
	}
}
void getdisg(int u,int fa,int s,int mx,int cnt){
	s+=a[u];
	if(!s)cnt++;
	if(s==-1)s=cnt=0,mx++;
	if(!s)g[mx].pb(cnt);
	for(int &v:e[u]){
		if(v==fa||vis[v])continue;
		getdisg(v,u,s,mx,cnt);
	}
}
inline void solve(int u){
	vis[u]=1,mxd=0;
	getdisf(u,0,0,0,0);
	for(int &v:e[u])if(!vis[v])getdisg(v,0,0,0,0);
	g[0].pb(0);
	for(int i=0;i<=mxd;i++)mul(i,1),f[i].clear(),g[i].clear();
	for(int &v:e[u]){
		if(vis[v])continue;mxd=1;
		if(a[u]==1)getdisf(v,u,0,1,0);
		else getdisf(v,u,-1,0,0);
		getdisg(v,u,0,0,0);
		for(int i=0;i<=mxd;i++)mul(i,-1),f[i].clear(),g[i].clear();
	}
	for(int &v:e[u]){
		if(vis[v])continue;
		maxn=siz[v],getrt(v,rt=0),solve(rt);
	}
}
char op[5];
signed main(){
	n=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read();
		e[u].pb(v),e[v].pb(u);
	}
	for(int i=1;i<=n;i++){
		scanf("%s",op+1);
		if(op[1]==')')a[i]=1;
		else a[i]=-1;
	}
	maxn=son[0]=n,getrt(1,rt=0);
	solve(rt);
	for(int i=1;i<=n;i++)ans[0]-=ans[i];
	ans[0]+=1ll*n*n,q=read();
	for(int i=1;i<=q;i++)
	cout<<ans[read()]<<'\n';
}
posted @ 2019-07-24 14:22  Stargazer_cykoi  阅读(129)  评论(0编辑  收藏  举报