E97 换根DP+矩阵加速 P6803 [CEOI 2020] 星际迷航

视频链接:E97 换根DP+矩阵加速 P6803 [CEOI 2020] 星际迷航_哔哩哔哩_bilibili

 

 

 

 

 

 

参考:G04 矩阵加速 P1962 斐波那契数列 - 董晓 - 博客园

P6803 [CEOI 2020] 星际迷航 - 洛谷 | 计算机科学教育新生态

复制代码
// 换根DP+矩阵加速 O(n+logD)
#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N=100005,M=1e9+7;
vector<int> e[N];
int n; LL D,m,ans;
int f[N],r[N],s[N],c[N][2];

void dfs(int x,int fa){
  for(int y:e[x]){
    if(y==fa) continue;
    dfs(y,x);
    s[x]+=!f[y]; //败点儿子的数量
    c[x][f[y]]+=r[y]; //可逆点的数量
  }
  f[x]=(s[x]>0); //若有败点儿子 则x为胜点
  if(s[x]==0) r[x]=c[x][1]+1; //没有败点儿子
  else if(s[x]==1) r[x]=c[x][0]; //只有一个败点儿子
  else r[x]=0; //有多个败点儿子
}
void dfs2(int x,int fa){
  if(f[x]==0) ++m;
  for(int y:e[x]){
    if(y==fa) continue;
    int fx=f[x],rx=r[x],sx=s[x],c0=c[x][0],c1=c[x][1];
    
    s[x]-=!f[y]; //去掉y的贡献
    c[x][f[y]]-=r[y];
    f[x]=(s[x]>0);
    if(s[x]==0) r[x]=c[x][1]+1;
    else if(s[x]==1) r[x]=c[x][0];
    else r[x]=0;
    s[y]+=!f[x]; //加上x的贡献
    c[y][f[x]]+=r[x];
    f[y]|=!f[x];
    if(s[y]==0) r[y]=c[y][1]+1;    
    else if(s[y]==1) r[y]=c[y][0];
    else r[y]=0;
    
    dfs2(y,x);
    f[x]=fx,r[x]=rx,s[x]=sx,c[x][0]=c0,c[x][1]=c1;
  }
}

struct mat{
  LL a[2][2];
  mat(){memset(a,0,sizeof a);}
  mat operator*(const mat &x){
    mat t;
    for(int i=0;i<2;i++)
    for(int j=0;j<2;j++)
    for(int k=0;k<2;k++)
      t.a[i][j]=(t.a[i][j]+a[i][k]*x.a[k][j])%M;
    return t;
  }  
}F,A;
void qpow(LL n){
  while(n){
    if(n&1) F=A*F;
    A=A*A;
    n>>=1;
  }
}
void add(LL &a,LL b){a=(a+b)%M;}
int main(){
  scanf("%d%lld",&n,&D);
  for(int i=1,x,y;i<n;i++){
    scanf("%d%d",&x,&y);
    e[x].push_back(y),e[y].push_back(x);
  }
  dfs(1,0);
  dfs2(1,0); //换根DP
  
  for(int i=1;i<=n;i++){
    if(f[i]==0){
      add(A.a[0][0],n-r[i]);
      add(A.a[0][1],n);
      add(A.a[1][0],r[i]);
    } 
    else{
      add(A.a[0][0],r[i]);
      add(A.a[1][0],n-r[i]);
      add(A.a[1][1],n);
    }
  }
  F.a[0][0]=m,F.a[1][0]=n-m;
  qpow(D-1); //矩阵加速
  
  if(f[1]==0) ans=r[1]*F.a[0][0]%M;
  else ans=((n-r[1])*F.a[0][0]+n*F.a[1][0])%M;
  printf("%d\n",ans);
}
复制代码

 

posted @   董晓  阅读(57)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
点击右上角即可分享
微信分享提示