Loading

联合省选 2020 补题记录

有点做不动题,就来写点东西。。。

鸽子了 A 卷的保序回归 和 B 卷。

通过翻游记大概搞到了题目顺序

d1t1 : P6619 [省选联考 2020 A/B 卷] 冰火战士

d1t2 : P6620 [省选联考 2020 A 卷] 组合数问题

d2t1 : P6622 [省选联考 2020 A/B 卷] 信号传递

d2t2 : P6623 [省选联考 2020 A 卷] 树

d2t3 : P6624 [省选联考 2020 A 卷] 作业题

我才发现顺序就是题目编号

【upd】心情好,随手写了一道 B 卷题:P6626 [省选联考 2020 B 卷] 消息传递

补题的时候我都不开 O2 以还原最真实的情景,所以不考虑开 O2 能过的做法


冰火战士

毕竟是补题,当时一群人被卡常大肆交流做法的时候也看到了 树状数组上二分 / 线段树上二分 这个算法标签。。。

但是既然是补题,还是写已知的最优秀的做法吧 qaq。

显然场地温度为 \(d\) 时,冰战士能量和是 温度 \(\le d\) 的能量战士和,火战士能量和是 温度 \(\ge d\) 的战士能量和。

两边都带等不方便二分,不妨把所有火战士的温度都 \(+1\),变为 \(>d\) 的战士能量和。

最后答案是两边能量和较小值的两倍。

这个东西显然可能有一段是平的,不好三分。

考虑二分两个位置,第一个是 最靠后的 冰能量 不大于 火能量 的位置 以及 第一次 冰能量 大于 火能量 的位置

树状数组上二分的时候,火能量可以通过总和减去前缀,冰能量直接前缀计算即可。

显然答案是这两个位置之一。

求第一个位置可以直接树状数组上二分,但是第二个位置是“最靠前的”,不那么常规。

仔细想想发现第二个位置的火能量是可以直接计算的,然后找到火战士 最靠后的值不小于火能量的位置 即可。

第一次写树状数组上二分感觉非常 /tuu,码力完全不行,写了 3h 重构了一发又写了 1h+,考场上绝对完蛋。

跑了 2.3s,质疑线段树二分能不能过去。。。

【upd】Guidingstar 说他考场上线段树二分过了,那可能是洛谷慢?

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
const int N = 2000005;
int Q, lsh[N], len, sum2;
int tr1[N], tr2[N];
struct node {
	int op, x, y, z;
} a[N];
inline void add1(int x, int d) {//ice, 前缀 
	for(int i = x; i <= len; i += i & -i) tr1[i] += d;
}
inline void add2(int x, int d) {//fire, 后缀
	sum2+=d; 
	for(int i = x; i <= len; i += i & -i) tr2[i] += d;
}
inline int ask1(int x) {
	int res = 0;
	for(int i = x; i > 0; i -= i & -i) res += tr1[i];
	return res;
}
inline int ask2(int x) {
	int res = 0;
	for(int i = x; i > 0; i -= i & -i) res += tr2[i];
	return res;
}
signed main() {
	Q = read();
	rep(i, 1, Q) {
		a[i].op = read(), a[i].x = read();
		if(a[i].op == 1) a[i].y = read(), a[i].z = read(), lsh[++len] = a[i].y;
	}
	sort(lsh + 1, lsh + len + 1), len = unique(lsh + 1, lsh + len + 1) - lsh - 1;
	rep(i, 1, Q) if(a[i].op == 1)
		a[i].y = lower_bound(lsh + 1, lsh + len + 1, a[i].y) - lsh;
	rep(i, 1, Q) {
		int op = a[i].op, x = a[i].x;
		if(op == 2) {
			if(a[x].x == 0) add1(a[x].y, -a[x].z);
			else add2(a[x].y + 1, -a[x].z);
		} else {
			if(a[i].x == 0) add1(a[i].y, a[i].z);
			else add2(a[i].y + 1, a[i].z);
		}
		int r = 0, s1 = 0, s2 = sum2;
		pair <int, int> ans1 = mkp(0, 0), ans2 = mkp(0, 0);
		for(int i = 1 << 20; i; i >>= 1){
			int tmp = r + i;
			if(tmp <= len && s1 + tr1[tmp] <= s2 - tr2[tmp])
				r = tmp, s1 += tr1[tmp], s2 -= tr2[tmp];
		}
		ans1 = mkp(s1, r);
		if(r < len) {
			ans2.fi = sum2 - ask2(r + 1);
			s1 = 0, s2 = sum2, r = 0;
			for(int i = 1 << 20; i; i >>= 1) {
				int tmp = r + i, t1 = s1 + tr1[tmp], t2 = s2 - tr2[tmp];
				if(tmp <= len && (t2 >= ans2.fi))
					r = tmp, s1 = t1, s2 = t2;
			}
			ans2.se = r;
		}
		ans1 = max(ans1, ans2);
		if(ans1.fi == 0) puts("Peace"); 
		else printf("%d %lld\n", lsh[ans1.se], ans1.fi * 2ll);
	}
}

组合数问题

上来扔给你一个式子叫你算感觉非常可怕。。。

但是再看几眼发现这个东西非常的 naive。

先做一点基础的化简,\(f(k)\) 必然要展开

\[\sum_{k=0}^{n}\sum_{i=0}^{m}a_ik^i x^k\binom{n}{k}\\ =\sum_{i=0}^{m}a_i\sum_{k=0}^{n}k^ix^k\binom{n}{k} \]

注意到那个次幂是一维特别大一维特别小,想到了第二类斯特林数展开:

\[m ^ n = \sum_{i} \begin{Bmatrix}n\\i\end{Bmatrix}i!\binom{m}{i} \]

带进去

\[=\sum_{i = 0} ^{m} a_i \sum_{k = 0} ^ {n} \sum_{j = 0} ^ {m}\begin{Bmatrix} i \\ j \end{Bmatrix}j!\binom{k}{j} x ^ k \binom{n}{k}\\ =\sum_{i = 0} ^{m} a_i \sum_{j = 0} ^ {m} \begin{Bmatrix} i \\ j \end{Bmatrix}j! \sum_{k = 0} ^ {n} x ^ k \binom{n}{j}\binom{n-j}{k-j}\\ =\sum_{i = 0} ^{m} a_i \sum_{j = 0} ^ {m}\binom{n}{j} \begin{Bmatrix} i \\ j \end{Bmatrix}j! x ^ j \sum_{k = 0} ^ {n - j} x ^ k \binom{n-j}{k}\\ =\sum_{i = 0} ^{m} a_i \sum_{j = 0} ^ {m}n^{\underline{j}} \begin{Bmatrix} i \\ j \end{Bmatrix} x ^ j (1+x)^{n-j} \]

这 100 分大概是白送的?心态好一点勇敢想正解都能想出来吧!

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}

const int N = 1005;
int n, x, mod, m, a[N], S2[N][N], fac[N], dwn[N], ans, pw1[N], pw2[N];
inline int qpow(int n, int k) {
	int res = 1;
	for(; k; k >>= 1, n = 1ll * n * n % mod)
		if(k & 1) res = 1ll * n * res % mod;
	return res;
}
signed main() {
	n = read(), x = read(), mod = read(), m = read();
	rep(i, 0, m) a[i] = read();
	S2[0][0] = 1;
	rep(i, 1, m) {
		rep(j, 1, i)
			S2[i][j] = (1ll * S2[i - 1][j] * j % mod + S2[i - 1][j - 1]) % mod;
	}
	fac[0] = 1;
	for(int i = 1; i <= m; ++i) fac[i] = 1ll * i * fac[i - 1] % mod;
	dwn[0] = 1;
	for(int i = 1; i <= m; ++i) dwn[i] = 1ll * (n - i + 1) * dwn[i - 1] % mod;
	pw1[m] = qpow(1 + x, n - m);
	for(int i = m - 1; i >= 0; --i) pw1[i] = 1ll * pw1[i + 1] * (1 + x) % mod;
	pw2[0] = 1;
	for(int i = 1; i <= m; ++i) pw2[i] = 1ll * pw2[i - 1] * x % mod;
	for(int i = 0; i <= m; ++i) pw1[i] = 1ll * pw1[i] * pw2[i] % mod;
	for(int i = 0; i <= m; ++i) {
		int res = 0;
		for(int j = 0; j <= i; ++j)
			res = (res + 1ll * S2[i][j] * dwn[j] % mod * pw1[j] % mod) % mod;
		ans = (ans + 1ll * res * a[i] % mod) % mod;
	}
	cout << ans << '\n';
}

信号传递

统计 \(c(i,j)\) 表示 \(i\) 塔向 \(j\) 塔传递了几次。

\(dp(msk)\) 表示 \([1,popcount(msk)]\) 使用 \(msk\) 里面为 \(1\) 的位置来填的最小贡献。

每一次枚举一个 \(i\not\in msk\) 来转移。

\(A=msk,B=\{x|x\not\in A \operatorname{and} x\not=i\}\)

那么

\(c(i,j)(j\in A)\) 的贡献是 \(k\times c(i,j)\times (x_i+x_j)\),其中 \(x\) 为下标。

\(c(i,j)(j\in B)\) 的贡献是 \(c(i,j)\times (x_j-x_i)\)(显然 \(j\)\(i\) 后面所以不带绝对值)。

由于每一次枚举 \(i\) 的时候 \(x_j\) 并不知道,所以对于 \(j\in A\) 提前加贡献,对于 \(j\in B\) 延后算贡献。

综上,大概思路是,枚举 \(i\not\in msk\),枚举 \(j\not=i\)

显然 \(x_i=popcount(msk)+1\)(当前集合大小加一)。

如果 \(j\in A\),则转移 \(dp(msk|2^i)\gets dp(msk) + k \times c(i,j)\times x_i + c(j,i)\times x_i\)

如果 \(j\in B\),则转移 \(dp(msk|2^i)\gets dp(msk) + k \times c(j,i)\times x_i - c(i,j)\times x_i\)

复杂度 \(O(2^mm^2)\),可以得到 \(70\) 分。

70 pts Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
int dp[1 << 23], cnt[1 << 23], n, m, k, c[23][23], S[100000];
signed main() {
	n = read(), m = read(), k = read();
	for(int i = 0; i < n; ++i) S[i] = read() - 1;
	for(int i = 0; i < n - 1; ++i) ++c[S[i]][S[i + 1]];
	memset(dp, 0x3f, sizeof(dp));
	dp[0] = 0;
	for(int msk = 0; msk < 1 << m; ++msk) {
		cnt[msk] = cnt[msk >> 1] + (msk & 1);
		for(int i = 0; i < m; ++i) {
			if(msk >> i & 1) continue;
			int t = 0, x = cnt[msk] + 1;
			for(int j = 0; j < m; ++j) if(i != j) {
				if(msk >> j & 1) t += x * k * c[i][j] + x * c[j][i];
				else t += x * k * c[j][i] - c[i][j] * x;
			}
			ckmin(dp[msk | (1 << i)], t + dp[msk]);
		}
	}
	cout << dp[(1 << m) - 1] << '\n';
	return 0;
}

大眼观察转移,注意到很多计算是重复的。具体来说,只有 \(m\times 2^m\)\(f(msk,i)\),可以提前预处理。

注意到预处理的时候如果再枚举 \(j\),复杂度就会错掉,仍然是 \(O(2^mm^2)\)

观察一下,贡献分为两类:从 \(i\) 往别的塔转移;从别的塔往 \(i\) 转移。

考虑分别预处理这两类贡献设为 \(f(msk,i),g(msk,i)\)

考虑到 \(msk\) 之间的依赖性,可以通过 \(f(msk,i)=f(msk\operatorname{xor}\operatorname{lowbit}(msk),i)+c(\log(\operatorname{lowbit}(msk)),i)\) 来递推,\(g\) 同理。

复杂度这样子就是 \(O(2^mm)\) 了,可惜空间也是 \(O(2^mm)\) 这个复杂度,过不去。

还是给出这部分的代码吧,比较好理解后面的优化。

MLE Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
int dp[1 << 23], cnt[1 << 23], n, m, k, c[23][23], S[100000];
int f[1 << 23][23], g[1 << 23][23], lg[1 << 23];
inline void init() {
	for(int i = 0; i < m; ++i) lg[1 << i] = i;
	for(int msk = 1; msk < 1 << m; ++msk) {
		for(int i = 0; i < m; ++i) {
			int lb = msk & -msk;
			f[msk][i] = f[msk ^ lb][i] + c[i][lg[lb]];
			g[msk][i] = g[msk ^ lb][i] + c[lg[lb]][i];
		}
	}
}
signed main() {
	n = read(), m = read(), k = read();
	for(int i = 0; i < n; ++i) S[i] = read() - 1;
	for(int i = 0; i < n - 1; ++i) ++c[S[i]][S[i + 1]];
	init();
	memset(dp, 0x3f, sizeof(dp));
	dp[0] = 0;
	for(int msk = 0, U = (1 << m) - 1; msk < U; ++msk) {
		cnt[msk] = cnt[msk >> 1] + (msk & 1);
		for(int i = 0; i < m; ++i) {
			if(msk >> i & 1) continue;
			int x = cnt[msk] + 1, o = x * k, res = 0, nx = msk | (1 << i);
			res += o * f[msk][i];
			res += x * g[msk][i];
			res += o * g[U ^ nx][i];
			res -= x * f[U ^ nx][i];
			ckmin(dp[nx], res + dp[msk]);
		}
	}
	cout << dp[(1 << m) - 1] << '\n';
	return 0;
}

既然都做到上面那一步通过 \(\operatorname{lowbit}\) 来递推了。

考虑用一个栈来保留当前具有公共前缀的 \(msk\)。举个例子可能好理解一些:

比如 \(msk=1101011\),那么栈内应该存的是:\(1000000,1100000,1101000,1101010,1101011\)

可以发现栈顶就是 \(msk\) 去掉 \(\operatorname{lowbit}\)的值,这样子我们可以根据栈顶 \(f,g\) 的值来递推出 \(msk\)\(f,g\) 值,再把 \(msk\) 压入栈中。

但是这个优化使得我们不太方便维护 补集异或 \(2^i\) 的答案,可以考虑用全集的答案 减去当前的答案 再减去 \(i\) 的答案。

根据压栈次数可以发现复杂度为 \(2^m\),而这部分的空间复杂度则降至 \(O(m^2)\),总空间复杂度变成了 \(O(2^m)\)!!!

然后就过去了。

能自己想到这些优化还是很开心的,但是加起来搞了近三个小时,省选场上估计我没这个耐心也没这个时间,还是要完蛋啊。。。

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
int dp[1 << 23], n, m, k, U, c[23][23], S[100000], lg[1 << 23];
int f[23][23], g[23][23], stk[23], top, st[23], ed[23];
signed main() {
	n = read(), m = read(), k = read();
	for(int i = 0; i < n; ++i) S[i] = read() - 1;
	for(int i = 0; i < n - 1; ++i)
		++c[S[i]][S[i + 1]], ++st[S[i]], ++ed[S[i + 1]];
	for(int i = 0; i < m; ++i) lg[1 << i] = i;
	U = (1 << m) - 1;
	memset(dp, 0x3f, sizeof(dp));
	dp[0] = 0;
	for(int msk = 0; msk < U; ++msk) {
		if(msk) {
			int lb = msk & -msk, LG = lg[lb];
			while(top && (stk[top] ^ lb) != msk) --top;
			stk[++top] = msk;
			for(int i = 0; i < m; ++i) {
				f[top][i] = f[top - 1][i] + c[i][LG];
				g[top][i] = g[top - 1][i] + c[LG][i];
			}
		}
		int x = top + 1, o = x * k;
		for(int i = 0; i < m; ++i) {
			if(msk >> i & 1) continue;
			int res = 0;
			res += o * f[top][i];
			res += x * g[top][i];
			res += o * (ed[i] - g[top][i]);
			res -= x * (st[i] - f[top][i]);
			res -= o * c[i][i] - c[i][i] * x;
			ckmin(dp[msk | (1 << i)], res + dp[msk]);
		}
	}
	cout << dp[U] << '\n';
	return 0;
}

非常的裸,要你维护一个集合,支持全局加一,查询全局异或和,合并。

直接用 Trie 树即可。

因为要全局加一,考虑从低位往高位插入。这样加一的时候,对于遍历到的节点执行以下两个操作即可:

  • 交换左右子树

  • 往左子树递归

原因很简单,左子树是 \(0\),变成 \(1\),右子树是 \(1\) 变成 \(0\) 并且往下一位进 \(1\)

对于左子树就是一个子问题了。

维护全局异或和,我的方法是维护每一层 \(1\) 的个数,这样子遍历每一层可以 \(O(\log V)\) 查询。

合并直接 Trie 启发式合并就好了。

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}

const int N = 600006;
const int T = N * 23;
int n, w[N], fa[N], cnt[N][21][2], tot, tr[T][2], rt[N], s[T];
int hed[N], et;
LL ans;
struct edge { int nx, to; } e[N];
inline void adde(int u, int v) {
	e[++et].nx = hed[u], e[et].to = v, hed[u] = et;
}
void add(int u) {
	int p = rt[u];
	for(int i = 0; i < 21; ++i) {
		cnt[u][i][0] += s[tr[p][1]];
		cnt[u][i][1] -= s[tr[p][1]];
		cnt[u][i][1] += s[tr[p][0]];
		cnt[u][i][0] -= s[tr[p][0]];
		swap(tr[p][0], tr[p][1]);
		if(tr[p][0]) p = tr[p][0];
		else break;
	}
}
int merge(int x, int y) {
	if(!x || !y) return x | y;
	s[x] += s[y];
	tr[x][0] = merge(tr[x][0], tr[y][0]);
	tr[x][1] = merge(tr[x][1], tr[y][1]);
	return x;
}
void insert(int u, int x) {
	if(!rt[u]) rt[u] = ++tot;
	int p = rt[u];
	for(int i = 0; i < 21; ++i) {
		int c = x >> i & 1;
		if(!tr[p][c]) tr[p][c] = ++tot;
		p = tr[p][c], ++cnt[u][i][c], ++s[p];
	}
}
void dfs(int u) {
	for(int i = hed[u]; i; i = e[i].nx) {
		int v = e[i].to;
		dfs(v);
		add(v);
		for(int j = 0; j < 21; ++j)
			cnt[u][j][0] += cnt[v][j][0], cnt[u][j][1] += cnt[v][j][1];
		rt[u] = merge(rt[u], rt[v]);
	}
	insert(u, w[u]);
	for(int i = 0; i < 21; ++i)
		if(cnt[u][i][1] & 1) ans += 1 << i;
}
signed main() {
	n = read();
	rep(i, 1, n) w[i] = read();
	rep(i, 2, n) fa[i] = read(), adde(fa[i], i);
	dfs(1);
	cout << ans << '\n';
}

作业题

首先 \(\gcd\) 反演掉,可以考虑简便的欧拉函数 \(\sum\limits_{d|n}\varphi(d)=n\)

\[ans=\sum_{d}\varphi(d)\sum_{G=\{e|w_e\%d=0\}}\left(\sum_{T\in G}\sum_{e\in T}w_e \right) \]

注意到在枚举 \(d\) 的过程中,如果特判掉没有生成树的情况,剩下的情况数非常少。

粗略估计最大上界:这个值域内每一个数因数个数最大为 \(\sqrt{152501}=144\),边数最大为 \(\dfrac{n(n-1)}{2}=435\),每一棵生成树至少需要 \(n-1\) 条边,所以总共计算图的生成树边权和次数最多为 \(144*15=2160\)

很显然这个上界达不到,哪怕达到,你写个 \(O(n^4)\) 的东西都能过。

那么现在的问题转化为,给你一张图,求其所有生成树边权和。

如果你做过 P5296 [北京省选集训2019]生成树计数,你会发现这题是个弱化版,直接套上即可。

大概思路是,设每一条边的边权为一个多项式 \(1+w_ex\),可以发现跑完矩阵树之后一次项系数就是答案。

直接把边权设成 Poly,带进去求 det。注意到可以在 \(\bmod x^2\) 意义下计算,所以这多项式就是个常数,复杂度 \(O(n^3)\)。千万不要在 \(\bmod x^{n+1}\) 意义下计算,因为这样子是 \(O(n^5)\) 的。

如果你像我一样脑抽了可以写个插值,求出这个多项式的 \(n\) 个点值然后把多项式插出来。如果你不知道怎么快速插这个多项式可以看看 拉格朗日插值如何插出系数,或者写个高斯消元也不影响复杂度。复杂度 \(O(n^4)\)

因为一开始写的是第二种方法,有一个点没过去,加了一点剪枝,代码比较 shit。除了上面提到的没有生成树用并查集判掉,还加了一个记忆化。因为发现不同的 \(d\) 提出的边可能会相同,就哈希了一下。不过加上这玩意 \(3s+\to 800ms\) 还是比较震惊的。。。

方法一非常的稳健,最大点 \(80ms\)

两种方法不那么好写,可能是我码力太菜了,码了 2h 左右才过去。

我感觉我考场上都没有勇气来想这种题的正解,更别说想到之后还有没有足够的时间码出来。

Code for First Solution
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
#define mod 998244353
const LL P = 10000000000000061ll;
inline int qpow(int n, int k) {
	int res = 1;
	for(; k; k >>= 1, n = 1ll * n * n % mod)
		if(k & 1) res = 1ll * n * res % mod;
	return res;
}
const int N = 32;
const int M = 160005;
int n, m, mp[N][N], x[N], y[N], f[N], ans;
int phi[M], pri[M], pct;
bool vis[M], bok[M];
struct node {
	int u, v, w;
	node() { u = v = w = 0; }
	node(int u_, int v_, int w_) { u = u_, v = v_, w = w_; }
} e[N * N];
struct poly {
int a[2];
poly(int x0 = 0, int x1 = 0) { a[0] = x0, a[1] = x1; }
inline int& operator [] (const int &k) { return a[k]; }
friend poly operator * (poly a, poly b) {
	return poly(1ll * a[0] * b[0] % mod, (1ll * a[0] * b[1] + 1ll * a[1] * b[0]) % mod);
}
poly inv() {
	int iv = qpow(a[0], mod - 2);
	return poly(iv, 1ll * iv * (mod - iv) % mod * a[1] % mod);
}
poly& operator += (const poly &b) {
	(a[0] += b.a[0]) %= mod, (a[1] += b.a[1]) %= mod;
	return *this;
}
poly& operator -= (const poly &b) {
	(a[0] += mod - b.a[0]) %= mod, (a[1] += mod - b.a[1]) %= mod;
	return *this;
}
poly operator -() {
	return poly(!a[0] ? 0 : mod - a[0], !a[1] ? 0 : mod - a[1]);
}
	
} a[N][N];
int F[N];
inline int anc(int x) { return x == F[x] ? x : F[x] = anc(F[x]); }
inline poly det(int n) {
	poly res(1, 0);
	for(int i = 0; i < n; ++i) {
		for(int j = i + 1; j < n; ++j) {
			poly tmp = a[j][i] * (a[i][i].inv());
			for(int l = i; l < n; ++l)
				a[j][l] -= a[i][l] * tmp;
		}
		res = res * a[i][i];
	}
	return res;
}
inline void init(const int&n = M - 1) {
	phi[1] = 1;
	for(int i = 2; i <= n; ++i) {
		if(!vis[i]) pri[++pct] = i, phi[i] = i - 1;
		for(int j = 1; j <= pct && i * pri[j] <= n; ++j) {
			vis[i * pri[j]] = 1;
			if(i % pri[j] == 0) {
				phi[i * pri[j]] = phi[i] * pri[j];
				break;
			} else phi[i * pri[j]] = phi[i] * phi[pri[j]];
		}
	}
}
inline int qwq(int x, int*f, int n) {
	int res = 0;
	for(int i = n - 1; i >= 0; --i) res = (1ll * res * x % mod + f[i]) % mod;
	return res;
}
map<LL, int> Map;
inline int calc(int d) {
	if(bok[d]) return 0;
	memset(mp, 0, sizeof(mp));
	int cnt = 0, h = 0;
	rep(i, 0, n - 1) F[i] = i;
	LL pw = 1;
	for(int i = 1; i <= m; ++i, pw = 2ll * pw % P ) {
		int x = e[i].u, y = e[i].v, w = e[i].w;
		if(w % d) continue;
		mp[x][y] = mp[y][x] = w, h = (h + pw) % P;
		if(anc(x) != anc(y)) F[anc(x)] = anc(y), ++cnt;
	}
	if(cnt != n - 1) {
		for(int j = d; j < M; j += d) bok[j] = 1;
		return 0;
	}
	int tmp = Map[h];
	if(tmp) return tmp;
	for(int i = 0; i < n; ++i)
		for(int j = 0; j < n; ++j)
			a[i][j] = poly(0, 0);
	for(int i = 0; i < n; ++i) {
		for(int j = 0; j < n; ++j) {
			if(!mp[i][j]) continue;
			a[i][j] = poly(1, mp[i][j]);
			a[i][i] += a[i][j], a[i][j] = -a[i][j];
		}
	}
	return Map[h] = det(n - 1).a[1];
}
signed main() {
	init();
	n = read(), m = read();
	rep(i, 1, m) e[i].u = read() - 1, e[i].v = read() - 1, e[i].w = read();
	for(int i = 1; i < M; ++i) ans = (ans + 1ll * phi[i] * calc(i)) % mod;
	cout << ans << '\n';
}
Code for Second Solution
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
#define mod 998244353
const LL P = 10000000000000061ll;
inline int qpow(int n, int k) {
	int res = 1;
	for(; k; k >>= 1, n = 1ll * n * n % mod)
		if(k & 1) res = 1ll * n * res % mod;
	return res;
}
const int N = 32;
const int M = 160005;
int n, m, mp[N][N], a[N][N], x[N], y[N], f[N], ans;
int phi[M], pri[M], pct;
bool vis[M], bok[M];
struct node {
	int u, v, w;
	node() { u = v = w = 0; }
	node(int u_, int v_, int w_) { u = u_, v = v_, w = w_; }
} e[N * N];
int F[N];
inline int anc(int x) { return x == F[x] ? x : F[x] = anc(F[x]); }
inline int det(int n) {
	int res = 1, flg = 0;
	for(int i = 0; i < n; ++i) {
		for(int j = i + 1; j < n; ++j) {
			while(a[j][i]) {
				int d = a[i][i] / a[j][i];
				for(int k = i; k < n; ++k)
					a[i][k] = (a[i][k] + mod - 1ll * a[j][k] * d % mod) % mod,
					swap(a[i][k], a[j][k]);
				flg ^= 1;
			}
		}
		if(!a[i][i]) return 0;
		res = 1ll * res * a[i][i] % mod;
	}
	return flg ? mod - res : res;
}
inline void lagrange(int*f, int*x, int*y, int n) {
	static int a[N], b[N], c[N];
	memset(a, 0, n << 2);
	memset(b, 0, (n + 1) << 2);
	memset(c, 0, n << 2);
	memset(f, 0, n << 2);
	for(int i = 0; i < n; ++i) {
		int A = 1;
		for(int j = 0; j < n; ++j) if(i != j)
			A = 1ll * A * (x[i] - x[j] + mod) % mod;
		a[i] = 1ll * y[i] * qpow(A, mod - 2) % mod;
	}
	b[0] = 1;
	for(int i = 0; i < n; ++i) {
		for(int j = i + 1; j >= 1; --j)
			b[j] = (b[j - 1] + 1ll * b[j] * (mod - x[i])) % mod;
		b[0] = 1ll * b[0] * (mod - x[i]) % mod;
	}
	for(int i = 0; i < n; ++i) {
		int iv = qpow(mod - x[i], mod - 2);
		c[0] = 1ll * b[0] * iv % mod;
		for(int j = 1; j < n; ++j)
			c[j] = 1ll * (b[j] - c[j - 1] + mod) * iv % mod;
		for(int j = 0; j < n; ++j)
			f[j] = (f[j] + 1ll * c[j] * a[i]) % mod;
	}
}
inline void init(const int&n = M - 1) {
	phi[1] = 1;
	for(int i = 2; i <= n; ++i) {
		if(!vis[i]) pri[++pct] = i, phi[i] = i - 1;
		for(int j = 1; j <= pct && i * pri[j] <= n; ++j) {
			vis[i * pri[j]] = 1;
			if(i % pri[j] == 0) {
				phi[i * pri[j]] = phi[i] * pri[j];
				break;
			} else phi[i * pri[j]] = phi[i] * phi[pri[j]];
		}
	}
}
inline int qwq(int x, int*f, int n) {
	int res = 0;
	for(int i = n - 1; i >= 0; --i) res = (1ll * res * x % mod + f[i]) % mod;
	return res;
}
map<LL, int> Map;
inline int calc(int d) {
	if(bok[d]) return 0;
	memset(mp, 0, sizeof(mp));
	int cnt = 0, h = 0;
	rep(i, 0, n - 1) F[i] = i;
	LL pw = 1;
	for(int i = 1; i <= m; ++i, pw = 2ll * pw % P ) {
		int x = e[i].u, y = e[i].v, w = e[i].w;
		if(w % d) continue;
		mp[x][y] = mp[y][x] = w, h = (h + pw) % P;
		if(anc(x) != anc(y)) F[anc(x)] = anc(y), ++cnt;
	}
	if(cnt != n - 1) {
		for(int j = d; j < M; j += d) bok[j] = 1;
		return 0;
	}
	int tmp = Map[h];
	if(tmp) return tmp;
	for(int z = 0; z < n; ++z) {
		memset(a, 0, sizeof(a));
		for(int i = 0; i < n; ++i) {
			for(int j = 0; j < n; ++j) {
				if(!mp[i][j]) continue;
				a[i][j] = 1 + (z + 1) * mp[i][j];
				a[i][i] += a[i][j], a[i][j] = mod - a[i][j];
			}
		}
		x[z] = z + 1, y[z] = det(n - 1);
	}
	lagrange(f, x, y, n);
	return Map[h] = f[1];
}
signed main() {
	init();
	n = read(), m = read();
	rep(i, 1, m) e[i].u = read() - 1, e[i].v = read() - 1, e[i].w = read();
	for(int i = 1; i < M; ++i) ans = (ans + 1ll * phi[i] * calc(i)) % mod;
	cout << ans << '\n';
}

消息传递

当作复习点分治板子了,居然写了 40min 才过去,我也是服了我了。。。

是震波的弱化版吧,但是只需要维护距离恰好为 \(k\) 的,而不是 \(\le k\) 的,把震波里的树状数组改成数组少一只 \(\log\),复杂度变成 \(O((n+m)\log n)\),就过去了。

想玩一下,其实也是怕 vector 不开 O2 常数爆炸,写了个什么指针+内存池,结果少开了一位调了好一会。。。

大概思路是容斥,维护 \(f(u,k)\) 表示 \(u\) 这个分治中心的分治子树内部离它距离为 \(k\) 的点的个数,\(g(u,k)\) 表示这个分治中心的分治子树内部离它在分治树上父亲距离为 \(k\) 的点的个数。

查询 \((x,k)\) 的时候,答案就是 \(\sum f(u,k-\operatorname{dis}(u,x))-g(pa_u,k-\operatorname{dis}(pa_u,x))\)

维护 \(f,g\) 要动态内存,\(f\) 只用存到分治树大小,\(g\) 要存到分治树大小加一,总大小是 \(2n\log n\) 的,带修加上点权也很方便,总之这题非常板子就是了。感觉在线写法比离线方便啊 qaq。

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp make_pair
#define pb push_back
#define sz(v) (int)(v).size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
const int N = 100005;
int n, m;
int et, hed[N];
struct edge { int nx, to; } e[N << 1];
inline void adde(int u, int v) {
	e[++et].to = v, e[et].nx = hed[u], hed[u] = et;
}
int ST[20][N << 1], tmr, dfn[N], dep[N], lg[N << 1];
int rt, mx[N], siz[N], used[N], tsiz, vt[N], fsz[N];
int *f[N], *g[N], pool[N * 40], *mem;
void dfs(int u, int ft) {
	dfn[u] = ++tmr, ST[0][tmr] = dep[u];
	for(int i = hed[u]; i; i = e[i].nx) {
		int v = e[i].to;
		if(v == ft) continue;
		dep[v] = dep[u] + 1, dfs(v, u), ST[0][++tmr] = dep[u];
	}
}
inline void init_dis() {
	dfs(1, 0);
	lg[0] = -1;
	for(int i = 1; i <= tmr; ++i) lg[i] = lg[i >> 1] + 1;
	rep(i, 1, lg[tmr]) rep(j, 1, tmr - (1 << i) + 1)
		ST[i][j] = min(ST[i - 1][j], ST[i - 1][j + (1 << (i - 1))]);
}
inline int dis(int x, int y) {
	int l = dfn[x], r = dfn[y];
	if(l > r) l ^= r ^= l ^= r;
	int t = lg[r - l + 1];
	return dep[x] + dep[y] - (min(ST[t][l], ST[t][r - (1 << t) + 1]) << 1);
}
void getrt(int u, int ft) {
	siz[u] = 1, mx[u] = 0;
	for(int i = hed[u]; i; i = e[i].nx) {
		int v = e[i].to;
		if(v == ft || used[v]) continue;
		getrt(v, u), siz[u] += siz[v];
		ckmax(mx[u], siz[v]);
	}
	ckmax(mx[u], tsiz - siz[u]);
	if(mx[u] < mx[rt]) rt = u;
}

void divide(int u) {
	fsz[u] = tsiz;
	f[u] = mem, mem += tsiz;
	for(int *i = f[u]; i != mem; ++i) *i = 0;
	g[u] = mem, mem += tsiz + 1;
	for(int *i = g[u]; i != mem; ++i) *i = 0;
	used[u] = 1;
	for(int i = u; i; i = vt[i]) {
		++f[i][dis(u, i)];
		if(vt[i]) ++g[i][dis(u, vt[i])];
	}
	for(int i = hed[u]; i; i = e[i].nx) {
		int v = e[i].to;
		if(used[v]) continue;
		tsiz = siz[v] > siz[u] ? fsz[u] - siz[u] : siz[v];
		rt = 0, getrt(v, 0), vt[rt] = u, divide(rt);
	}
}
inline int query(int x, int k) {
	int res = 0;
	for(int i = x, d; i; i = vt[i]) {
		d = k - dis(x, i);
		if(0 <= d && d < fsz[i]) {
			res += f[i][d];
		}
		if(vt[i]) {
			d = k - dis(x, vt[i]);
			if(0 <= d && d < fsz[i] + 1) {
				res -= g[i][d];
			}
		}
	}
	return res;
}
inline void clear() {
	et = 0;
	tmr = 0;
	mem = pool;
	memset(vt, 0, sizeof(vt));
	memset(hed, 0, sizeof(hed));
	memset(used, 0, sizeof(used));
}
void Main() {
	n = read(), m = read();
	clear();
	rep(i, 2, n) {
		int x = read(), y = read();
		adde(x, y), adde(y, x);
	}
	init_dis();
	mx[rt = 0] = n, tsiz = n, getrt(1, 0), divide(rt);
	while(m--) {
		int x = read(), k = read();
		printf("%d\n", query(x, k));
	}
}
signed main() {
	for(int T = read(); T; --T) Main();
}

总结

感觉整体难度不大,给我时间我都能做出来。

问题就是,在这么有限的时间内把这么多东西打出来,还要勇于想正解,还是很考验决策能力以及心态的。

自信一点吧!

posted @ 2021-04-03 23:39  zzctommy  阅读(227)  评论(0编辑  收藏  举报