WC2021题解
考场想明白的都是神仙,受我一膜!看了题解到现在我都只是模模糊糊想清楚了。
考场上我的想法是从一个点不停地走反向边扩展,就能找到所有的这种情况:\(A_1\rightarrow A_2\rightarrow A_n\rightarrow B \leftarrow C_n \leftarrow C_2 \leftarrow C_1\)
然后再写一个\(\text{floyd}\),这样复杂度大概是阶乘?不管了,拿了\(\text{32pts}\)跑路了。
正解的想法是每个点我们只拓展一条边,即找到所有的\(A\rightarrow B \leftarrow C\),大大减小了复杂度。
这样是不是计算不到\(D\rightarrow A\rightarrow B \leftarrow C \leftarrow E\)呢?是的。所以考虑一个小操作。
考虑将\(A\)和\(C\)合并,那么就可以这样\(D\rightarrow AC \leftarrow E\)计算到了。
考虑并查集按秩合并,复杂度\(\Theta(n\log n)\)或\(\Theta(n \log^2 n)\),就看你愿不愿意写哈希表了。
#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=3e5+10;
const int mod=1e9+7;
int n,m,k;
map<int,int>mp[maxn];
map<int,int>::iterator it;
queue<pair<int,int> >q;
#define fi first
#define se second
int fa[maxn],siz[maxn];
long long ans;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
inline int find(int x){
if(fa[x]==x)return x;
return find(fa[x]);
}
inline void merge(int x,int y){
int s=find(x),t=find(y);
if(s==t)return;
if(siz[s]<siz[t])swap(s,t);
fa[t]=s;siz[s]+=siz[t];
for(it=mp[t].begin();it!=mp[t].end();it++){
if(mp[s][it->fi])q.push(make_pair(it->se,mp[s][it->fi]));
else mp[s][it->fi]=it->se;
}
}
int main(){
n=read(),m=read(),k=read();
int x,y,z;
for(int i=1;i<=m;i++){
x=read(),y=read(),z=read();
if(mp[y][z])q.push(make_pair(x,mp[y][z]));
else mp[y][z]=x;
}
for(int i=1;i<=n;i++)
fa[i]=i,siz[i]=1;
while(!q.empty()){
x=q.front().fi;
y=q.front().se;
q.pop();
merge(x,y);
}
for(int i=1;i<=n;i++)
if(fa[i]==i)ans+=1ll*siz[i]*(siz[i]-1)/2;
printf("%lld\n",ans);
return 0;
}
首先显然每一位拆开做,假设我们设 \(dp_{i,j}\) 表示答案为 \(j\),我们就得到了一个 \(O(n|S|m^2)\) 的做法。
而假设我们钦定最后答案为 \(j\) 然后计算方案数,则我们去掉了一个 \(m\)。
然而若我们钦定了小于 \(j\) 的集合 \(S\),那么我们把 \(n\) 换成了 \(2^m\),至此我们有了 \(O(m2^m|S|)\) 的做法。
考虑差分,我们再省去一个 \(m\),于是 \(O(2^m|S|)\) 过了。
#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=1e5+10;
const int mod=1e9+7;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
const int N=10;
int n,m,A[N][maxn],rt,len,st[maxn],top,bl[maxn];
char S[maxn],op[maxn];
int ls[maxn],rs[maxn],cnt;
inline int nwnd(int x,int y,char o){
++cnt;ls[cnt]=x,rs[cnt]=y;
op[cnt]=o;return cnt;
}
inline int build(int l,int r){
if(l>r)return 0;
//printf("build %d %d\n",l,r);
if(bl[r]){
if(bl[r]==l)return build(l+1,r-1);
int x=build(l,bl[r]-2),y=build(bl[r]+1,r-1);
return nwnd(x,y,S[bl[r]-1]);
}int y=nwnd(0,0,S[r]),x=build(l,r-2);
if(l==r)return y;
return nwnd(x,y,S[r-1]);
}
int dp[maxn][2],coef[1<<N];
inline void dfs(int x,int sta){
dp[x][0]=dp[x][1]=0;
if(!ls[x]&&!rs[x]){
dp[x][~sta>>(op[x]-'0')&1]=1;
return;
}dfs(ls[x],sta),dfs(rs[x],sta);
for(int i=0;i<2;i++)
for(int j=0;j<2;j++){
if(op[x]!='<')dp[x][i|j]=(dp[x][i|j]+1ll*dp[ls[x]][i]*dp[rs[x]][j])%mod;
if(op[x]!='>')dp[x][i&j]=(dp[x][i&j]+1ll*dp[ls[x]][i]*dp[rs[x]][j])%mod;
}
//printf("u=%d: %d %d\n",x,dp[x][0],dp[x][1]);
}
int Pos,p[N],ans;
inline int cmp(int x,int y){return A[x][Pos]<A[y][Pos];}
int main(){
n=read(),m=read();
for(int i=0;i<m;i++)
for(int j=1;j<=n;j++)A[i][j]=read();
scanf("%s",S+1);len=strlen(S+1);
for(int i=1;i<=len;i++){
if(S[i]=='(')st[++top]=i;
else if(S[i]==')')bl[i]=st[top],top--;
}rt=build(1,len);
for(int i=0;i<(1<<m);i++)
dfs(rt,i),coef[i]=dp[rt][1];
for(int i=1;i<=n;i++){
for(int j=0;j<m;j++)p[j]=j;Pos=i;
sort(p,p+m,cmp);int T=0;
ans=(ans+1ll*coef[T]*A[p[0]][i])%mod;
for(int j=1;j<m;j++){
T|=(1<<p[j-1]);
ans=(ans+1ll*coef[T]*(A[p[j]][i]-A[p[j-1]][i]))%mod;
}
}printf("%d\n",ans);
return 0;
}