luogu4365 秘密袭击 (生成函数+线段树合并+拉格朗日插值)
求所有可能联通块的第k大值的和,考虑枚举这个值:
$ans=\sum\limits_{i=1}^{W}{i\sum\limits_{S}{[i是第K大]}}$
设cnt[i]为连通块中值>=i的个数
$ans=\sum\limits_{i=1}^{W}{i\sum\limits_{S}{[cnt[i]>=K]-[cnt[i+1]>=K]}}$
$ans=\sum\limits_{i=1}^{W}{\sum\limits_{S}{[cnt[i]>=K]}}$
于是先考虑树上dp,设f[i][j][k]表示以i为根的连通块中,值>=j的数量为k的情况数
然后$ans=\sum\limits_{i=1}^{N}{\sum\limits_{j=1}^{W}{\sum\limits_{k=K}^{N}{f[i][j][k]}}}$
转移和背包类似,所以这样做是$O(N^2W)$的
考虑使用生成函数优化,设$F[i][j]=\sum{f[i][j][k]x^k}$,再设$G[i][j]=\sum{F[s][j]},i是s的祖先$
于是转移就变成了$F[i][j]*=(F[s][j]+1),G[i][j]+=G[s][j],G[i][j]+=F[i][j]$,其中s是i的孩子
同时有初值$F[i][j]=(d[i]>=j?x:1)$,答案就是G[1][*]的K~N项系数的和
然后当然不能真的去乘了..
考虑先将F和G用点值表达,最后再插回来
首先枚举x=1..N+1,然后给每个点i开动态开点的线段树维护F[i][j]和G[i][j]的值
然后用线段树合并来做对应位置的相乘和相加
具体来说,我们让线段树上的结点维护一个作用在$(f,g)$上的变换$(a,b,c,d)$,使得最终得到$(af+b,cf+d+g)$
然后也不难得到变换的乘法(有结合律但没有交换律)
然后就可以做了 复杂度我也不会分析 反正有可能跑的比暴力还慢
别忘了回收掉不用的点
1 #include<bits/stdc++.h> 2 #define pa pair<int,int> 3 #define CLR(a,x) memset(a,x,sizeof(a)) 4 #define MP make_pair 5 #define fi first 6 #define se second 7 using namespace std; 8 typedef long long ll; 9 typedef unsigned long long ull; 10 typedef unsigned int ui; 11 typedef long double ld; 12 const int maxn=1700,maxp=3e6; 13 const int P=64123; 14 15 inline char gc(){ 16 return getchar(); 17 static const int maxs=1<<16;static char buf[maxs],*p1=buf,*p2=buf; 18 return p1==p2&&(p2=(p1=buf)+fread(buf,1,maxs,stdin),p1==p2)?EOF:*p1++; 19 } 20 inline ll rd(){ 21 ll x=0;char c=gc();bool neg=0; 22 while(c<'0'||c>'9'){if(c=='-') neg=1;c=gc();} 23 while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-'0',c=gc(); 24 return neg?(~x+1):x; 25 } 26 27 struct Node{ 28 int a,b,c,d; 29 Node(int _a=1,int _b=0,int _c=0,int _d=0){a=_a,b=_b,c=_c,d=_d;} 30 }val[maxp]; 31 Node operator *(Node x,Node y){ 32 return Node(1ll*x.a*y.a%P,(1ll*x.b*y.a+y.b)%P,(1ll*x.a*y.c+x.c)%P,(1ll*x.b*y.c+x.d+y.d)%P); 33 } 34 35 int N,K,W,dan[maxn],eg[maxn*2][2],egh[maxn],ect; 36 int ch[maxp][2],stk[maxp],sh,rt[maxn]; 37 int yy[maxn]; 38 39 inline void adeg(int a,int b){ 40 eg[++ect][0]=b,eg[ect][1]=egh[a],egh[a]=ect; 41 } 42 43 inline int newnode(){ 44 int p=stk[sh--]; 45 assert(sh>=1); 46 ch[p][0]=ch[p][1]=0; 47 val[p]=Node(); 48 return p; 49 } 50 51 inline void delall(int &p){ 52 if(!p) return; 53 delall(ch[p][0]);delall(ch[p][1]); 54 stk[++sh]=p;p=0; 55 } 56 57 inline void pushdown(int p){ 58 if(!ch[p][0]) ch[p][0]=newnode(); 59 if(!ch[p][1]) ch[p][1]=newnode(); 60 val[ch[p][0]]=val[ch[p][0]]*val[p]; 61 val[ch[p][1]]=val[ch[p][1]]*val[p]; 62 val[p]=Node(); 63 } 64 65 void mul(int &p,int l,int r,int x,int y,Node z){ 66 if(!p) p=newnode(); 67 if(x<=l&&r<=y){ 68 val[p]=val[p]*z; 69 }else{ 70 int m=(l+r)>>1;pushdown(p); 71 if(x<=m) mul(ch[p][0],l,m,x,y,z); 72 if(y>=m+1) mul(ch[p][1],m+1,r,x,y,z); 73 } 74 } 75 76 int merge(int &p,int &q){ 77 if(!p||!q) return p|q; 78 if(!ch[p][0]&&!ch[p][1]) swap(p,q); 79 if(!ch[q][0]&&!ch[q][1]){ 80 val[p]=val[p]*Node(val[q].b,0,0,val[q].d); 81 return p; 82 } 83 pushdown(p),pushdown(q); 84 ch[p][0]=merge(ch[p][0],ch[q][0]); 85 ch[p][1]=merge(ch[p][1],ch[q][1]); 86 return p; 87 } 88 89 void dfs(int x,int f,int id){ 90 mul(rt[x],1,W,1,W,Node(0,1,0,0)); 91 for(int i=egh[x];i;i=eg[i][1]){ 92 int b=eg[i][0];if(b==f) continue; 93 dfs(b,x,id); 94 merge(rt[x],rt[b]); 95 delall(rt[b]); 96 } 97 mul(rt[x],1,W,1,dan[x],Node(id,0,0,0)); 98 mul(rt[x],1,W,1,W,Node(1,0,1,0)); 99 mul(rt[x],1,W,1,W,Node(1,1,0,0)); 100 } 101 102 int query(int p,int l,int r){ 103 if(!p) return 0; 104 if(l==r) return val[p].d; 105 int m=(l+r)>>1;pushdown(p); 106 return (query(ch[p][0],l,m)+query(ch[p][1],m+1,r))%P; 107 } 108 109 int fpow(int x,int y){ 110 int r=1; 111 while(y){ 112 if(y&1) r=1ll*r*x%P; 113 x=1ll*x*x%P,y>>=1; 114 }return r; 115 } 116 117 int l[maxn],tmp[maxn],ans[maxn]; 118 void calc(){ 119 l[0]=1; 120 for(int i=1;i<=N+1;i++){ 121 for(int j=i-1;j>=0;j--){ 122 l[j+1]=(l[j+1]+l[j])%P; 123 l[j]=-1ll*i*l[j]%P; 124 } 125 } 126 for(int i=1;i<=N+1;i++){ 127 int ib=-fpow(i,P-2); 128 tmp[0]=1ll*l[0]*ib%P; 129 for(int j=1;j<=N;j++){ 130 tmp[j]=1ll*(l[j]-tmp[j-1])*ib%P; 131 } 132 int k=0,x=1; 133 for(int j=0;j<=N;j++){ 134 k=(1ll*x*tmp[j]+k)%P; 135 x=1ll*x*i%P; 136 } 137 k=1ll*fpow(k,P-2)*yy[i]%P; 138 for(int j=0;j<=N;j++){ 139 ans[j]=(1ll*tmp[j]*k+ans[j])%P; 140 } 141 } 142 } 143 144 int main(){ 145 //freopen("","r",stdin); 146 N=rd(),K=rd(),W=rd(); 147 for(int i=1;i<=N;i++) dan[i]=rd(); 148 for(int i=1;i<N;i++){ 149 int a=rd(),b=rd(); 150 adeg(a,b);adeg(b,a); 151 } 152 for(int i=1;i<maxp-5;i++) stk[++sh]=i; 153 154 for(int i=1;i<=N+1;i++){ 155 dfs(1,0,i); 156 yy[i]=query(rt[1],1,W); 157 delall(rt[1]); 158 } 159 calc(); 160 int a=0; 161 for(int i=K;i<=N;i++) a=(a+ans[i])%P; 162 printf("%d\n",(a+P)%P); 163 return 0; 164 }