[日常训练]网络(bitset+连通性dp)
Description
小 D 正在研究网络模型。
在计算机图形学中,三角网格是最常用的网络模型之一,用于表示较为复杂的三维模型。
三角网格是一系列由空间中的顶点以及若干连接这些顶点的三角形面组成的。
而网格细分,则是将给定的原始网格进行精细地改进,从而产生更加平滑的效果。
小 D 对一个三角形面使用了最为简单的网格细分方法:每次选择一个三角形面,在内部新增一个点,并将此顶点与三角形的三个顶点之间连边,从而将一个三角形面变为三个更小的三角形面。不断重复上述过程,直到小 D 觉得足够为止。
小 D 将细分后的网格看成了一张无向图,并给每条边赋予了一个权值。
现在输入这张无向图,请求出图中权值和最大的简单路径。其中,简单路径,指不经过重复的点或边的路径。
设顶点个数为 \(n\),\(n\le 1000\),边权 \(\in[1,10^6]\)。
时空限制 \(\text{1s/512MB}\),开启 \(\text{O2}\) 优化和 c++11
。
Solution
首先还原出无向图的网格细分方式。
先用 bitset 存下无向图的邻接矩阵 \(a\)。
枚举三元组 \((u,v,w)\),满足 \(u,v,w\) 两两有边,然后把 \((u,v),(v,w),(u,w)\) 三条边在 bitset 中都删掉。
接下来,求 \(a_u,a_v,a_w\) 的交集,要么为空,要么有一个元素 \(c\)。递归处理 三角形 \((u,v,c),(v,w,c),(u,w,c)\) 即可。
如果递归结束之后,能把 \(n\) 个点都找出来,那么就求出了一种合法的网格细分方式。
因为三角形面只有 \(O(n)\) 个,所以这部分时间复杂度为 \(O(\frac{n^3}{64})\)。
接下来考虑建一棵树,树上的每个节点都对应一个三角形面。
考虑 dp,设 \(i\) 号节点对应三角形 \((u,v,w)\),三角形中心为 \(c\),那么它有三个儿子,分别为 \((u,v,c),(v,w,c),(u,w,c)\)。
记 \(dp(i,s)\) 表示只考虑三角形 \(i\) 内部的点的连边,不考虑边 \((u,v),(v,w),(u,w)\),满足 \(u,v,w,c\) 这四个点的状态为 \(s\),最大的边权之和(注意不是最长路径)是多少。
其中 \(s\) 包含:
- \(u,v,w\) 的度数 \(deg\)
- \(u,v,w\) 这四个点相互间的连通性 \(bel\),用最小表示法
- 三角形 \(i\) 内部有多少个 \(deg=1\) 的点(\(u,v,w\) 不算),记作 \(one\)
具体实现时,可以用 \(map\) 维护每个状态的 dp 值。
先不考虑 \((u,c),(v,c),(w,c)\) 这三条边是否连上,先把 \(i\) 的三个儿子合并起来。
注意这里合并的时候,要记 \(4\) 个点的状态,也就是要记:
- \(u,v,w,c\) 的度数 \(deg\)
- \(u,v,w,c\) 这四个点相互间的连通性 \(bel\)
- 三角形 \(i\) 内部有多少个 \(deg=1\) 的点(\(u,v,w,c\) 不算),记作 \(one\)
合法状态下,不能有环,且不能有 \(deg>2\) 的点。
因此合并两个部分的时候,要判断:
- 同一个点两侧 \(deg\) 之和必须 \(\le 2\)
- \(u,v,w,c\) 中,不能存在两个点 \(i,j\),在两个部分都满足 \(bel_i=bel_j\)
同时,对于 \(one\le 3\) 的状态,显然是不合法的,因为不管 \(u,v,w,c\) 之间的点相互怎么连边,或者 \(u,v,w\) 跟三角形外部怎么连边,都不能把这三个 \(deg=1\) 的点串成一条链。
那么加一个判断(算是剪枝):
- 两边的 \(one\) 之和 \(\le2\)
接下来,考虑 \((u,c),(v,c),(w,c)\) 这三条边是否要连,先 \(O(2^3)\) 枚举连边状态。
连边 \((x,y)\) 的时候,要判断:
- \(bel_x\ne bel_y\)
- \(deg_x,deg_y\le1\)
连边之后,更新 \(bel,deg\),此时的合法状态必须满足下列两个条件之一:
- \(c\) 跟 \(u,v,w\) 中的至少一个点连通
- \(deg_c=0\)
不然的话,\(c\) 所在的链就不能和外部串起来了。
然后把状态中,跟 \(c\) 有关的信息删除,就得到了一个 \(i\) 的合法状态 \(s\) 了,可以更新 \(dp(i,s)\) 的最大值。
考虑计算答案,即在 dp 到三角形 \(i\) 时,计算三角形内部的最长路。(\(u,v,w\) 都不能在路径上)
那么此时必须满足:
- \(deg_u=deg_v=deg_w=0\)
- \(one=2\)(这里的 \(one\) 算上 \(c\))
最后,设最外层的三角形为 \((x,y,z)\),考虑它们之间的连边。
在 \(dp(root,s)\) 的基础上,枚举 \(x,y,z\) 之间的连边情况,判断连边是否合法。
此时的 \(one\) 要算上 \(x,y,z\),然后判断是否 \(one=2\),更新答案。
Code
一道磨炼心态和体力的好题。
代码长度约 \(7\text{KB}\)。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pb push_back
#define vi vector<int>
#define mp map<state, int>
#define fi first
#define se second
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 1e4 + 5;
struct point
{
int x, y;
point(){}
point(int _x, int _y) :
x(_x), y(_y) {}
}a[e];
struct tri
{
int u, v, w;
tri(){}
tri(int _u, int _v, int _w) :
u(_u), v(_v), w(_w) {}
}p[e];
bitset<e>b[e], bit;
int n, m, all, ru, rv, rw, pool, son[e][3], mid[e], ans, cost[1005][1005];
inline void sort3(int &x, int &y, int &z)
{
if (x > y) swap(x, y);
if (x > z) swap(x, z);
if (y > z) swap(y, z);
}
inline void join(int x, int y)
{
b[x][y] = b[y][x] = 1;
}
inline void clear(int x, int y)
{
b[x][y] = b[y][x] = 0;
}
inline void check(int u, int v, int w)
{
clear(u, v);
clear(u, w);
clear(v, w);
bit = b[u] & b[v] & b[w];
if (!bit.count()) return;
int c = bit._Find_first();
all++;
check(u, v, c);
check(u, w, c);
check(v, w, c);
}
inline int build(int u, int v, int w)
{
sort3(u, v, w);
int x = ++pool;
p[x] = tri(u, v, w);
clear(u, v);
clear(u, w);
clear(v, w);
bit = b[u] & b[v] & b[w];
if (!bit.count()) return x;
int c = bit._Find_first();
mid[x] = c;
son[x][0] = build(u, v, c);
son[x][1] = build(u, w, c);
son[x][2] = build(v, w, c);
return x;
}
struct node
{
int id, deg, bel;
node(){}
node(int _id, int _deg, int _bel) :
id(_id), deg(_deg), bel(_bel) {}
};
inline bool operator < (node a, node b)
{
return a.id != b.id ? a.id < b.id : a.deg != b.deg ? a.deg < b.deg : a.bel < b.bel;
}
struct state
{
vector<node>g;
int one;
state()
{
one = 0;
}
inline void upt()
{
if (!g.size()) return;
sort(g.begin(), g.end());
}
inline void smaller()
{
int i, j, col[4], t = 0, l = g.size();
memset(col, 0, sizeof(col));
for (i = 0; i < l; i++)
if (!col[i])
{
col[i] = ++t;
for (j = i + 1; j < l; j++)
if (g[i].bel == g[j].bel) col[j] = t;
}
for (i = 0; i < l; i++) g[i].bel = col[i];
}
inline void merge(int x, int y)
{
int i, bx = g[x].bel, by = g[y].bel, l = g.size();
if (bx == by) return;
for (i = 0; i < l; i++)
if (g[i].bel == by) g[i].bel = bx;
smaller();
}
inline bool link(int x, int y)
{
if (x > y) swap(x, y);
if (g[x].bel == g[y].bel || g[x].deg > 1 || g[y].deg > 1) return 0;
merge(x, y);
g[x].deg++; g[y].deg++;
return 1;
}
inline void erase(int x)
{
vector<node> h = g;
g.clear();
int i;
for (i = 0; i <= 3; i++)
if (i != x) g.pb(h[i]);
smaller();
}
};
inline bool operator < (state a, state b)
{
int i, la = a.g.size(), lb = b.g.size();
if (la != lb) return la < lb;
if (a.one != b.one) return a.one < b.one;
for (i = 0; i < la; i++)
if (a.g[i] < b.g[i]) return 1;
else if (b.g[i] < a.g[i]) return 0;
return 0;
}
inline bool operator == (state a, state b)
{
return !(a < b) && !(b < a);
}
mp f[e], nxt;
mp :: iterator it, it1, it2;
inline void dfs(int k)
{
int u = p[k].u, v = p[k].v, w = p[k].w, c = mid[k], i, z, id[3], j, pos[3] = {u, v, w};
state s;
s.g.pb(node(u, 0, 1));
s.g.pb(node(v, 0, 2));
s.g.pb(node(w, 0, 3));
f[k][s] = 0;
if (!c) return;
for (i = 0; i <= 2; i++) dfs(son[k][i]);
for (z = 0; z <= 2; z++)
{
int ch = son[k][z], x;
if (z == 0) x = w;
else if (z == 1) x = v;
else x = u;
nxt.clear();
for (it1 = f[k].begin(); it1 != f[k].end(); it1++)
for (it2 = f[ch].begin(); it2 != f[ch].end(); it2++)
{
state s1 = (*it1).fi, s2 = (*it2).fi, p1 = s1, p2 = s2;
int v1 = (*it1).se, v2 = (*it2).se;
bool ok = 1;
if (s1.g.size() == 3)
{
s1.g.pb(node(c, 0, 4));
s1.upt();
s1.smaller();
}
s2.g.pb(node(x, 0, 4));
s2.upt();
for (i = 0; i <= 3 && ok; i++)
for (j = 0; j <= 3; j++)
if (i != j)
ok &= s1.g[i].bel != s1.g[j].bel || s2.g[i].bel != s2.g[j].bel;
if (!ok) continue;
for (i = 0; i <= 3 && ok; i++)
{
s1.g[i].deg += s2.g[i].deg;
ok &= s1.g[i].deg <= 2;
}
s1.one += s2.one;
if (!ok || s1.one > 2) continue;
for (i = 0; i <= 2; i++)
for (j = i + 1; j <= 3; j++)
if (s2.g[j].bel == s2.g[i].bel) s1.merge(i, j);
nxt[s1] = max(nxt[s1], v1 + v2);
}
f[k].swap(nxt);
}
for (i = 0; i <= 2; i++)
{
int x = pos[i];
id[i] = i;
if (c < x) id[i]++;
}
int idc = 0;
if (u < c) idc++;
if (v < c) idc++;
if (w < c) idc++;
nxt.clear();
state now;
for (it = f[k].begin(); it != f[k].end(); it++)
{
for (int x = 0; x < 7; x++)
{
bool ok = 1;
state s = (*it).fi;
int val = (*it).se, d;
for (i = 0; i <= 2 && ok; i++)
if (x & (1 << i))
{
ok &= s.link(id[i], idc);
val += cost[pos[i]][c];
}
if (!ok) continue;
if (d = s.g[idc].deg)
{
bool pd = 0;
for (i = 0; i <= 3 && !pd; i++)
pd |= i != idc && s.g[i].bel == s.g[idc].bel;
if (d == 1) s.one++;
if (pd)
{
state t = s;
t.erase(idc);
int &dp = nxt[t];
dp = max(dp, val);
}
}
else
{
state t = s;
t.erase(idc);
int &dp = nxt[t];
dp = max(dp, val);
}
if (s.one != 2) continue;
bool pd = 1;
for (i = 0; i <= 3 && pd; i++)
pd &= i == idc || !s.g[i].deg;
if (pd) ans = max(ans, val);
}
}
f[k].swap(nxt);
if (k != 1) return;
for (it = f[1].begin(); it != f[1].end(); it++)
{
for (int x = 0; x < 7; x++)
{
state s = (*it).fi, lst = s;
int val = (*it).se;
bool ok = 1;
for (i = 0; i <= 2 && ok; i++)
if (x & (1 << i))
{
ok &= s.link(i, (i + 1) % 3);
val += cost[pos[i]][pos[(i + 1) % 3]];
}
if (!ok) continue;
for (i = 0; i <= 2; i++)
if (s.g[i].deg == 1) s.one++;
if (s.one == 2) ans = max(ans, val);
}
}
}
int main()
{
// freopen("mesh3.in", "r", stdin);
read(n); m = 3 * (n - 2);
int i, x, y, w, j, z;
for (i = 1; i <= m; i++)
{
read(x); read(y); read(z);
a[i] = point(x, y);
join(x, y);
cost[x][y] = cost[y][x] = z;
}
if (n == 3)
{
ans = max(ans, cost[1][2] + cost[1][3]);
ans = max(ans, cost[1][3] + cost[3][2]);
ans = max(ans, cost[1][2] + cost[2][3]);
cout << ans << endl;
return 0;
}
for (i = 1; i <= m && !ru; i++)
{
int u = a[i].x, v = a[i].y;
for (w = 1; w <= n; w++)
if (b[u][w] && b[v][w])
{
all = 3;
check(u, v, w);
for (j = 1; j <= m; j++) join(a[j].x, a[j].y);
if (all == n)
{
ru = u; rv = v; rw = w;
break;
}
}
}
build(ru, rv, rw);
dfs(1);
cout << ans << endl;
return 0;
}