2017 西安网络赛A Tree(树上静态查询,带权并查集,矩阵乘法压位,好题)
题意:
给出 \(n(n \leq 3000)\) 个结点的一棵树,树上每个结点有一个 \(64 \times 64\) 的 \(0,1\)矩阵,每个结点上的矩阵是根据输入的 \(seed\) (unsigned long long)生成的,给出 \(q\) 个询问 \((u,v)\) ,询问 \(u→v\) 的路上(包括 \(u,v\) )的矩阵相乘(膜2乘)的结果 \(M\),输出\((\sum_{i=1}^{64} \sum_{j=1}^{64} M_{ij} * 19^i *26^j) mod 19260817\)。
题解:
比赛时写了个树链剖分,写了很久还gg了,最后T了,赛后听大家说树链剖分主要是动态查询,所以复杂度高,有很多静态查询的方法,比如树分治等。
这题讲一个带权并查集的做法:(来自 \(quailty\) )
矩阵乘法压位显然,每个询问 \(u→v\) 拆成 \(u→lca\) 和 \(lca→v\) (不包括\(lca\)) ,在树上 \(DFS\) 一遍,从 \(u\) 子树出来时处理 \(u→v\) 和 \(v→u(v∈T_u)\) 的询问,类似 \(tarjan\) 求 \(LCA\) 的思路,带权并查集维护 \(T_u\) 内的点到 \(u\) 的两个方向的路径矩阵乘积,完成 \(u\) 处的询问后将 \(u\) 的在并查集上的根设为\(fa(u)\) ,复杂度\(O((n+q)α(n)64^2)\) 。
(1)对于询问 \((u,v)\),拆成 \(u→lca\) 和 \(tv→v\) , \(tv\) 是 \(lca\) 到 \(v\) 的路径上 \(lca\) 的儿子结点。
(2) 由于是矩阵相乘,所以得注意方向,路径拆成两部分后,前一部分是向上乘,后一部分是向下乘,带权并查集中除了维护每个点的父亲结点 \(pa[x]\) 之外,还要维护两个矩阵 \(up[x],dw[x]\),分别表示 \(x\) 向上到根结点的矩阵乘(不包括根结点)和根结点向下到 \(x\) 的矩阵乘(不包括根结点),在并查集路径压缩时更新 \(up[x],dw[x]\),完成 \(u\) 处的询问后将 \(u\) 的在并查集上的根设为\(fa(u)\)。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
#define dbg(...) cerr<<"["<<#__VA_ARGS__":"<<(__VA_ARGS__)<<"]"<<endl;
typedef vector<int> VI;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=19260817;
const int maxn=3000+10;
int head[maxn];
struct edge
{
int to,next;
}e[maxn*2]; //
int tol=0;
void add(int u,int v)
{
e[++tol].to=v,e[tol].next=head[u],head[u]=tol;
}
int deep[maxn],fa[maxn][13];
void bfs(int rt)
{
queue<int> q;
deep[rt] = 0;
fa[rt][0] = rt;
q.push(rt);
while(!q.empty())
{
int t = q.front();
q.pop();
for(int i = 1 ; i <= 12 ; i++)
fa[t][i] = fa[fa[t][i-1]][i-1];
for(int i = head[t] ; i ; i = e[i].next)
{
int v = e[i].to;
if(v == fa[t][0]) continue;
deep[v] = deep[t]+1;
fa[v][0] = t;
q.push(v);
}
}
}
int lca(int u,int v)
{
if(deep[u] > deep[v]) swap(u,v);
int hu = deep[u],hv = deep[v];
int tu = u,tv = v;
for(int det = hv-hu, i = 0; det ;det>>=1, i++)
if(det&1)
tv = fa[tv][i];
if(tu == tv) return tu;
for(int i = 12 ; i>=0 ; i--)
{
if(fa[tu][i] == fa[tv][i]) continue;
tu = fa[tu][i];
tv = fa[tv][i];
}
return fa[tu][0];
}
int up(int u,int k)
{
int tu=u;
for(int det = k,i = 0;det;det >>= 1, i++)
if(det&1)
tu = fa[tu][i];
return tu;
}
struct Matrix
{
ull a[65];
Matrix()
{
memset(a,0,sizeof(a));
}
void clear()
{
memset(a,0,sizeof(a));
}
void init()
{
rep(i,0,64) a[i]=1ull<<i; //
}
Matrix operator * (const Matrix &B)const
{
Matrix C;
rep(i,0,64)
rep(k,0,64)
if(a[i]>>k&1)
C.a[i]^=B.a[k];
return C;
}
}M[maxn];
struct DSU
{
int pa[maxn];
Matrix up[maxn],dw[maxn];
void init(int n)
{
rep(i,1,n+1) pa[i]=i,up[i].init(),dw[i].init();
}
int find(int x)
{
if(pa[x]==x) return x;
int f=find(pa[x]);
up[x]=up[x]*up[pa[x]];
dw[x]=dw[pa[x]]*dw[x];
return pa[x]=f;
}
void Union(int x,int y)
{
int fx=find(x),fy=find(y);
if(fx==fy) return;
if(deep[fx]<deep[fy])
swap(fx,fy);
pa[fx]=fy;
up[fx]=dw[fx]=M[fx];
}
}dsu;
struct Query
{
int x,id,kd;
Query(int a=0,int b=0,int c=0):x(a),id(b),kd(c) {}
};
vector<Query> query[maxn];
Matrix ans[maxn*10][2];
void dfs(int u,int f)
{
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==f) continue;
dfs(v,u);
dsu.Union(v,u);
}
for(auto item:query[u])
{
int x=item.x;
dsu.find(x);
if(!item.kd) ans[item.id][0]=dsu.up[x]*M[u];
else ans[item.id][1]=M[u]*dsu.dw[x];
}
}
ll f1[66],f2[66];
int main()
{
f1[0]=f2[0]=1ll;
rep(i,1,65) f1[i]=(1ll*f1[i-1]*19)%mod,f2[i]=(1ll*f2[i-1]*26)%mod;
int n,q;
while(~scanf("%d%d",&n,&q))
{
tol=0;
rep(i,1,n+1) head[i]=0,M[i].clear(),query[i].clear();
rep(i,1,n)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
bfs(1);
dsu.init(n);
ull seed;
scanf("%llu",&seed);
rep(i,1,n+1) rep(p,1,65)
{
seed^=seed*seed+15;
rep(q,1,65)
M[i].a[p-1]|=seed&(1ull<<(q-1));
}
rep(i,1,q+1)
{
int u,v;
scanf("%d%d",&u,&v);
if(u==v)
{
ans[i][0]=M[u];
ans[i][1].init();
continue;
}
int f=lca(u,v);
query[f].push_back(Query(u,i,0));
if(v!=f)
{
int tv=up(v,deep[v]-deep[f]-1);
query[tv].pb(Query(v,i,1));
}
else ans[i][1].init();
}
dfs(1,0);
rep(_,1,q+1)
{
ans[_][0]=ans[_][0]*ans[_][1];
ll res=0;
rep(i,0,64) rep(j,0,64) res=(res+1ll*(ans[_][0].a[i]>>j&1)*f1[i+1]*f2[j+1]%mod)%mod;
printf("%lld\n",res%mod);
}
}
return 0;
}
\(quailty\) 代码:
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const int mod=19260817;
const int N=3005;
const int Q=30005;
vector<int> e[N];
int deep[N], f[N][12];
void read(int &x)
{
char ch;
while(!isdigit(ch=getchar()));
x=ch-'0';
while(isdigit(ch=getchar()))
x=x*10+ch-'0';
}
void dfs(int x,int pre)
{
deep[x]=deep[pre]+1;
for(auto &y:e[x])
if(y!=pre)
{
f[y][0]=x;
for(int i=1;i<=11;++i)
f[y][i]=f[f[y][i-1]][i-1];
dfs(y,x);
}
}
int LCA(int x,int y)
{
if(deep[x]>deep[y]) swap(x,y);
for(int i=11;i>=0;--i)
if(deep[f[y][i]]>=deep[x])
y=f[y][i];
if(x==y) return x;
for(int i=11;i>=0;--i)
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
int up(int x,int k)
{
for(int i=11;i>=0;--i)
if((k>>i)&1)
x=f[x][i];
return x;
}
struct Matrix
{
ull a[64];
Matrix()
{
memset(a,0,sizeof(a));
}
void clear()
{
memset(a,0,sizeof(a));
}
void init()
{
for(int i=0;i<64;i++)
a[i]=(1ULL<<i);
}
Matrix operator * (const Matrix &B)const
{
Matrix C;
for(int i=0;i<64;i++)
for(int j=0;j<64;j++)
if(a[i]>>j&1)
C.a[i]^=B.a[j];
return C;
}
}M[N],res[Q][2];
struct DSU
{
int fa[N];
Matrix up[N],dw[N];
void Init(int n)
{
for(int i=1;i<=n;i++)
fa[i]=i,up[i].init(),dw[i].init();
}
int Find(int x)
{
if(fa[x]==x)return x;
int f=Find(fa[x]);
up[x]=up[x]*up[fa[x]];
dw[x]=dw[fa[x]]*dw[x];
return fa[x]=f;
}
void Union(int x,int y)
{
x=Find(x),y=Find(y);
if(x==y)return;
if(deep[x]<deep[y])
swap(x,y);
fa[x]=y;
up[x]=dw[x]=M[x];
}
}dsu;
struct path
{
int x,o,d;
path(){}
path(int _x,int _o,int _d):x(_x),o(_o),d(_d){}
};
vector<path> que[N];
void dfs2(int u,int pre)
{
for(int i=0;i<(int)e[u].size();i++)
{
int v=e[u][i];
if(v==pre)continue;
dfs2(v,u);
dsu.Union(u,v);
}
for(int i=0;i<(int)que[u].size();i++)
{
int r=dsu.Find(que[u][i].x);
if(que[u][i].d==0)res[que[u][i].o][0]=dsu.up[que[u][i].x]*M[r];
else res[que[u][i].o][1]=M[r]*dsu.dw[que[u][i].x];
}
}
int main()
{
int n,m;
while(scanf("%d%d",&n,&m)!=EOF)
{
for(int i=1;i<=n;++i)
e[i].clear(),M[i].clear(),que[i].clear();
for(int i=1;i<n;++i)
{
int x,y;
read(x);read(y);
e[x].push_back(y);
e[y].push_back(x);
}
ull seed;
scanf("%llu",&seed);
for(int i=1;i<=n;++i)
for(int j=0;j<64;++j)
{
seed^=seed*seed+15;
for(int k=0;k<64;++k)
M[i].a[j]|=seed&(1ULL<<k);
}
dfs(1,0);
for(int i=1;i<=m;++i)
{
int x,y;
read(x);read(y);
int lca=LCA(x,y);
que[lca].push_back(path(x,i,0));
if(lca!=y)
{
que[up(y,deep[y]-deep[lca]-1)].push_back(path(y,i,1));
}
}
for(int i=1;i<=m;i++)
res[i][0].init(),res[i][1].init();
dsu.Init(n);
dfs2(1,0);
for(int i=1;i<=m;i++)
res[i][0]=res[i][0]*res[i][1];
for(int _=1;_<=m;_++)
{
int tmp=0;
for(int i=0,p=19;i<64;i++,p=19LL*p%mod)
for(int j=0,q=26;j<64;j++,q=26LL*q%mod)
tmp=(tmp+1LL*(res[_][0].a[i]>>j&1)*p*q)%mod;
printf("%d\n",tmp);
}
}
return 0;
}