D. Fish eating fruit

题:https://nanti.jisuanke.com/t/41403

题意:求任意俩点之间距离之和模3后的三个结果的总数(原距离之和)

第一种做法:

树形dp

 

#include<bits/stdc++.h>
using namespace std;
#define pb push_back
typedef long long ll;
const int M=1e4+4;
const int mod=1e9+7;
struct node{
    int v;
    ll w;
};
ll C[M][3],S[M][3],ans[M];
//C[i][j]:表示以i为根,然后路径消耗取模后为j的路径数
//S[i][j]:表示以i为根,路径消耗取模后为j的路径总消耗 
vector<node>e[M];;
void dfs(int u,int f,ll pre){
    C[u][0]=C[u][1]=C[u][2]=0;
    S[u][0]=S[u][1]=S[u][2]=0;
    int len=e[u].size();
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i].v;
        if(v==f)
            continue;
        dfs(v,u,e[u][i].w);
        //算跨越跟节点的贡献
        for(int p=0;p<3;p++){
            for(int j=0;j<3;j++)
                for(int k=0;k<3;k++)
                    if(p==(j+k)%3) 
                        ans[p]=(ans[p]+S[u][j]*C[v][k]%mod+C[u][j]*S[v][k]%mod)%mod;
        }
        for(int j=0;j<3;j++){
            C[u][j]=(C[u][j]+C[v][j])%mod;
            S[u][j]=(S[u][j]+S[v][j])%mod;
        }
        
    }
    
    for(int i=0;i<3;i++)//算以u为跟对答案的贡献,就直接算u的每一个子树的贡献 

        ans[i]=(ans[i]+S[u][i])%mod;
    ll c[3],s[3];
    memset(c,0ll,sizeof(c));
    memset(s,0ll,sizeof(s));
    for(int i=0;i<3;i++){
        int t=(i-pre%3+3)%3;
        c[i]=(c[i]+C[u][t])%mod;
        s[i]=(s[i]+(S[u][t]+C[u][t]*pre%mod)%mod)%mod;
    }
    if(f!=0)
        c[pre%3]=(c[pre%3]+1ll)%mod,s[pre%3]=(s[pre%3]+pre)%mod;
    for(int i=0;i<3;i++)
        C[u][i]=c[i],S[u][i]=s[i];
}
int main(){
    int n;
    while(~scanf("%d",&n)){
        for(int i=0;i<=n;i++)
            e[i].clear();
        memset(S,0,sizeof(S));
        memset(C,0,sizeof(C)); 
        for(int i=1;i<n;i++){
            int u,v;
            ll w;
            for(int i=0;i<3;i++)
                ans[i]=0ll;
            scanf("%d%d%lld",&u,&v,&w);
            u++,v++;
            e[u].pb(node{v,w});
            e[v].pb(node{u,w});
        }
        dfs(1,0,0);
        printf("%lld %lld %lld\n",ans[0]*2ll%mod,ans[1]*2ll%mod,ans[2]*2ll%mod);
        
    }
    return 0;;
}
View Code

 第二种做法:

点分治

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int M=2e4+4;
const ll mod=1e9+7;
struct node{
    int v,nextt;
    ll w;
}e[M<<1];
ll sum[4],disnum[4],dissum[4];
int head[M],vis[M],sz[M],maxv[M],tot,n,maxx,root;
void addedge(int u,int v,ll w){
    e[tot].v=v;
    e[tot].nextt=head[u];
    e[tot].w=w;
    head[u]=tot++;
}
void dfssz(int u,int f){
    maxv[u]=0;
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nextt){
        int v=e[i].v;
        if(v==f||vis[v])
            continue;
        dfssz(v,u);
        sz[u]+=sz[v];
        maxv[u]=max(maxv[u],sz[v]);
    }
}
void dfsroot(int r,int u,int f){
    maxv[u]=max(maxv[u],sz[r]-sz[u]);
    if(maxx>maxv[u]){
        maxx=maxv[u];
        root=u;
    }
    for(int i=head[u];~i;i=e[i].nextt){
        int v=e[i].v;
        if(v==f||vis[v])
            continue;
        dfsroot(r,v,u);
    }
}
void dfsdis(int u,int f,ll d){
//    if(f!=-1&&d!=0)
    disnum[d%3]++;
    disnum[d%3]%=mod;
    dissum[d%3]+=d;
    dissum[d%3]%=mod;
    for(int i=head[u];~i;i=e[i].nextt){
        int v=e[i].v;
        if(vis[v]||v==f)
            continue;
        dfsdis(v,u,(d+e[i].w)%mod);
    }
}
void cal(int u,ll d,int flag){
    for(int i=0;i<3;i++)
        dissum[i]=disnum[i]=0;
    
    dfsdis(u,-1,d);
    for(int i=0;i<3;i++)
        for(int j=0;j<3;j++){
            int t=(i+j)%3;
            sum[t]=(sum[t]+flag*((disnum[i]*dissum[j]%mod+disnum[j]*dissum[i]%mod)%mod)%mod+mod)%mod;
        }
}
void solve(int u){
    maxx=n;
    dfssz(u,-1);
    dfsroot(u,u,-1);
    cal(root,0ll,1);//+
    vis[root]=1;
    for(int i=head[root];~i;i=e[i].nextt){
        int v=e[i].v;
        if(vis[v])
            continue;
        cal(v,e[i].w,-1);
        solve(v);
    }
}
int main(){
    while(~scanf("%d",&n)){
        sum[0]=sum[1]=sum[2]=0;
        tot=0;
        for(int i=0;i<=n;i++)
            head[i]=-1,vis[i]=0;;
        for(int i=1;i<n;i++){
            int u,v;
            ll w;
            scanf("%d%d%lld",&u,&v,&w);
            u++,v++;
            addedge(u,v,w);
            addedge(v,u,w);
        }
        solve(1);
        printf("%lld %lld %lld\n",sum[0],sum[1],sum[2]);
    }
    return 0;
}
View Code

 

posted @ 2019-09-15 09:11  starve_to_death  阅读(167)  评论(0编辑  收藏  举报