codeforces 955F Cowmpany Cowmpensation 树上DP+多项式插值
给一个树,每个点的权值为正整数,且不能超过自己的父节点,根节点的最高权值不超过D
问一共有多少种分配工资的方式?
题解:
A immediate simple observation is that we can compute the answer in $O(nD) $with a simple dynamic program. How to speed it up though?
To speed it up, we need the following lemma.
Lemma 1: For a tree with nn vertices, the answer is a polynomial in $D$ of degree at most nn.
We can prove this via induction, and the fact that for any polynomial $ p(x) $ of degree dd, the sum $\sum p_i $ is a polynomial in $n$ of degree $d+1$
Now the solution is easy: compute the answerfor $ 0≤D≤n $ and use interpolation to compute the answer for $ D>n $.
The complexity is $O(n^2)$ or the initial dp and $O(n)$ for the interpolation step.
#include <bits/stdc++.h> #define endl '\n' #define ll long long #define ull unsigned long long #define fi first #define se second #define mp make_pair #define pii pair<int,int> #define ull unsigned long long #define all(x) x.begin(),x.end() #pragma GCC optimize("unroll-loops") #define inline inline __attribute__( \ (always_inline, __gnu_inline__, __artificial__)) \ __attribute__((optimize("Ofast"))) __attribute__((target("sse"))) \ __attribute__((target("sse2"))) __attribute__((target("mmx"))) #define IO ios::sync_with_stdio(false); #define rep(ii,a,b) for(int ii=a;ii<=b;++ii) #define per(ii,a,b) for(int ii=b;ii>=a;--ii) #define for_node(x,i) for(int i=head[x];i;i=e[i].next) #define show(x) cout<<#x<<"="<<x<<endl #define showa(a,b) cout<<#a<<'['<<b<<"]="baidu<a[b]<<endl #define show2(x,y) cout<<#x<<"="<<x<<" "<<#y<<"="<<y<<endl #define show3(x,y,z) cout<<#x<<"="<<x<<" "<<#y<<"="<<y<<" "<<#z<<"="<<z<<endl #define show4(w,x,y,z) cout<<#w<<"="<<w<<" "<<#x<<"="<<x<<" "<<#y<<"="<<y<<" "<<#z<<"="<<z<<endl using namespace std; const int maxn=4e3+10,maxm=2e6+10; const int INF=0x3f3f3f3f; const ll mod=1e9+7; const double PI=acos(-1.0); //head ll casn,n,m,k; ll num[maxn]; ll pow_mod(ll a,ll b,ll c=mod,ll ans=1){while(b){if(b&1) ans=(a*ans)%c;a=(a*a)%c,b>>=1;}return ans;} namespace polysum { const int maxn=101000; const ll mod=1e9+7; ll a[maxn],f[maxn],g[maxn],p[maxn],p1[maxn],p2[maxn],b[maxn],h[maxn][2],C[maxn]; ll calcn(int d,ll *a,ll n) {//d´Î¶àÏîʽ(a[0-d])ÇóµÚnÏî if (n<=d) return a[n]; p1[0]=p2[0]=1; rep(i,0,d) { ll t=(n-i+mod)%mod; p1[i+1]=p1[i]*t%mod; } rep(i,0,d) { ll t=(n-d+i+mod)%mod; p2[i+1]=p2[i]*t%mod; } ll ans=0; rep(i,0,d) { ll t=g[i]*g[d-i]%mod*p1[i]%mod*p2[d-i]%mod*a[i]%mod; if ((d-i)&1) ans=(ans-t+mod)%mod; else ans=(ans+t)%mod; } return ans; } void init(int maxm) {//³õʼ»¯Ô¤´¦Àí½×³ËºÍÄæÔª(È¡Ä£³Ë·¨) f[0]=f[1]=g[0]=g[1]=1; rep(i,2,maxm+4) f[i]=f[i-1]*i%mod; g[maxm+4]=pow_mod(f[maxm+4],mod-2); per(i,1,maxm+3) g[i]=g[i+1]*(i+1)%mod; } } ll dp[maxn][maxn]; struct node {int to,next;}e[maxm];int head[maxn],nume; void add(int a,int b){e[++nume]=(node){b,head[a]};head[a]=nume;} void dfs(int now){ rep(i,1,n) dp[now][i]=1; for(int i=head[now];i;i=e[i].next){ dfs(e[i].to); rep(j,1,n) dp[now][j]=dp[now][j]*dp[e[i].to][j]%mod; } rep(i,1,n) dp[now][i]=(dp[now][i]+dp[now][i-1])%mod; } int main() { //#define test #ifdef test auto _start = chrono::high_resolution_clock::now(); freopen("in.txt","r",stdin);freopen("out.txt","w",stdout); #endif IO; cin>>n>>k; rep(i,2,n){ int a; cin>>a; add(a,i); } dfs(1); polysum::init(n+1); rep(i,0,n+1){ num[i]=dp[1][i]; } ll ans=polysum::calcn(n,num,k); cout<<ans<<endl; #ifdef test auto _end = chrono::high_resolution_clock::now(); cerr << "elapsed time: " << chrono::duration<double, milli>(_end - _start).count() << " ms\n"; fclose(stdin);fclose(stdout);system("out.txt"); #endif return 0; }