cf 995F Cowmpany Cowmpensation - 拉格朗日插值 + 树形dp

传送门
第一次独立完成div1的F题,😭,不容易啊,因为数组开小了,导致MLE了,不应该是返回RE吗?cf绝了。。。
题意就是每个人,除了根,有一个上司,也就是一棵树,然后进行分工资,上司的工资必须≥下属的工资
如果说d比价小,那么就直接用树形dp解决
也就是说对于每个点,求出当前点的工资为i时,有多少种情况,设dp[u][i],表示对于u结点来说,工资为i时有多少种情况
比如对于1个结点u,求工资为3时的个数,有孩子结点v,只需要把dp[u][3] *= (dp[v][1] + dp[v][2] + dp[v][3])即可,也就是对于每个工资,乘上孩子结点的前缀和
但是d比较大,那么就想到用拉格朗日插值法去做,先求出n个连续数字的点,然后对d数字进行插值求出即可

对于n个数字进行拉格朗日插值时,取完模后对答案无影响

开始wa了因为我怕爆ll,但你在公式里可以发现,y[i]还是要取模的
然后我还怕拉格朗日插值的准确性会有问题,事实发现,拉格朗日的准确性是很高的,而且点越多越准确

代码是用\(O(n)\)求拉格朗日插值的,其实\(O(n^2)\)也能求

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <cmath>
#include <stack>
#include <algorithm>
#include <ctime>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define make make_pair
#define fi first
#define se second
#define vi std::vector<int>;
#define DEBUG cout << "debug" << endl
#define CLOSE ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr)
#define CASE int Kase=0;cin>>Kase;for(int kase=1;kase<=Kase;kase++)
const int N = 3000 + 5;
const int M = 3000 + 5;
const int MOD = 1e9 + 7;
const int CM = 998244353;
const int INF = 0x3f3f3f3f;
const double eps = 1e-6;
template <typename T>
T MAX(T a, T b) { return a > b ? a : b; }
template <typename T>
T MIN(T a, T b) { return a > b ? b : a; }
template <typename T>
T GCD(T a, T b){ return b == 0 ? a : GCD(b, a % b); }
template <typename T>
T LCM(T a, T b){ return a / GCD(a, b) * b; }
template <typename T>
T ABS(T a, T b) { return a > 0 ? a : -a; }
template <typename T>
T ADD(T a, T b, T MOD){ return (a + b) % MOD; }
template <typename T>
T DEL(T a, T b, T MOD) { return ((a - b) % MOD + MOD) % MOD; }
struct Edge{
    int to, next;
}e[M << 1];
int head[N], tot;
void add(int u, int v){
    e[++tot].to = v;
    e[tot].next = head[u];
    head[u] = tot;
}
ll dp[N][N];
ll pre[N];
ll a[N];
int n, d, m;
struct Lagrange{
    int n;
    int mod;
    ll y[N];
    void init(int n, int mod, ll *y){
        this->n = n, this->mod = mod;
        for(int i = 1; i <= n; i++)
            this->y[i] = y[i];
    }
    ll pow(ll a, ll b, ll p){
        ll ans = 1; a %= p;
        while(b){
            if(b & 1) ans = ans * a % p;
            a = a * a % p;
            b >>= 1;
        }
        return ans;
    }
    inline ll add(ll x, ll y){return x + y >= mod ? x + y - mod : x + y;}
    inline ll dec(ll x, ll y){return x - y < 0 ? x - y + mod : x - y;}
    inline ll mul(ll x, ll y){return x * y % mod;}
    ll pre[N], shuff[N]; // 分子前缀积和后缀积
    ll fac[N], inv[N];
    void Pre(int k){
        pre[0] = 1, shuff[n + 1] = 1; fac[0] = 1;
        for(int i = 1; i <= n; i++) pre[i] = mul(pre[i - 1], dec(k, i));
        for(int i = n; i >= 1; i--) shuff[i] = mul(shuff[i + 1], dec(k, i));
        for(int i = 1; i <= n; i++) fac[i] = mul(fac[i - 1], i);
        inv[n] = pow(fac[n], mod - 2, mod);
        for(int i = n - 1; i >= 0; i--) inv[i] = mul(inv[i + 1], i + 1);
    }
    ll cal(ll k){
        Pre(k);
        ll ans = 0;
        for(int i = 1; i <= n; i++) {
            ll up = mul(pre[i - 1], shuff[i + 1]);
            ll down = mul(inv[i - 1], inv[n - i]);
            up = mul(y[i], up);
            up = mul(up, down);
            if((n - i) % 2 == 0) ans = add(ans, up);
            else ans = dec(ans, up);
        }
        return ans;
    }
} lagrange;
void dfs(int u, int fath){
    for(int j = 1; j <= m; j++) dp[u][j] = 1;
    for(int i = head[u]; i; i = e[i].next) {
        int v = e[i].to;
        if(v == fath) continue;
        dfs(v, u);
        pre[0] = 0;
        for(int j = 1; j <= m; j++) pre[j] = (pre[j - 1] + dp[v][j]) % MOD;
        for(int j = 1; j <= m; j++) dp[u][j] = dp[u][j] * pre[j] % MOD;
    }
}
void solve(){
    scanf("%d%d", &n, &d); m = n + 1;
    for(int i = 1; i < n; i++) {
        int p; scanf("%d", &p);
        add(i + 1, p); add(p, i + 1);
    }
    dfs(1, 0);
    // DEBUG;
    a[0] = 0;
    for(int i = 1; i <= m; i++)
        a[i] = (a[i - 1] + dp[1][i]) % MOD;
    lagrange.init(m, MOD, a);
    printf("%lld\n", lagrange.cal(d));
}
void TIME(){
    clock_t start, finish;
    double totaltime;
    start = clock();

    // 待测程序

    // CASE {
        solve();
    // }
    // 待测程序

    finish = clock();
    totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
    printf("\nTime:%lfms\n", totaltime * 1000);
}
int main(){
    #ifdef ONLINE_JUDGE
        // CLOSE;
        // CASE {
            solve();
        // }
    #else
        // freopen("./data/a.txt", "r", stdin);
        // freopen("./data/a.txt", "w", stdout);
        TIME();
    #endif
    return 0;
}
posted @ 2020-08-15 19:26  Emcikem  阅读(162)  评论(0编辑  收藏  举报