LOJ 2743(洛谷 4365) 「九省联考 2018」秘密袭击——整体DP+插值思想

题目:https://loj.ac/problem/2473

   https://www.luogu.org/problemnew/show/P4365

参考:https://blog.csdn.net/xyz32768/article/details/82952313

   https://zhang-rq.github.io/2018/05/04/%E4%B9%9D%E7%9C%81%E8%81%94%E8%80%832018-%E7%A7%98%E5%AF%86%E8%A2%AD%E5%87%BBCoaT/

   https://blog.csdn.net/qq_35649707/article/details/79923740

关于如何已知 n+1 个点值 n2 还原出 n 次多项式的系数:

   https://zhang-rq.github.io/2018/05/04/%E6%8B%89%E6%A0%BC%E6%9C%97%E6%97%A5%E6%8F%92%E5%80%BC%E6%B3%95%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/?tdsourcetag=s_pctim_aiomsg 

看到连值域也只有 1666 ,首先想到的是枚举第 k 大值是什么,设为 w 。

然后自己只能想到 n6 的 DP ……就是记录 “有 i 个 > w 的值和 j 个 = w 的值” 的连通块个数,求出第 k 大值恰好是 w 的连通块个数。

其实可以这样考虑,就是求 “第 k 大值 >= w ” 的连通块的个数。设为 f[ w ] 的话,答案就是 \( \sum\limits_{w=1}^{W}f[w] \) 。这样对于一个 w 就会被算 w 遍,恰好是答案。

然后就可以枚举 w ,令 \( dp[i][j] \) 表示以 i 为根的子树、有 j 个点的值 >=w 的连通块个数。 j 是和子树大小有关的,所以 DP 是 n3

正解是这样考虑:

用整体 DP 的思想,考虑一次 DP 把所有 w 的答案都做出来。

那么 DP 的时候每个点就记录了 n2 个值,来表示有 j 个点的值 >= w 的连通块个数。形如:

转移的时候就是当前点 cr 与孩子 v 的横着的格子对应转移;对于一个横着的格子伸出去的竖着的数组,转移形如卷积的样子。

所以考虑把每个竖着的数组写成一个多项式的样子。\( dp[cr][w]=\sum\limits_{i=0}^{\infty}a_i * x^i \) 这样。

如果把多个 w 像这样同时 DP ,可以发现一个点自己对数组的影响就是:

  1.在 w<a[cr] 的那些位置的多项式上 +1,表示多出一个 “有0个点>=w” 的连通块;

  2.在 w>=a[cr] 的那些位置的多项式上 *x ,表示原来 “有 j 个点 >=w ” 的连通块会变成 “有 j+1 个点 >=w ” 的连通块。

因为一个点对整个横着的数组的操作很少,所以考虑对横着的数组用动态开点线段树维护。

  比如可能经过了好几个点,对 w=3 这个位置和 w=4 这个位置的操作全都是 *x ,那么 w=3 和 w=4 的这两个位置就不用分别维护,像粘在一起的一样打标记就行了。

  所以这样可以通过把询问一起做来降低复杂度。

转移就是线段树合并。这样复杂度是 nlogn 。

但是线段树一个节点上维护了一个多项式。要合并很麻烦。所以考虑把 x 换成实际的值,这样线段树的节点上就只记录了一个值,合并的时候就是 O(1) 了。

把 x 换成实际的值求出来的结果是多项式的点值。所以做 n+1 遍求出 n+1 个点值,就能用拉格朗日插值 O(n2) 还原出系数了。 

已知原来要求的多项式是 \( \sum\limits_{i=0}^{\infty}a_i * x^i \) ,答案是每个点的线段树所有叶子的多项式中 次数>=k 的那些项的系数和。“所有叶子”表示考虑所有 w 的情况。

那个“每个点”很不好。所以考虑再记一个多项式表示“子树里所有点”的情况,即 \( \sum\limits_{i=0}^{\infty}(\sum\limits_{j \in tree_i}a_{j,i} ) * x^i \) 。

线段树转移的种种就参见参考的那些博客……

注意 unsigned int 类型的所有数都是 >=0 的,也就是它的负数也是 >=0 的。所以 upt( ) 不能写 if(x<0) ... 这样。

自己写(抄)的不知为何常数很大,在洛谷上只能60分,在 LOJ 上可以垫底 AC 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define u32 unsigned int
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=1670,M=N*N<<1; u32 mod=64123;
u32 upt(u32 x){while(x>=mod)x-=mod;return x;}
u32 pw(int x,int k)
{u32 ret=1;while(k){if(k&1)ret=ret*x%mod;x=x*x%mod;k>>=1;}return ret;}

int n,m,k,a[N],hd[N],xnt,to[N<<1],nxt[N<<1];
int rt[N],tot,ls[M],rs[M],dpl[M],dtop;
struct Node{
  u32 a,b,c,d;
  Node(u32 a=1,u32 b=0,u32 c=0,u32 d=0):a(a),b(b),c(c),d(d) {}
  Node operator* (const Node &t)const
  {
    return Node(a*t.a%mod,upt(b*t.a%mod+t.b),
        upt(a*t.c%mod+c),upt(b*t.c%mod+t.d+d));
  }
  void init(){a=1;b=c=d=0;}
}vl[M];
int nwnd()
{
  if(!dtop)return ++tot;
  int ret=dpl[dtop--]; vl[ret].init(); return ret;
}
void del(int &x)
{
  if(ls[x])del(ls[x]); if(rs[x])del(rs[x]);
  dpl[++dtop]=x; x=0;
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void pshd(int cr)
{
  if(!ls[cr])ls[cr]=nwnd(); if(!rs[cr])rs[cr]=nwnd();
  vl[ls[cr]]=vl[ls[cr]]*vl[cr];
  vl[rs[cr]]=vl[rs[cr]]*vl[cr]; vl[cr].init();
}
void mdfy(int l,int r,int &cr,int L,int R,Node k)
{
  if(L>R)return; if(!cr)cr=nwnd();
  if(l>=L&&r<=R){vl[cr]=vl[cr]*k;return;}
  int mid=l+r>>1; pshd(cr);
  if(L<=mid)mdfy(l,mid,ls[cr],L,R,k);
  if(mid<R)mdfy(mid+1,r,rs[cr],L,R,k);
}
u32 qry(int l,int r,int cr)
{
  if(l==r)return vl[cr].d;
  int mid=l+r>>1; pshd(cr);
  return upt(qry(l,mid,ls[cr])+qry(mid+1,r,rs[cr]));
}
void mrg(int &x,int &y)
{
  if(!x)swap(x,y);if(!y)return;//
  if(!ls[x]&&!rs[x])swap(x,y);
  if(!ls[y]&&!rs[y])
    {
      vl[x]=vl[x]*Node(vl[y].b,0,0,vl[y].d);
      return;
    }
  pshd(x); pshd(y);/////
  mrg(ls[x],ls[y]); mrg(rs[x],rs[y]);
}
void dfs(int cr,int fa,int x)
{
  mdfy(1,m,rt[cr],1,m,Node(0,1,0,0));//f=1//m not n!
  //but g isn't changed so del()
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs(v,cr,x);
    mrg(rt[cr],rt[v]);
    del(rt[v]);
      }
  mdfy(1,m,rt[cr],1,a[cr],Node(x,0,0,0));
  mdfy(1,m,rt[cr],1,m,Node(1,0,1,0)*Node(1,1,0,0));
}
int c[N],f[N],g[N],inv[N],ans[N];
void solve()
{
  for(int x=1;x<=n+1;x++)
    {
      dfs(1,0,x);
      c[x]=qry(1,m,rt[1]);//m not n!!
      del(rt[1]);///when dfs del(rt[v])//time?n^2
    }
  f[0]=mod-1; f[1]=1;
  for(int i=2;i<=n+1;i++)
    {
      for(int j=n+1;j;j--)
    f[j]=upt((mod-i)*f[j]%mod+f[j-1]);
      f[0]=(mod-i)*f[0]%mod;
    }
  inv[1]=1;
  for(int i=2;i<=n+1;i++)
    inv[i]=(mod-mod/i)*inv[mod%i]%mod;
  u32 ans=0;
  for(int i=1;i<=n+1;i++)
    {
      g[0]=(mod-f[0])*inv[i]%mod;
      for(int j=1;j<=n;j++)
    g[j]=upt(g[j-1]-f[j]+mod)*inv[i]%mod;
      u32 pl=0;
      for(int j=k;j<=n;j++)pl=upt(pl+g[j]);
      u32 ml=c[i];
      for(int j=1;j<=n+1;j++)
    if(j<i) ml=ml*inv[i-j]%mod;
    else if(j>i) ml=ml*(mod-inv[j-i])%mod;
      ans=upt(ans+ml*pl%mod);
    } 
  printf("%d\n",ans);
}
int main()
{
  n=rdn();k=rdn();m=rdn();
  for(int i=1;i<=n;i++)a[i]=rdn();
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  solve(); return 0;
}

 

posted on 2019-03-20 14:25  Narh  阅读(366)  评论(0编辑  收藏  举报

导航