LOJ 2743(洛谷 4365) 「九省联考 2018」秘密袭击——整体DP+插值思想
题目:https://loj.ac/problem/2473
https://www.luogu.org/problemnew/show/P4365
参考:https://blog.csdn.net/xyz32768/article/details/82952313
https://blog.csdn.net/qq_35649707/article/details/79923740
关于如何已知 n+1 个点值 n2 还原出 n 次多项式的系数:
看到连值域也只有 1666 ,首先想到的是枚举第 k 大值是什么,设为 w 。
然后自己只能想到 n6 的 DP ……就是记录 “有 i 个 > w 的值和 j 个 = w 的值” 的连通块个数,求出第 k 大值恰好是 w 的连通块个数。
其实可以这样考虑,就是求 “第 k 大值 >= w ” 的连通块的个数。设为 f[ w ] 的话,答案就是 \( \sum\limits_{w=1}^{W}f[w] \) 。这样对于一个 w 就会被算 w 遍,恰好是答案。
然后就可以枚举 w ,令 \( dp[i][j] \) 表示以 i 为根的子树、有 j 个点的值 >=w 的连通块个数。 j 是和子树大小有关的,所以 DP 是 n3 。
正解是这样考虑:
用整体 DP 的思想,考虑一次 DP 把所有 w 的答案都做出来。
那么 DP 的时候每个点就记录了 n2 个值,来表示有 j 个点的值 >= w 的连通块个数。形如:
转移的时候就是当前点 cr 与孩子 v 的横着的格子对应转移;对于一个横着的格子伸出去的竖着的数组,转移形如卷积的样子。
所以考虑把每个竖着的数组写成一个多项式的样子。\( dp[cr][w]=\sum\limits_{i=0}^{\infty}a_i * x^i \) 这样。
如果把多个 w 像这样同时 DP ,可以发现一个点自己对数组的影响就是:
1.在 w<a[cr] 的那些位置的多项式上 +1,表示多出一个 “有0个点>=w” 的连通块;
2.在 w>=a[cr] 的那些位置的多项式上 *x ,表示原来 “有 j 个点 >=w ” 的连通块会变成 “有 j+1 个点 >=w ” 的连通块。
因为一个点对整个横着的数组的操作很少,所以考虑对横着的数组用动态开点线段树维护。
比如可能经过了好几个点,对 w=3 这个位置和 w=4 这个位置的操作全都是 *x ,那么 w=3 和 w=4 的这两个位置就不用分别维护,像粘在一起的一样打标记就行了。
所以这样可以通过把询问一起做来降低复杂度。
转移就是线段树合并。这样复杂度是 nlogn 。
但是线段树一个节点上维护了一个多项式。要合并很麻烦。所以考虑把 x 换成实际的值,这样线段树的节点上就只记录了一个值,合并的时候就是 O(1) 了。
把 x 换成实际的值求出来的结果是多项式的点值。所以做 n+1 遍求出 n+1 个点值,就能用拉格朗日插值 O(n2) 还原出系数了。
已知原来要求的多项式是 \( \sum\limits_{i=0}^{\infty}a_i * x^i \) ,答案是每个点的线段树所有叶子的多项式中 次数>=k 的那些项的系数和。“所有叶子”表示考虑所有 w 的情况。
那个“每个点”很不好。所以考虑再记一个多项式表示“子树里所有点”的情况,即 \( \sum\limits_{i=0}^{\infty}(\sum\limits_{j \in tree_i}a_{j,i} ) * x^i \) 。
线段树转移的种种就参见参考的那些博客……
注意 unsigned int 类型的所有数都是 >=0 的,也就是它的负数也是 >=0 的。所以 upt( ) 不能写 if(x<0) ... 这样。
自己写(抄)的不知为何常数很大,在洛谷上只能60分,在 LOJ 上可以垫底 AC 。
#include<cstdio> #include<cstring> #include<algorithm> #define u32 unsigned int using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=1670,M=N*N<<1; u32 mod=64123; u32 upt(u32 x){while(x>=mod)x-=mod;return x;} u32 pw(int x,int k) {u32 ret=1;while(k){if(k&1)ret=ret*x%mod;x=x*x%mod;k>>=1;}return ret;} int n,m,k,a[N],hd[N],xnt,to[N<<1],nxt[N<<1]; int rt[N],tot,ls[M],rs[M],dpl[M],dtop; struct Node{ u32 a,b,c,d; Node(u32 a=1,u32 b=0,u32 c=0,u32 d=0):a(a),b(b),c(c),d(d) {} Node operator* (const Node &t)const { return Node(a*t.a%mod,upt(b*t.a%mod+t.b), upt(a*t.c%mod+c),upt(b*t.c%mod+t.d+d)); } void init(){a=1;b=c=d=0;} }vl[M]; int nwnd() { if(!dtop)return ++tot; int ret=dpl[dtop--]; vl[ret].init(); return ret; } void del(int &x) { if(ls[x])del(ls[x]); if(rs[x])del(rs[x]); dpl[++dtop]=x; x=0; } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void pshd(int cr) { if(!ls[cr])ls[cr]=nwnd(); if(!rs[cr])rs[cr]=nwnd(); vl[ls[cr]]=vl[ls[cr]]*vl[cr]; vl[rs[cr]]=vl[rs[cr]]*vl[cr]; vl[cr].init(); } void mdfy(int l,int r,int &cr,int L,int R,Node k) { if(L>R)return; if(!cr)cr=nwnd(); if(l>=L&&r<=R){vl[cr]=vl[cr]*k;return;} int mid=l+r>>1; pshd(cr); if(L<=mid)mdfy(l,mid,ls[cr],L,R,k); if(mid<R)mdfy(mid+1,r,rs[cr],L,R,k); } u32 qry(int l,int r,int cr) { if(l==r)return vl[cr].d; int mid=l+r>>1; pshd(cr); return upt(qry(l,mid,ls[cr])+qry(mid+1,r,rs[cr])); } void mrg(int &x,int &y) { if(!x)swap(x,y);if(!y)return;// if(!ls[x]&&!rs[x])swap(x,y); if(!ls[y]&&!rs[y]) { vl[x]=vl[x]*Node(vl[y].b,0,0,vl[y].d); return; } pshd(x); pshd(y);///// mrg(ls[x],ls[y]); mrg(rs[x],rs[y]); } void dfs(int cr,int fa,int x) { mdfy(1,m,rt[cr],1,m,Node(0,1,0,0));//f=1//m not n! //but g isn't changed so del() for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr,x); mrg(rt[cr],rt[v]); del(rt[v]); } mdfy(1,m,rt[cr],1,a[cr],Node(x,0,0,0)); mdfy(1,m,rt[cr],1,m,Node(1,0,1,0)*Node(1,1,0,0)); } int c[N],f[N],g[N],inv[N],ans[N]; void solve() { for(int x=1;x<=n+1;x++) { dfs(1,0,x); c[x]=qry(1,m,rt[1]);//m not n!! del(rt[1]);///when dfs del(rt[v])//time?n^2 } f[0]=mod-1; f[1]=1; for(int i=2;i<=n+1;i++) { for(int j=n+1;j;j--) f[j]=upt((mod-i)*f[j]%mod+f[j-1]); f[0]=(mod-i)*f[0]%mod; } inv[1]=1; for(int i=2;i<=n+1;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod; u32 ans=0; for(int i=1;i<=n+1;i++) { g[0]=(mod-f[0])*inv[i]%mod; for(int j=1;j<=n;j++) g[j]=upt(g[j-1]-f[j]+mod)*inv[i]%mod; u32 pl=0; for(int j=k;j<=n;j++)pl=upt(pl+g[j]); u32 ml=c[i]; for(int j=1;j<=n+1;j++) if(j<i) ml=ml*inv[i-j]%mod; else if(j>i) ml=ml*(mod-inv[j-i])%mod; ans=upt(ans+ml*pl%mod); } printf("%d\n",ans); } int main() { n=rdn();k=rdn();m=rdn(); for(int i=1;i<=n;i++)a[i]=rdn(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); solve(); return 0; }