Codeforces Round #665 (Div. 2) D. Maximum Distributed Tree (贪心 + dfs)

** 注意sort时不要取模!!!!!**

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<vector>
#include<string>
#include<fstream>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
#define px first
#define py second
typedef long long ll;
typedef pair<int,int> PII;
const int N = 5e5 + 5;
const int mod = 1e9 + 7;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
bool cmp(ll a, ll b){return a > b;}
//
inline int read()
{
    char c=getchar();
    int x=0,f=1;
    while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}

int T, n, m;
int head[N], cnt = 0, tot = 0, tp = 0; ll p[N], a[N];
int to[N << 1], nxt[N << 1];
ll siz[N], res;

inline void add(int u ,int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    siz[u] = 1;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs(v, u);
        siz[u] = (siz[u] + siz[v]) % mod;
        a[++ tp] = (n - siz[v]) * siz[v];
    }
}

int main()
{
    scanf("%d",&T);
    while(T --){
        scanf("%d",&n);
        cnt = tot = tp = 0;
        for(int i = 0; i <= n; ++ i) head[i] = -1;
        for(int i = 1; i < n; ++ i){
            a[i] = 0; p[i] = 0;
            int x, y; scanf("%d%d",&x,&y);
            add(x, y);
        }
        scanf("%d",&m);
        for(int i = 1; i <= m; ++ i) scanf("%lld",&p[i]);
        tot = m;
        if(m <= n - 1){
            while(tot < n - 1) p[++ tot] = 1;
            sort(p + 1, p + n);
        }
        else{
            sort(p + 1, p + m + 1);
            for(int i = n; i <= m; ++ i) {
                p[n - 1] = p[n - 1] * p[i] % mod;
            }
            tot = n - 1;
        }
        dfs(1, 0);
        sort(a + 1, a + n);
        res = 0;
        for(int i = n - 1; i >= 1; -- i){
            res = (res + a[i] * p[i] % mod) % mod;
            // cout<<a[i]<<" "<<p[i]<<endl;
        }
        printf("%lld\n",res);
    }
    return 0;
}
posted @ 2020-08-24 14:40  A_sc  阅读(97)  评论(0编辑  收藏  举报