题解 [UER #9] 知识网络
只会到前 50 pts /kk
暴力就是对 tag 建虚点,每个点向其 tag 连权为 0 的边,同时 tag 向这个点连权为 1 的边
然后 01 BFS就可以过 30 pts
m 较小,n 很大的话有原图中边的点数与 m 同阶,只对这些点跑最短路,剩下的与虚点答案一样
然后正解:
考虑每个 tag 对应虚点到目标点的最短路
一定是从这个虚点走到某个标签为当前 tag 的点(可行的这样的点可能有多个),再走到目标点
那么再考虑所有标签为当前 tag 的点到目标点的最短路
要么先走到这个虚点,要么直接走向目标点
那么如果建出了虚点到所有目标点的最短路 DAG,那么可以不用先走到虚点的点一定满足 目标点在 DAG 上是这个点的后继
那么反过来考虑,每个目标点对长度 \(dis+1\) (+1 是因为要求序列长度)产生的贡献就是这个标签的点的数量-这个标签的点在 DAG 上是这个点的前驱数
那么 bitset 处理即可
然后发现直接 bitset 复杂度是 \(O(k(n+m)+\frac{n(n+m)k}{\omega})\),过不去
发现操作时 bitset 中有好多位置是空的,于是用 ull 进行分块 bitset
复杂度就变成了 \(O(k(n+m)+\frac{\sum siz_i(n+m)}{\omega})=O(k(n+m)+\frac{n(n+m)}{\omega})\),可以过
- 当 bitset 中有好多位置是空的时:考虑使用 ull 进行分块 bitset 以优化复杂度
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, k;
ull s[N];
ll ans[N], sum;
int head[N], ecnt;
vector<int> to[N], bel[N];
int p[N], q[N<<1], dis[N], sta[N], cnt[N], pre[N], top, l, r;
struct edge{int to, next, val;}e[N<<1];
inline void add(int s, int t, int w) {e[++ecnt]={t, head[s], w}; head[s]=ecnt;}
void solve(int tag) {
memset(pre, 0, sizeof(pre));
memset(cnt, 0, sizeof(cnt));
memset(dis, 0x3f, sizeof(dis));
for (int i=1; i<=n+k; ++i) to[i].clear();
dis[q[l=r=N]=n+tag]=0;
while (l<=r) {
int u=q[l++];
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (dis[u]+e[i].val<dis[v]) {
dis[v]=dis[u]+e[i].val;
to[u].pb(v), ++cnt[v];
if (e[i].val) q[++r]=v;
else q[--l]=v;
}
else if (dis[u]+e[i].val==dis[v]) to[u].pb(v), ++cnt[v];
}
}
// cout<<"dis: "; for (int i=1; i<=n; ++i) cout<<dis[i]<<' '; cout<<endl;
l=0, r=l-1; top=0;
for (int i=1; i<=n+k; ++i) if (dis[i]!=INF && !cnt[i]) q[++r]=i;
while (l<=r) {
int u=q[l++];
sta[++top]=u;
for (auto& v:to[u]) if (--cnt[v]==0) q[++r]=v;
}
// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
for (int l=0,r; l<bel[tag].size(); l=r+1) {
r=min(l+64-1, (int)bel[tag].size()-1);
memset(s, 0, sizeof(s));
for (int i=l; i<=r; ++i) s[bel[tag][i]]|=1llu<<i-l;
for (int i=1; i<=top; ++i) for (auto& v:to[sta[i]]) s[v]|=s[sta[i]];
for (int i=1; i<=n; ++i) if (dis[i]!=INF) pre[i]+=__builtin_popcountll(s[i]);
}
// cout<<"pre: "; for (int i=1; i<=n; ++i) cout<<pre[i]<<' '; cout<<endl;
for (int i=1; i<=n; ++i) if (dis[i]!=INF) {
// cout<<"i: "<<i<<' '<<bel[tag].size()-pre[i]<<' '<<pre[i]-(p[i]==tag)<<endl;
ans[dis[i]+1]+=bel[tag].size()-pre[i];
ans[dis[i]]+=pre[i]-(p[i]==tag);
}
}
signed main()
{
n=read(); m=read(); k=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=n; ++i) bel[p[i]=read()].pb(i), add(i, n+p[i], 0), add(n+p[i], i, 1);
for (int i=1,u,v; i<=m; ++i) {
u=read(); v=read();
add(u, v, 1); add(v, u, 1);
}
for (int i=1; i<=k; ++i) if (bel[i].size()) solve(i);
for (int i=1; i<=k<<1; ++i) assert(!(ans[i]&1)), sum+=(ans[i]/=2);
ans[k<<1|1]=1ll*n*(n-1)/2-sum;
for (int i=1; i<=(k<<1|1); ++i) printf("%lld%c", ans[i], " \n"[i==(k<<1|1)]);
return 0;
}