「NOI十联测」反函数
30pts
令(为1,)为-1;
暴力枚举每个点为起始点的路径,一条路径是合法的当且仅当路径权值和为0且路径上没有出现过负数。
将所有答案算出。
100pts
使用点分治。
要求知道经过重心root的路径,这里默认把root当做树的根。
经过root的路径$ (x,y)$分为两种:
- root是路径一端点。
- root不是路径\((x,y)\)一端点,可以将路径分为\((x,root.son[x]),root,(root.son[y],y)\),其中\(root.son[x]\)表示\((x,root)\)路径中是root子节点的点。
这样就只需要链的信息。
每个节点记录从\(root.son\)到它的权值和\(val\)。
链\((x,root.son[x] )\)是合法的仅当\(val[x]\)是该路径上最大的。
否则令权值最大的点为\(y\),那么\((x,y)\)的权值和为负数,不合法。
同样的\((root.son[x],x )\)合法仅当\(val[x]\)是该路径上最小的。
同时再记录\((x,root.son[x])\)路径上最大值,\((root.son[x],x)\)上最小值出现的次数,即它可以划分的个数。
具体实现如下:
void add(int f,int x,int Mx,int ct,int Mi,int ct1) {
sum[x]=sum[f]+val[x];
if(sum[x]>Mx)ct=1,Mx=sum[x];
else if(sum[x]==Mx)++ct;
if(sum[x]<Mi)ct1=1,Mi=sum[x];
else if(sum[x]==Mi)++ct1;
if(Mx>=0&&sum[x]==Mx)Max[Mx].push_back(ct);
if(Mi<=0&&sum[x]==Mi)Min[-Mi].push_back(ct1);
++all_si;
travel(x)if(to[q]!=f&&!mark[q])add(x,to[q],Mx,ct,Mi,ct1);
}
合并的时候,要是路径和为0,分三类。
- 从根向下
- \((x,root.son[x]),(root.son[y],y)\)中某一个值为0
- \((x,root.son[x]),(root.son[y],y)\)均不为0
由于root也要算在内,根据root分类。
第一种情况暴力。
if(val[rt]==1) {
rep(w,0,(int)Min[1].size()-1)++ans[Min[1][w]];
} else {
rep(w,0,(int)Max[1].size()-1)++ans[Max[1][w]];
}
第二种情况,先写暴力代码:
if(val[rt]==1) {
if(Max[0].size()&&Min[1].size()) {
rep(w,0,(int)Min[1].size()-1) {
rep(e,0,(int)Max[0].size()-1) {
++ans[Min[1][w]+Max[0][e]];
}
}
}
} else {
if(Max[1].size()&&Min[0].size()) {
rep(w,0,(int)Min[0].size()-1) {
rep(e,0,(int)Max[1].size()-1) {
++ans[Min[0][w]+Max[1][e]];
}
}
}
}
发现其实就是一个卷积的形式。
而且由于vector里面是路径上出现最大或最小的次数,因此必然是值连续的。
故使用FFT复杂度正确,共计\(o(nlog(n^2))\)。
第3种情况两条路拼成一条,因此算S的时候要减1,其他和2一样。
总复杂度$o(nlog(n^2)) $,常数不管。
Code
#include<bits/stdc++.h>
#define rep(q,a,b) for(int q=a,q##_end_=b;q<=q##_end_;++q)
#define dep(q,a,b) for(int q=a,q##_end_=b;q>=q##_end_;--q)
#define mem(a,b) memset(a,b,sizeof a )
#define debug(a) cerr<<#a<<' '<<a<<"___"<<endl
using namespace std;
typedef long long ll;
void in(int &r) {
static char c;
r=0;
while(c=getchar(),!isdigit(c));
do r=(r<<1)+(r<<3)+(c^48);
while(c=getchar(),isdigit(c));
}
const int mn=50005;
ll ans[132000];
const int mod=998244353;
int mlv(int x,int v){
int ans=1;
while(v){
if(v&1)ans=1LL*ans*x%mod;
x=1LL*x*x%mod,v>>=1;
}
return ans;
}
namespace NTT {
const int g=3;
const int gg=332748118;
const int mn=132000;
int to[mn],lim,n,a[mn],b[mn];
void DFT(int* a,int inv) {
rep(q,0,n-1)if(to[q]<q)swap(a[q],a[to[q]]);
for(int len=2; len<=n; len<<=1) {
int sp=len>>1;
ll mv=1;
int ml=mlv(inv?g:gg,(mod-1)/len);
for(int k=0; k<sp; ++k) {
for(int* p=a; p!=a+n; p+=len) {
int mid=1LL*p[k+sp]*mv%mod;
p[k+sp]=(p[k]-mid)%mod;
p[k]=(p[k]+mid)%mod;
}
mv=mv*ml%mod;
}
}
}
void solve(int* a1,int len,int* b1,int len1,int ty) {
n=1,lim=0;
while(n<(len+len1))n<<=1,++lim;
rep(q,0,n-1)to[q]=(to[q>>1]>>1)|(q&1)<<(lim-1);
rep(q,0,len-1)a[q]=a1[q];
rep(q,len,n-1)a[q]=0;
rep(q,0,len1-1)b[q]=b1[q];
rep(q,len1,n-1)b[q]=0;
DFT(a,1),DFT(b,1);
rep(q,0,n-1)a[q]=1LL*a[q]*b[q]%mod;
DFT(a,0);
int inv_n=mlv(n,mod-2);
rep(q,0,n-1)ans[q]+=ty*((1LL*a[q]*inv_n%mod+mod)%mod);
}
}
int head[mn],ne[mn<<1],to[mn<<1],cnt1=1;
#define link(a,b) link_edge(a,b),link_edge(b,a)
#define link_edge(a,b) to[++cnt1]=b,ne[cnt1]=head[a],head[a]=cnt1
#define travel(x) for(int q(head[x]);q;q=ne[q])
int val[mn];
bool mark[mn<<1];
int si[mn],all_si,Mn,root;
void get(int f,int x) {
si[x]=1;
travel(x)if(to[q]!=f&&!mark[q])get(x,to[q]),si[x]+=si[to[q]];
}
void find(int f,int x) {
travel(x)if(to[q]!=f&&!mark[q])find(x,to[q]);
si[x]=max(si[x],all_si-si[x]);
if(si[x]<Mn)Mn=si[x],root=x;
}
int num[mn],sum[mn];
vector<int> Max[mn],Min[mn];
void add(int f,int x,int Mx,int ct,int Mi,int ct1) {
sum[x]=sum[f]+val[x];
if(sum[x]>Mx)ct=1,Mx=sum[x];
else if(sum[x]==Mx)++ct;
if(sum[x]<Mi)ct1=1,Mi=sum[x];
else if(sum[x]==Mi)++ct1;
if(Mx>=0&&sum[x]==Mx)Max[Mx].push_back(ct);
if(Mi<=0&&sum[x]==Mi)Min[-Mi].push_back(ct1);
++all_si;
travel(x)if(to[q]!=f&&!mark[q])add(x,to[q],Mx,ct,Mi,ct1);
}
void clear(int v) {
rep(q,0,v)Max[q].clear(),Min[q].clear();
}
int mid[mn],mid1[mn];
void get_FFT(int a,int b,int ty,int v) {
int mq=0,mw=0;
rep(w,0,(int)Max[a].size()-1)++mid[Max[a][w]],mq=max(mq,Max[a][w]);
rep(w,0,(int)Min[b].size()-1)++mid1[Min[b][w]-v],mw=max(mw,Min[b][w]-v);
NTT::solve(mid,mq+1,mid1,mw+1,ty);
rep(w,0,mq)mid[w]=0;
rep(w,0,mw)mid1[w]=0;
}
void calc(int rt,int v,int ty) {
rep(q,0,v)mid[q]=mid1[q]=0;
if(val[rt]==1) {
rep(q,1,v-1)if(Max[q].size()&&Min[q+1].size())get_FFT(q,q+1,ty,1);
if(Max[0].size()&&Min[1].size())get_FFT(0,1,ty,0);
if(ty==1)rep(w,0,(int)Min[1].size()-1)++ans[Min[1][w]];
} else {
rep(q,1,v-1)if(Max[q+1].size()&&Min[q].size())get_FFT(q+1,q,ty,1);
if(Max[1].size()&&Min[0].size())get_FFT(1,0,ty,0);
if(ty==1)rep(w,0,(int)Max[1].size()-1)++ans[Max[1][w]];
}
}
void solve(int x) {
get(0,x);
all_si=si[x],Mn=1e9,root=x;
int mid=all_si;
find(0,x);
int mid_root=root;
travel(mid_root)if(!mark[q]) {
mark[q]=mark[q^1]=1;
solve(to[q]);
mark[q]=mark[q^1]=0;
}
travel(mid_root)if(!mark[q]) {
mark[q]=mark[q^1]=1;
add(0,to[q],-1e9,0,1e9,0);
mark[q]=mark[q^1]=0;
}
calc(mid_root,mid,1);
clear(mid);
travel(mid_root)if(!mark[q]) {
mark[q]=mark[q^1]=1;
all_si=0;
add(0,to[q],-1e9,0,1e9,0);
calc(mid_root,all_si,-1);
clear(all_si);
mark[q]=mark[q^1]=0;
}
}
int main() {
int n,m,a,b;
in(n);
rep(q,1,n-1)in(a),in(b),link(a,b);
char c;
rep(q,1,n) {
while(c=getchar(),c!=')'&&c!='(');
val[q]=c=='('?1:-1;
}
solve(1);
in(m);
rep(q,1,m)in(a),printf("%lld\n",ans[a]);
return 0;
}