2017 西安网络赛A Tree(树上静态查询,带权并查集,矩阵乘法压位,好题)

题目链接

题意:

给出 \(n(n \leq 3000)\) 个结点的一棵树,树上每个结点有一个 \(64 \times 64\)\(0,1\)矩阵,每个结点上的矩阵是根据输入的 \(seed\) (unsigned long long)生成的,给出 \(q\) 个询问 \((u,v)\) ,询问 \(u→v\) 的路上(包括 \(u,v\) )的矩阵相乘(膜2乘)的结果 \(M\),输出\((\sum_{i=1}^{64} \sum_{j=1}^{64} M_{ij} * 19^i *26^j) mod 19260817\)

题解:

比赛时写了个树链剖分,写了很久还gg了,最后T了,赛后听大家说树链剖分主要是动态查询,所以复杂度高,有很多静态查询的方法,比如树分治等。

这题讲一个带权并查集的做法:(来自 \(quailty\) )

矩阵乘法压位显然,每个询问 \(u→v\) 拆成 \(u→lca\)\(lca→v\) (不包括\(lca\)) ,在树上 \(DFS\) 一遍,从 \(u\) 子树出来时处理 \(u→v\)\(v→u(v∈T_u)\) 的询问,类似 \(tarjan\)\(LCA\) 的思路,带权并查集维护 \(T_u\) 内的点到 \(u\) 的两个方向的路径矩阵乘积,完成 \(u\) 处的询问后将 \(u\) 的在并查集上的根设为\(fa(u)\) ,复杂度\(O((n+q)α(n)64^2)\)

(1)对于询问 \((u,v)\),拆成 \(u→lca\)\(tv→v\)\(tv\)\(lca\)\(v\) 的路径上 \(lca\) 的儿子结点。

(2) 由于是矩阵相乘,所以得注意方向,路径拆成两部分后,前一部分是向上乘,后一部分是向下乘,带权并查集中除了维护每个点的父亲结点 \(pa[x]\) 之外,还要维护两个矩阵 \(up[x],dw[x]\),分别表示 \(x\) 向上到根结点的矩阵乘(不包括根结点)和根结点向下到 \(x\) 的矩阵乘(不包括根结点),在并查集路径压缩时更新 \(up[x],dw[x]\),完成 \(u\) 处的询问后将 \(u\) 的在并查集上的根设为\(fa(u)\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
#define dbg(...) cerr<<"["<<#__VA_ARGS__":"<<(__VA_ARGS__)<<"]"<<endl;
typedef vector<int> VI;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=19260817;
const int maxn=3000+10;

int head[maxn];
struct edge
{
    int to,next;
}e[maxn*2];   //
int tol=0;
void add(int u,int v)
{
    e[++tol].to=v,e[tol].next=head[u],head[u]=tol;
}
int deep[maxn],fa[maxn][13];

void bfs(int rt)
{
    queue<int> q;
    deep[rt] = 0;
    fa[rt][0] = rt;
    q.push(rt);
    while(!q.empty())
    {
        int t = q.front();
        q.pop();
        for(int i = 1 ; i <= 12 ; i++)
            fa[t][i] = fa[fa[t][i-1]][i-1];
        for(int i = head[t] ; i ; i = e[i].next)
        {
            int v = e[i].to;
            if(v == fa[t][0]) continue;
            deep[v] = deep[t]+1;
            fa[v][0] = t;
            q.push(v);
        }
    }
}

int lca(int u,int v)
{
    if(deep[u] > deep[v]) swap(u,v);
    int hu = deep[u],hv = deep[v];
    int tu = u,tv = v;
    for(int det = hv-hu, i = 0; det ;det>>=1, i++)
        if(det&1)
            tv = fa[tv][i];
    if(tu == tv) return tu;
    for(int i = 12 ; i>=0 ; i--)
    {
        if(fa[tu][i] == fa[tv][i]) continue;
        tu = fa[tu][i];
        tv = fa[tv][i];
    }
    return fa[tu][0];
}

int up(int u,int k)
{
    int tu=u;
    for(int det = k,i = 0;det;det >>= 1, i++)
        if(det&1)
            tu = fa[tu][i];
    return tu;
}

struct Matrix
{
    ull a[65];
    Matrix()
    {
        memset(a,0,sizeof(a));
    }
    void clear()
    {
        memset(a,0,sizeof(a));
    }
    void init()
    {
        rep(i,0,64) a[i]=1ull<<i; //
    }
    Matrix operator * (const Matrix &B)const
    {
        Matrix C;
        rep(i,0,64)
            rep(k,0,64)
            if(a[i]>>k&1)
                C.a[i]^=B.a[k];
        return C;
    }
}M[maxn];

struct DSU
{
    int pa[maxn];
    Matrix up[maxn],dw[maxn];
    void init(int n)
    {
        rep(i,1,n+1) pa[i]=i,up[i].init(),dw[i].init();
    }
    int find(int x)
    {
        if(pa[x]==x) return x;
        int f=find(pa[x]);
        up[x]=up[x]*up[pa[x]];
        dw[x]=dw[pa[x]]*dw[x];
        return pa[x]=f;
    }
    void Union(int x,int y)
    {
        int fx=find(x),fy=find(y);
        if(fx==fy) return;
        if(deep[fx]<deep[fy])
            swap(fx,fy);
        pa[fx]=fy;
        up[fx]=dw[fx]=M[fx];
    }
}dsu;

struct Query
{
    int x,id,kd;
    Query(int a=0,int b=0,int c=0):x(a),id(b),kd(c) {}
};

vector<Query> query[maxn];
Matrix ans[maxn*10][2];

void dfs(int u,int f)
{
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f) continue;
        dfs(v,u);
        dsu.Union(v,u);
    }
    for(auto item:query[u])
    {
        int x=item.x;
        dsu.find(x);
        if(!item.kd) ans[item.id][0]=dsu.up[x]*M[u];
        else ans[item.id][1]=M[u]*dsu.dw[x];
    }
}
ll f1[66],f2[66];
int main()
{
    f1[0]=f2[0]=1ll;
    rep(i,1,65) f1[i]=(1ll*f1[i-1]*19)%mod,f2[i]=(1ll*f2[i-1]*26)%mod;
    int n,q;
    while(~scanf("%d%d",&n,&q))
    {
        tol=0;
        rep(i,1,n+1) head[i]=0,M[i].clear(),query[i].clear();
        rep(i,1,n)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            add(u,v);add(v,u);
        }
        bfs(1);
        dsu.init(n);
        ull seed;
        scanf("%llu",&seed);
        rep(i,1,n+1) rep(p,1,65)
        {
            seed^=seed*seed+15;
            rep(q,1,65)
            M[i].a[p-1]|=seed&(1ull<<(q-1));
        }
        rep(i,1,q+1)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            if(u==v)
            {
                ans[i][0]=M[u];
                ans[i][1].init();
                continue;
            }
            int f=lca(u,v);
            query[f].push_back(Query(u,i,0));
            if(v!=f)
            {
                int tv=up(v,deep[v]-deep[f]-1);
                query[tv].pb(Query(v,i,1));
            }
            else ans[i][1].init();
        }
        dfs(1,0);
        rep(_,1,q+1)
        {
            ans[_][0]=ans[_][0]*ans[_][1];
            ll res=0;
            rep(i,0,64) rep(j,0,64) res=(res+1ll*(ans[_][0].a[i]>>j&1)*f1[i+1]*f2[j+1]%mod)%mod;
            printf("%lld\n",res%mod);
        }
    }
    return 0;
}

\(quailty\) 代码:

#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const int mod=19260817;
const int N=3005;
const int Q=30005;
vector<int> e[N];
int deep[N], f[N][12];
void read(int &x)
{
    char ch;
    while(!isdigit(ch=getchar()));
    x=ch-'0';
    while(isdigit(ch=getchar()))
        x=x*10+ch-'0';
}
void dfs(int x,int pre)
{
    deep[x]=deep[pre]+1;
    for(auto &y:e[x])
    if(y!=pre)
    {
        f[y][0]=x;
        for(int i=1;i<=11;++i)
            f[y][i]=f[f[y][i-1]][i-1];
        dfs(y,x);
    }
}
int LCA(int x,int y)
{
    if(deep[x]>deep[y]) swap(x,y);
    for(int i=11;i>=0;--i)
        if(deep[f[y][i]]>=deep[x])
            y=f[y][i];
    if(x==y) return x;
    for(int i=11;i>=0;--i)
        if(f[x][i]!=f[y][i])
        {
            x=f[x][i];
            y=f[y][i];
        }
    return f[x][0];
}
int up(int x,int k)
{
    for(int i=11;i>=0;--i)
    if((k>>i)&1)
        x=f[x][i];
    return x;
}
struct Matrix
{
    ull a[64];
    Matrix()
    {
        memset(a,0,sizeof(a));
    }
    void clear()
    {
        memset(a,0,sizeof(a));
    }
    void init()
    {
        for(int i=0;i<64;i++)
            a[i]=(1ULL<<i);
    }
    Matrix operator * (const Matrix &B)const
    {
        Matrix C;
        for(int i=0;i<64;i++)
            for(int j=0;j<64;j++)
                if(a[i]>>j&1)
                    C.a[i]^=B.a[j];
        return C;
    }
}M[N],res[Q][2];
struct DSU
{
    int fa[N];
    Matrix up[N],dw[N];
    void Init(int n)
    {
        for(int i=1;i<=n;i++)
            fa[i]=i,up[i].init(),dw[i].init();
    }
    int Find(int x)
    {
        if(fa[x]==x)return x;
        int f=Find(fa[x]);
        up[x]=up[x]*up[fa[x]];
        dw[x]=dw[fa[x]]*dw[x];
        return fa[x]=f;
    }
    void Union(int x,int y)
    {
        x=Find(x),y=Find(y);
        if(x==y)return;
        if(deep[x]<deep[y])
            swap(x,y);
        fa[x]=y;
        up[x]=dw[x]=M[x];
    }
}dsu;
struct path
{
    int x,o,d;
    path(){}
    path(int _x,int _o,int _d):x(_x),o(_o),d(_d){}
};
vector<path> que[N];
void dfs2(int u,int pre)
{
    for(int i=0;i<(int)e[u].size();i++)
    {
        int v=e[u][i];
        if(v==pre)continue;
        dfs2(v,u);
        dsu.Union(u,v);
    }
    for(int i=0;i<(int)que[u].size();i++)
    {
        int r=dsu.Find(que[u][i].x);
        if(que[u][i].d==0)res[que[u][i].o][0]=dsu.up[que[u][i].x]*M[r];
        else res[que[u][i].o][1]=M[r]*dsu.dw[que[u][i].x];
    }
}
int main()
{
    int n,m;
    while(scanf("%d%d",&n,&m)!=EOF)
    {
        for(int i=1;i<=n;++i)
            e[i].clear(),M[i].clear(),que[i].clear();
        for(int i=1;i<n;++i)
        {
            int x,y;
            read(x);read(y);
            e[x].push_back(y);
            e[y].push_back(x);
        }
        ull seed;
        scanf("%llu",&seed);
        for(int i=1;i<=n;++i)
            for(int j=0;j<64;++j)
            {
                seed^=seed*seed+15;
                for(int k=0;k<64;++k)
                    M[i].a[j]|=seed&(1ULL<<k);
            }
        dfs(1,0);
        for(int i=1;i<=m;++i)
        {
            int x,y;
            read(x);read(y);
            int lca=LCA(x,y);
            que[lca].push_back(path(x,i,0));
            if(lca!=y)
            {
                que[up(y,deep[y]-deep[lca]-1)].push_back(path(y,i,1));
            }
        }
        for(int i=1;i<=m;i++)
            res[i][0].init(),res[i][1].init();
        dsu.Init(n);
        dfs2(1,0);
        for(int i=1;i<=m;i++)
            res[i][0]=res[i][0]*res[i][1];
        for(int _=1;_<=m;_++)
        {
            int tmp=0;
            for(int i=0,p=19;i<64;i++,p=19LL*p%mod)
                for(int j=0,q=26;j<64;j++,q=26LL*q%mod)
                    tmp=(tmp+1LL*(res[_][0].a[i]>>j&1)*p*q)%mod;
            printf("%d\n",tmp);
        }
    }
    return 0;
}
posted @ 2017-10-04 19:08  tarjan's  阅读(122)  评论(0编辑  收藏  举报