hard(2018.10.18)
题意:给你一棵\(n\)个节点的树,\(q\)个询问,每次询问读入\(u,v,k,op\),需要满足树上有\(k\)对点的简单路径交都等于\(u,v\)之间的简单路径,\(op=1\)表示\(k\)对点中每个点只能存在于一个点对中,否则每个点可以存在于多个点对中,问那k对点有多少种选法,答案对\(998244353\)取模。
数据范围:对于\(100%\)的数据,保证 \(1≤n≤10^5,1≤u,v≤n,u \ne v,1≤k≤min(n,500),op∈{0,1}\),保证每个节点的度数不超过\(500\)。
我们抓住“两两路径之交是\((u,v)\)”这条性质。 可以发现\(u,v\)是独立的。我们等价于要求:在\(u\)的子树中选\(k\)个点使它们两两\(lca\)是\(u\)的方案数,对\(v\)也求同样的东西,再把两者相乘。如果\(u,v\)存在祖孙关系,不妨设\(u\)是\(v\)的祖先,那么\(u\)的子树就要改为以\(v\)的方向作为根方向前提下的子树。 显然为了使两两\(lca\)是\(u\),在\(u\)的每一个儿子中就至多只能选一个点。然后这题就差不多了。
设\(g[x][i]\)为在\(x\)的子树里选\(i\)个点的方案数,\(ans\)为最后的答案,\(u,v\)为读入的\(u,v\),钦定\(u\)为深度小的那个点,\(tmp\)为\(u-v\)路径上最靠近\(u\)的那个点
那么有:
\[if(lca(u,v)==u)ans=tmp[k]*g[v][k]
\]
\[else\ \ \ \ \ \ \ ans=g[u][k]*g[v][k]
\]
注意:\(k\)个点对是不等价的,比如说我们可以选\((i,j)\)为第一个点对和选\((i,j)\)为第二个点对是两种方案。
代码:
#include<cstdio>
#include<algorithm>
int n,q,cnt,fac[501],inv[501],facinv[501],pre[200001],nxt[200001],h[100001],f[100001][20],size[100001],dep[100001],mod=998244353;
struct oo{
int d[601],du;oo(){d[du=0]=1;}
void add(int x){du++;for(int i=du;i;i--)d[i]=(d[i]+1ll*d[i-1]*x)%mod;}
void del(int x){for(int i=1;i<=du;i++)d[i]=((d[i]-1ll*d[i-1]*x)%mod+mod)%mod;du--;}
int cal(int x,int op){int ans=0;for(int i=op?x-1:0;i<=x;i++)ans=(ans+1ll*d[i]*facinv[x-i])%mod;return (1ll*ans*fac[x])%mod;}
}g[100001];
void add(int x,int y){
pre[++cnt]=y;nxt[cnt]=h[x];h[x]=cnt;
pre[++cnt]=x;nxt[cnt]=h[y];h[y]=cnt;}
void dfs(int x){size[x]=1;
for(int i=1;i<20;i++){if(dep[x]<(1<<i))break;f[x][i]=f[f[x][i-1]][i-1];}
for(int i=h[x];i;i=nxt[i])if(pre[i]!=f[x][0]){dep[pre[i]]=dep[x]+1,f[pre[i]][0]=x,dfs(pre[i]),size[x]+=size[pre[i]];g[x].add(size[pre[i]]);}}
int lca(int x,int y){
if(dep[x]>dep[y])std::swap(x,y);int poor=dep[y]-dep[x];
for(int i=19;i>=0;i--)if(poor&(1<<i))y=f[y][i];
for(int i=19;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
if(x==y)return x;return f[x][0];}
int get(int x,int y){int poor=dep[x]-dep[y]-1;for(int i=19;i>=0;i--)if(poor&(1<<i))x=f[x][i];return x;}
int main(){
scanf("%d%d",&n,&q);inv[1]=fac[0]=facinv[0]=1;for(int i=2;i<=500;i++)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=500;i++)fac[i]=1ll*fac[i-1]*i%mod,facinv[i]=1ll*facinv[i-1]*inv[i]%mod;
for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),add(x,y);dfs(1);
for(int i=1,u,v,k,op;i<=q;i++){
scanf("%d%d%d%d",&u,&v,&k,&op);if(dep[u]>dep[v])std::swap(u,v);
if(lca(u,v)==u){
int now=get(v,u);oo s=g[u];s.del(size[now]),s.add(n-size[u]);
printf("%d\n",(1ll*s.cal(k,op)*g[v].cal(k,op))%mod);}
else printf("%d\n",(1ll*g[u].cal(k,op)*g[v].cal(k,op))%mod);}}