[LOJ3626/QOJ4889] 愚蠢的在线法官
考虑这个矩阵长啥样,首先显然 \(A\) 不能重复否则答案是 \(0\)(有两行两列相同)。
把 \(A\) 重标号为 DFS 序的顺序,那么行列式的值不改变,因为交换 \(A_i,A_j\) 相当于同时交换两行两列。
考虑把权值 \(v\) 做树上差分,令 \(B_u=v_u-v_{fa(u)}\),那么就等价于对每个 \(i\) 把 \(i\) 子树内的所有点形成的这个矩阵中的每个值都加上 \(B_i\)。那子树的 DFS 序是连续的,所以相当于做一个矩形加。
矩阵上做二维差分不影响行列式,考虑二维差分,那么矩形加就变成了四个单点加。
假设这个对应的 DFS 序区间是 \([l,r]\),那么相当于 M[l][l]++,M[r+1][r+1]++,M[l][r+1]--,M[r+1][l]--
,发现这个类似于矩阵树的 Kirchhoff 矩阵,那相当于建一个图出来,然后对每个点连边 \([l,r+1]\),要数生成树个数。
那怎么做呢,发现由于各子树 DFS 序要么包含要么不交,然后你发现任取一个子图如果想同胚于 \(K_4\) 那最后我们发现肯定会出现相交,所以这个是广义串并联图,套广义串并联方法就行了。
时间复杂度 \(O(n\log n)\)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
return x*f;
}
const int mod=998244353;
int ksm(int x,int y,int p=mod){
int ans=1;
for(int i=y;i;i>>=1,x=1ll*x*x%p)if(i&1)ans=1ll*ans*x%p;
return ans%p;
}
int inv(int x,int p=mod){return ksm(x,p-2,p)%p;}
mt19937 rnd(time(0));
int randint(int l,int r){return rnd()%(r-l+1)+l;}
void add(int &x,int v){x+=v;if(x>=mod)x-=mod;}
void Mod(int &x){if(x>=mod)x-=mod;}
int cmod(int x){if(x>=mod)x-=mod;return x;}
const int N=5e5+5;
int n,m;
vector<int>G[N];
struct Node{
int f,g;
Node(int F,int G):f(F),g(G){}
Node(){}
bool operator<(const Node &rhs)const{
if(f!=rhs.f)return f<rhs.f;
return g<rhs.g;
}
};
Node cmpr(Node wy,Node wz){return Node(1ll*wy.f*wz.f,cmod(1ll*wy.f*wz.g%mod+1ll*wy.g*wz.f%mod));}
Node twist(Node u,Node w){return Node(cmod(1ll*w.f*u.g%mod+1ll*u.f*w.g%mod),1ll*w.g*u.g%mod);}
queue<int>q;
set<pair<int,Node> >adj[N];
#define fi first
#define se second
#define mk make_pair
int dfn[N],dfc=0;
int a[N],val[N],fa[N],sz[N];
bool inq[N];
signed main(void){
#ifdef YUNQIAN
freopen("in.in","r",stdin);
#endif
n=read(),m=read();
for(int i=1;i<=n;i++)val[i]=read();
for(int i=1;i<=m;i++)a[i]=read();
for(int i=1;i<=n-1;i++){
int u=read(),v=read();
G[u].emplace_back(v),G[v].emplace_back(u);
}
function<void(int)>dfs=[&](int u){
dfn[u]=++dfc,sz[u]=1;
for(int v:G[u])if(v!=fa[u])fa[v]=u,dfs(v),sz[u]+=sz[v];
add(val[u],mod-val[fa[u]]);
};
dfs(1);
for(int i=1;i<=m;i++)a[i]=dfn[a[i]];sort(a+1,a+m+1);
vector<int>fa(m+2);
for(int i=1;i<=m+1;i++)fa[i]=i;
function<int(int)>find=[&](int x){return x==fa[x]?x:fa[x]=find(fa[x]);};
auto adde=[&](int x,int y,Node w){
if(x==y)return ;fa[find(x)]=find(y);
set<pair<int,Node> >::iterator t=adj[x].lower_bound(mk(y,Node(0,0)));
Node cur=t->se;
if(t->fi==y){
adj[x].erase(mk(y,cur)),adj[x].insert(mk(y,twist(cur,w)));
adj[y].erase(mk(x,cur)),adj[y].insert(mk(x,twist(cur,w)));
}
else adj[x].insert(mk(y,w)),adj[y].insert(mk(x,w));
};
for(int i=1;i<=n;i++){
int l=dfn[i],r=dfn[i]+sz[i]-1;
int L=lower_bound(a+1,a+m+1,l)-a;
int R=upper_bound(a+1,a+m+1,r)-a-1;
adde(L,R+1,Node(val[i],1));
}
m++;
auto chk=[&](int x){
if(inq[x])return ;
if(adj[x].size()<=2)q.push(x),inq[x]=1;
};
for(int i=1;i<=m;i++)chk(i);
for(int i=1;i<=m;i++)if(find(i)!=find(1))return puts("0"),0;
int ans=1;
while(q.size()){
int x=q.front();q.pop();
if(adj[x].size()==1){
auto t=adj[x].begin();int y=t->fi;Node w=t->se;
ans=1ll*ans*w.f%mod;
adj[y].erase(mk(x,w)),chk(y);
}
else if(adj[x].size()==2){
auto p=adj[x].begin(),q=--adj[x].end();
int y=p->fi,z=q->fi;auto wy=p->se,wz=q->se;
adj[y].erase(mk(x,wy)),adj[z].erase(mk(x,wz));
Node w=Node(1ll*wy.f*wz.f%mod,cmod(1ll*wy.f*wz.g%mod+1ll*wy.g*wz.f%mod));
auto r=adj[y].lower_bound(mk(z,Node(0,0)));
if(r->fi!=z)adj[y].insert(mk(z,w)),adj[z].insert(mk(y,w));
else{
auto u=r->se;
auto v=Node(cmod(1ll*w.f*u.g%mod+1ll*u.f*w.g%mod),1ll*w.g*u.g%mod);
adj[y].erase(mk(z,u)),adj[y].insert(mk(z,v));
adj[z].erase(mk(y,u)),adj[z].insert(mk(y,v));
chk(y),chk(z);
}
}
}
cout<<ans<<endl;
return 0;
}