FJWC2019 子图 (三元环计数、四元环计数)

给定 n 个点和 m 条边的一张图和一个值 k ,求图中边数为 k 的联通子图个数 mod 1e9+7。

\(n \le 10^5, m \le 2 \times 10^5, 1 \le k \le 4\)

观察到 k 的值贼小,考虑分类讨论

下面代码中du[]代表点的度数。(度 找不到比较好的英文,而这个拼音比较巨,所以du是我的代码习惯中里出现拼音的少数几中情况之一)

观察图片

k = 1,输出 m。 k = 2, 枚举点 2,组合数一下即可。

对于 k = 3 我们分为三种情况(按照图片顺序从左到右)

图(1) 我们枚举边 2--3 ,则答案为(du[2] - 1) * (du[3] - 1)。

图(2) 我们枚举点 2 ,组合数一下即可。

图(3) 则转化为三元环计数问题。

需要注意 图(1) 中 可能出现 1 和 4 重合的情况,恰好是一个三元环。每个三元环会被统计三次。所以需要减去 3 * 三元环数量。

所以最后我们需要减去 2 * 三元环数量。

对于 k = 4,我们分为五种情况(按照图片顺序从左到右)

先考虑图(3)是一个菊花比较简单,枚举点 2 直接组合数一下即可。

然后我们考虑图(2), 我们考虑枚举 边 2--3 ,则答案为 (du[2] - 2) * (du[3] - 1) + (du[2] - 1) * (du[3] - 2)。

注意到可能存在 1 号点 和 (4号点 或 5号点) 重合的情况,那么就变成了 图(4) 。每一种图 (4) 都会被统计两次,需要减去两次 图(4) 的方案总数。

然后我们考虑图(1)。 我们考虑枚举点 3, 再依次枚举它的出边。我们开个变量 tmp 维护 已经扫过的出边的另一端点的出边数量和,即 sum(du[i] - 1) (i 是扫过的出边连接的点) 每次我们将本次扫的 (du[i] - 1) 与 tmp 乘一下累加到答案然后把 (du[i] - 1) 加到 tmp 里。

这里 我们会有三种重合情况:(假设点 2 是之前枚举过的点, 点 4 是当前正在枚举的点)

  1. 点 5 和 点 1 重合。 那么图变成了四元环,即图(5)。每个四元环会在这里被统计四次,接下来需要 -= 4 * 四元环方案数。
  2. 点 5 和 点 2 重合,点 1 和点 4 不重合。 那么图变成了图(4)。每个 图(4) 会在这里被统计两次, 接下来需要 -= 2 * 四元环方案数。
  3. 点 5 和 点 2 重合,点 1 和点 4 重合。 那么图变成了一个三元环。每个三元环会在这里被统计三次,接下来需要 -= 3 * 三元环方案数。

然后考虑图(4)。图 4 是有一个三元环和一条边组成。我们考虑枚举点 2 , ans += 点 2 所在三元环数量 * (du[2] - 2) 。

考虑图(5)。图(5)是一个四元环计数。

现在我们把问题转化成了给定图求每个点所在三元环数量,以及四元环总数。三元环数量可以用给每个点三元环数量 / 3 来求出。

首先考虑求每个点所在三元环的数量。

这是常数比较大的 \(O(m \sqrt m)\) 的算法,另外还有 \(O(m \sqrt n)\) 的转化为有向图的常数较小的做法,没看太懂这里不解释

我们考虑把点分为两种,第一种是度数 \(du[x] \le \sqrt m\) 的,第二种是度数 \(du[x] > \sqrt m\) 的 (一下简称第一种点、第二种点)。

然后我们三元环分为两类,第一类是包含第一种点的三元环, 第二类是不包含第一种点的三元环(环上所有点都是第二种点)。

先考虑枚举所有第一种点 x, 我们考虑枚举所有无序出边对,然后我们就可以得到三个顶点 (x, y, z) 。我们可以维护一个 set ,然后直接判断这三个点是否构成三元环。 如果构成三元环我们考虑统计答案。 首先 对于所有第一类点,我们会枚举所有三元环,直接在枚举这个点时暴力统计贡献即可。 对于第二类点,假设 y 和 z 都是第二类点, 那么 (x, y, z) 这个三元环只会在此时被统计,我们将 ans[y]++, ans[z]++。 如果 y 和 z 中只有一个第二类点(假设y),那么三元环 (x, y, z) 会被枚举两次,其中 x 和 z 是第一类点。而 y 只能被统计一次,我们可以考虑当 x 的标号 小于 z 的标号时候将 ans[y]++ 以避免重复统计。这一部分的时间复杂度为 \(O(m \sqrt m)\),因为每条边最多只会作为第一条出边出现两次,而每个点 x 的出度是 \(O(\sqrt m)\) 级别的,所以第二条出边最多为 \(O(\sqrt m)\) 个,所以复杂度为 \(O(m \sqrt m)\)

然后我们考虑枚举第二类三元环。由于第二类点的个数 \(\le \sqrt m\) 个,我们考虑暴力 \(O({\sqrt m}^3) = O(m\sqrt m)\) 枚举三个点判断是否是三元环即可。

然后我们考虑四元环计数。

考虑枚举枚举点 1(设为 x), 并钦定它是标号最大的点(钦定2 3 4 号点的标号都比 1 号点小),然后我们枚举 4 号点(设为y),再枚举 与 4 号点相连的 3 号点(设为z, z != x)。我们考虑维护一个 cnt 数组,cnt[x] 表示 和 1 号点的距离为 2 的路径数量。 我们每次枚举到 z 的时候,前面已经有 cnt[z] 条路径与 x 相连了, 我们让 ans += cnt[z], 然后 cnt[z]++ 即可。

至于时间复杂度的证明,我们可以将点按照度数从小到大排序后重新标号跑一遍算法,时间复杂度为 \(O(m \sqrt m)\)。我并不会证。。。代码里偷懒忘了排序了,复杂度可能会被卡掉

大毒瘤代码

#include <cmath>
#include <cstdio>
#include <vector>
#include <unordered_set>
using namespace std;

const int xkj = 1000000007;

int n, m, k;
vector<int> out[100010];
unordered_set<int> hsh[100010];
int x[200010], y[200010], du[100010], swh[100010], fuck[100010], bucket[100010], cnt;

int Cx2(int x) { return x * (long long)(x - 1) % xkj * 500000004 % xkj; }
int Cx3(int x) { return x * (long long)(x - 1) % xkj * (x - 2) % xkj * 166666668 % xkj; }
int Cx4(int x) { return x * (long long)(x - 1) % xkj * (x - 2) % xkj * (x - 3) % xkj * 41666667 % xkj; }

void prepare_swh() //获取每个点三元环数量
{
	for (int i = 1; i <= n; i++) du[i] = 0, swh[i] = 0;
	for (int i = 1; i <= m; i++) du[x[i]]++, du[y[i]]++;
	int sqm = sqrt(m + 0.233); cnt = 0;
	for (int i = 1; i <= n; i++)
	{
		if (du[i] <= sqm) //暴力统计
		{
			int sz = out[i].size();
			for (int j = 0; j < sz; j++)
			{
				for (int k = j + 1; k < sz; k++)
				{
					if (hsh[out[i][j]].count(out[i][k]))
					{
						swh[i]++;
						if (du[out[i][j]] > sqm && du[out[i][k]] > sqm) swh[out[i][j]]++, swh[out[i][k]]++;
						else
						{
							int x = out[i][j], y = out[i][k];
							if (du[x] > sqm || du[y] > sqm)
							{
								if (du[x] <= sqm) swap(x, y);
								if (i < y) swh[x]++;
							}
						}
					}
				}
			}
		}
		else fuck[++cnt] = i;
	}
	for (int i = 1; i <= cnt; i++)
	{
		for (int j = i + 1; j <= cnt; j++)
		{
			if (hsh[fuck[i]].count(fuck[j]))
			{
				for (int k = j + 1; k <= cnt; k++)
				{
					if (hsh[fuck[j]].count(fuck[k]) && hsh[fuck[i]].count(fuck[k]))
					{
						swh[fuck[i]]++, swh[fuck[j]]++, swh[fuck[k]]++;
					}
				}
			}
		}
	}
}

int qcnt()
{
	int ans = 0;
	for (int i = 1; i <= n; i++) ans = (ans + swh[i]) % xkj;
	return ans * 333333336LL % xkj;
}

int qcnt2() //获取 sigma 每个点在多少个三元环内*(这个点的度数-2)
{
	int ans = 0;
	for (int i = 1; i <= n; i++) ans = (ans + swh[i] * (long long)(du[i] - 2) % xkj) % xkj;
	return ans;
}

int qcnt3() //获取正方形
{
	int ans = 0;
	for (int i = 1; i <= n; i++)
	{
		for (int j : out[i]) if (j < i)
		{
			for (int k : out[j]) if (k < i)
			{
				ans = (ans + bucket[k]) % xkj;
				bucket[k]++;
			}
		}
		for (int j : out[i]) if (j < i)
		{
			for (int k : out[j]) if (k < i)
			{
				bucket[k]--;
			}
		}
	}
	return ans;
}

int main()
{
	freopen("subgraph.in", "r", stdin), freopen("subgraph.out", "w", stdout);
	scanf("%d%d%d", &n, &m, &k);
	for (int i = 1; i <= m; i++)
	{
		scanf("%d%d", &x[i], &y[i]);
		out[x[i]].push_back(y[i]), out[y[i]].push_back(x[i]);
		hsh[x[i]].insert(y[i]), hsh[y[i]].insert(x[i]);
	}
	prepare_swh();
	if (k == 1) { printf("%d\n", m); }
	if (k == 2)
	{
		int ans = 0;
		for (int i = 1; i <= n; i++) ans = (ans + Cx2(out[i].size())) % xkj; // []---[]---[]
		printf("%d\n", ans);
	}
	if (k == 3)
	{
		int ans = 0;
		for (int i = 1; i <= n; i++) ans = (ans + Cx3(out[i].size())) % xkj; // 菊花型
		for (int i = 1; i <= m; i++) ans = (ans + (out[x[i]].size() - 1) * (long long)(out[y[i]].size() - 1) % xkj) % xkj; //链型+三元环*3
		ans = (ans - 2 * qcnt()) % xkj;//三元环数量
		ans = (ans + xkj) % xkj;
		printf("%d\n", ans);
	}
	if (k == 4)
	{
		int ans = 0;
		for (int i = 1; i <= n; i++) ans = (ans + Cx4(out[i].size())) % xkj; // 菊花型
		for (int i = 1; i <= m; i++) ans = (ans + Cx2(out[x[i]].size() - 1) * (long long)(out[y[i]].size() - 1) % xkj) % xkj,
		ans = (ans + Cx2(out[y[i]].size() - 1) * (long long)(out[x[i]].size() - 1) % xkj) % xkj; //箭头形+陷阱型*2
		ans = (ans + xkj - 3 * (long long)qcnt2() % xkj) % xkj; //减去陷阱型
		for (int i = 1; i <= n; i++) //链型
		{
			int tmp = 0;
			for (int j : out[i])
			{
				int tmp1 = out[j].size() - 1;
				ans = (ans + tmp * (long long)tmp1 % xkj) % xkj;
				tmp = (tmp + tmp1) % xkj;
			}
		}
		ans = (ans + xkj - 3 * (long long)qcnt() % xkj) % xkj;
		ans = (ans + xkj - 3 * (long long)qcnt3() % xkj) % xkj;
		printf("%d\n", ans);
	}
	return 0;
}
posted @ 2019-03-15 18:39  ghj1222  阅读(2089)  评论(2编辑  收藏  举报