CSP模拟 Town

题意

有一棵树,将它的一些边断开,使得每个连通块的点权异或和为给定的一个数 \(x\),求方案数。

原题好像是牛客的 NC200547,没有账号看不到题,不确定。

思路

朴素的想法是用 f[i][j] 记录“\(i\) 子树中与 \(i\) 相连的连通块异或和为 \(j\)\(i\) 子树内其他连通块异或和为 \(x\))”的方案数。

但注意到一个性质,一棵树存在合法断边方案,当且仅当树所有点的异或和为 \(0\)(拆分为偶数个连通块)或 \(x\)(拆分为奇数个连通块)。

于是我们只需要 f[i][0/1] 记录“\(i\) 子树中,已经划分出偶数/奇数个异或和为 \(x\) 的连通块,剩下没被划分的部分连通且与 \(i\) 相连”的方案数。

在做树形 dp 时,儿子之间的计数是简单的,偶方案数 \(=\) 儿子 1 偶方案数 \(\times\) 儿子 2 偶方案数 \(+\) 儿子 1 奇方案数 \(\times\) 儿子 2 奇方案数,奇方案数 \(=\) 儿子 1 偶方案数 \(\times\) 儿子 2 奇方案数 \(+\) 儿子 1 奇方案数 \(\times\) 儿子 2 偶方案数。

然后假如当前子树的异或和为 \(x\),则父亲偶方案数需要加上儿子们的奇方案数,表示将剩下的那个连通块划分出来(即子树根节点向它的父亲断边)。当前子树异或和为 \(0\) 时同理。但如果当前点是根节点的时候就不能做这步,具体原因见下。

注意特判 \(x=0\) 的情况,此时的方案数即为每个子树异或和为 \(0\) 的节点是否断掉与其父亲相连的边的方案数。

另外,输出答案时,不能直接整棵树异或和为 \(0\) 时候输出 f[1][0],为 \(x\) 的时候输出 f[1][1],这是因为 dp 数组含义是“有剩下没划分的部分”,正确的输出应该是,根节点不进行上述统计操作,整棵树异或和为 \(0\) 时输出 f[1][1],为 \(x\) 时输出 f[1][0],表示剩下的部分全部划分出一个连通块且没有剩余未划分部分

代码

#include<bits/stdc++.h>
using namespace std;
template<class T>inline void rd(T &x){
    T res=0,f=1;
    char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1; ch=getchar();}
    while(isdigit(ch)){res=res*10+ch-'0';ch=getchar();}
    x=res*f;
}
template<class T>inline void wt(T x){
    if(x<0){x=-x;putchar('-');}
    if(x>9) wt(x/10);
    putchar(x%10+'0');
}
typedef long long LL;
const int MAXN=1e6+5;
const LL P=998244353;
int n,val,a[MAXN];
LL f[MAXN][2];
struct EDGE{
    int v,nxt;
}e[MAXN<<1];
int head[MAXN],ecnt=0;
inline void add(const int& u,const int& v){
    e[++ecnt].v=v;
    e[ecnt].nxt=head[u];
    head[u]=ecnt;
}
void dfs(int x,int fa){
    f[x][0]=1;f[x][1]=0;
    LL f0,f1;
    for(int i=head[x],it;i;i=e[i].nxt){
        it=e[i].v;
        if(it==fa) continue;
        dfs(it,x);
        a[x]^=a[it];
        f0=f[x][0];f1=f[x][1];
        f[x][0]=(f0*f[it][0]%P+f1*f[it][1]%P)%P;
        f[x][1]=(f0*f[it][1]%P+f1*f[it][0]%P)%P;
    }
    if(x==1) return;
    f0=f[x][0];f1=f[x][1];
    if(a[x]==0) f[x][0]=(f[x][0]+f1)%P;
    else if(a[x]==val) f[x][1]=(f[x][1]+f0)%P;
}
int main(){
    rd(n);rd(val);
    for(int i=1;i<=n;i++){
        rd(a[i]);
    }
    for(int i=1,u,v;i<n;i++){
        rd(u);rd(v);
        add(u,v);add(v,u);
    }
    dfs(1,0);
    if(val==0){
        if(a[1]==0){
            LL ans0=1;
            for(int i=2;i<=n;i++){
                if(a[i]==0) ans0=ans0*2%P;
            }
            wt(ans0);
        }
        else wt(0);
        return 0;
    }
    if(a[1]==0) wt(f[1][1]);
    else if(a[1]==val) wt(f[1][0]);
    else wt(0);
    return 0;
}
posted @ 2024-10-25 12:52  MessageBoxA  阅读(3)  评论(0编辑  收藏  举报