CSP-S 2019 简要题解
从这里开始
又考炸了,sad.....明年应该在准备高考了,考完把坑填了好了。
一半题都被卡常,qswl。[我汤姆要报警.jpg]
dfs 怎么这么慢呀,sad.....
i7 牛逼!
写的比较混乱,可以将就着看就看吧。
Day 1
Problem A
考虑求出最高位是 1 还是 0,可以推出剩下的 $n - 1$ 位二进制数在 $n - 1$ 格雷码的排名。
#include <bits/stdc++.h> using namespace std; typedef bool boolean; template <typename T> boolean chkmin(T& a, T b) { return (a > b) ? (a = b, true) : false; } template <typename T> boolean chkmax(T& a, T b) { return (a < b) ? (a = b, true) : false; } #define forv(_i, _vec) for (vector<int>::iterator (_i) = (_vec).begin(); (_i) != (_vec).end(); (_i)++) #define ull unsigned long long int n; ull k; char ans[70]; int main() { freopen("code.in", "r", stdin); freopen("code.out", "w", stdout); scanf("%d%llu", &n, &k); for (int i = n; i; i--) { ull k0 = 1ull << (i - 1); if (k >= k0) { k -= k0; k = k0 - k - 1; ans[i] = '1'; } else { ans[i] = '0'; } } reverse(ans + 1, ans + n + 1); puts(ans + 1); return 0; }
Problem B
考虑计算以每个位置结尾的串的个数。
把左括号看做 1,右括号看做 0,考虑充要条件是前缀和大于等于 0 且最终和为 0。
单调栈维护一下即可。
#include <bits/stdc++.h> using namespace std; typedef bool boolean; template <typename T> boolean chkmin(T& a, T b) { return (a > b) ? (a = b, true) : false; } template <typename T> boolean chkmax(T& a, T b) { return (a < b) ? (a = b, true) : false; } #define forv(_i, _vec) for (vector<int>::iterator (_i) = (_vec).begin(); (_i) != (_vec).end(); (_i)++) typedef class Input { public: } Input; Input& operator >> (Input& in, int& u) { char x; int flag = 1; while (~(x = getchar()) && (x != '-') && !(x >= '0' && x <= '9')); if (x == '-') { flag = -1; x = getchar(); } for (u = x - '0'; ~(x = getchar()) && (x >= '0' && x <= '9'); u = u * 10+ x - '0'); u *= flag; return in; } Input in; #define ull unsigned long long const int N = 500005; typedef class Data { public: int s, c; Data() { } Data(int s, int c) : s(s), c(c) { } } Data; int n; char s[N]; int sum[N]; ull g[N]; vector<int> G[N]; int tp; Data stk[N]; void dfs(int p, int fa) { sum[p] = sum[fa] + ((s[p] == '(') ? (1) : (-1)); int oldtp = tp; while (tp > 1 && sum[p] < stk[tp].s) tp--; Data oldv; if (stk[tp].s == sum[p]) { oldv = stk[tp]; stk[tp].c++; } else { oldv = stk[tp + 1]; stk[++tp] = Data(sum[p], 1); } g[p] = g[fa] + stk[tp].c - 1; // cerr << p << " " << g[p] << '\n'; forv (e, G[p]) { dfs(*e, p); } stk[tp] = oldv; tp = oldtp; } int main() { freopen("brackets.in", "r", stdin); freopen("brackets.out", "w", stdout); in >> n; scanf("%s", s + 1); for (int i = 2, x; i <= n; i++) { in >> x; G[x].push_back(i); } stk[tp = 1] = Data(0, 1); dfs(1, 0); ull ans = 0; for (int i = 1; i <= n; i++) { ans = ans ^ (1ull * i * g[i]); } printf("%llu\n", ans); // cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s"; return 0; }
Problem C
这题都没当场过。sad.....
不知道为啥之前的做法这么复杂,看起好像绕了一个弯路。好像考虑一下菊花的情况这题就好想多了。
考虑一下一个足够强的一次交换能实现的必要条件,然后猜它是充要的那个并查集维护就没了。我应该是智障。
考虑字典序贪心。
当 $n = 1$ 的时候,直接输出答案就行了。
假设点 $i$ 最终要到 $p_i$,那么在树上加入一条有向路径 $i \rightarrow p_i$。
考虑每次寻找最小的合法的点,问题变成怎么判断能否加入一条路径 $i \rightarrow p_i$。
首先不难注意到任意一条边在同一方向会被经过恰好 1 次,以及当 $n > 1$ 的时候 $p_i \neq i$。
考虑进入点 $x$ 和离开点 $x$ 的路径最终大概形如:
可以发现,把当前点看成它的一个子树,那么每条路径将恰好从它的一个子树到另外一个子树。考虑把原来判断合法的变成这样一个条件:合法当且仅当每个点的所有子树恰好形成 1 个环,并且任意两条路径都没有边的相交(指同一方向)。
考虑必要性
- 不难发现周围几条边的操作顺序是唯一的,因为每次只可能操作连接 $x$ 和 $p_x$ 所在子树根的边。如果存在超过 1 个环,那么一定无解,
- 第二点比较显然。
考虑充分性,当只有 1 个点的时候显然成立。
当 $n > 1$ 的时候,考虑每次操作一条边$(x, y)$,满足 $p_x$ 在以 $x$ 为根,$y$ 的子树内,$p_y$ 在以 $y$ 为根,$x$ 的子树内。
首先证明这样一条边一定存在,考虑任取一个点 $x$,设它的 $p_x$ 在 $y$ 的子树内,如果 $(x, y)$ 不满足条件,那么我们删掉这条边(不是题目中的删除操作),令 $x' = y$。因为树是有限的,所以这一过程是有限的,所以一定会停止。
考虑找到这一条边后,操作它,不难证明新得到的两棵树仍然满足上述条件或大小等于 1。
然后用并查集判一下就好了。
时间复杂度 $O(Tn^2)$。
Code
#include <bits/stdc++.h> using namespace std; typedef bool boolean; const int N = 2005; typedef class Edge { public: int ed, nx; Edge() { } Edge(int ed, int nx) : ed(ed), nx(nx) { } } Edge; typedef class MapManager { public: int h[N]; vector<Edge> E; void init(int n) { memset(h, -1, sizeof(int) * (n + 1)); E.clear(); } void add_edge(int u, int v) { E.push_back(Edge(v, h[u])); h[u] = (signed) E.size() - 1; } Edge& operator [] (int p) { return E[p]; } } MapManager; int T, n, n2; int deg[N]; int pos[N]; MapManager g; int uf[N * 3]; int fa[N], fe[N]; bitset<N * 2> vise; bitset<N> occ, vis; int find(int x) { return uf[x] == x ? x : (uf[x] = find(uf[x])); } void dfs(int p, int fa, int fe) { if (!occ.test(p) && fa && (deg[p] == 1 || (find(n2 + p) ^ find(fe ^ 1)))) vis.set(p); ::fa[p] = fa, ::fe[p] = fe; for (int i = g.h[p]; ~i; i = g[i].nx) { int e = g[i].ed; if (e == fa) continue; if (vise.test(i)) continue; if ((deg[p] != 2 - !fa || (!occ.test(p) && deg[p] == 2)) && find(fe ^ 1) == find(i)) continue; dfs(e, p, i); } } void solve() { scanf("%d", &n); g.init(n); for (int i = 1; i <= n; i++) { scanf("%d", pos + i); } memset(deg, 0, sizeof(int) * (n + 1)); for (int i = 1, u, v; i < n; i++) { scanf("%d%d", &u, &v); g.add_edge(u, v); // cerr << u << " " << v << " " << (signed) g.E.size() - 1 << '\n'; g.add_edge(v, u); // cerr << v << " " << u << " " << (signed) g.E.size() - 1 << '\n'; deg[u] += 2; deg[v] += 2; } if (n == 1) { printf("%d\n", 1); return; } occ.reset(); vise.reset(); n2 = n << 1; for (int i = 0; i <= (3 * n); i++) uf[i] = i; for (int i = 1; i <= n; i++) { int p = pos[i]; vis.reset(); dfs(p, 0, (n2 + p) ^ 1); // assert(vis.count()); // cerr << "# " << i << " " << p << '\n'; for (int j = 1; j <= n; j++) { if (vis.test(j)) { printf("%d ", j); occ.set(j); int x = j, e1 = n2 + j, e2; while (x ^ p) { deg[x] -= 2 - (e1 > n2); e2 = fe[x]; // cerr << "Link " << e1 << " " << (e2 ^ 1) << '\n'; uf[find(e1)] = find(e2 ^ 1); vise.set(e2); swap(e1, e2); x = fa[x]; } deg[p]--; uf[find(n2 + p)] = find(e1); // cerr << "Link " << n2 + p << " " << (e1) << '\n'; break; } } } putchar('\n'); } int main() { // freopen("tree.in", "r", stdin); // freopen("tree.out", "w", stdout); scanf("%d", &T); while (T--) { solve(); } // cerr << clock() << "ms"; return 0; }
Day 2
手速慢,脑速也慢,sad....
Problem A
枚举哪一种主食材料不合法,然后 dp 记一下选择它的数量和不选的差。
时间复杂度 $O(mn^2)$。好像有点卡常。
Code
#include <bits/stdc++.h> using namespace std; typedef bool boolean; template <typename T> boolean chkmin(T& a, T b) { return (a > b) ? (a = b, true) : false; } template <typename T> boolean chkmax(T& a, T b) { return (a < b) ? (a = b, true) : false; } #define forv(_i, _vec) for (vector<int>::iterator (_i) = (_vec).begin(); (_i) != (_vec).end(); (_i)++) const int Mod = 998244353; #define ll long long void exgcd(int a, int b, int& x, int& y) { if (!b) { x = 1, y = 0; } else { exgcd(b, a % b, y, x); y -= (a / b) * x; } } int inv(int a) { int x, y; exgcd(a, Mod, x, y); return (x < 0) ? (x + Mod) : (x); } typedef class Zi { public: int v; Zi() : v(0) { } Zi(int v) : v(v) { } Zi(ll x) : v(x % Mod) { } friend Zi operator + (Zi a, Zi b) { return ((a.v += b.v) >= Mod) ? (a.v -= Mod) : (a); } friend Zi operator - (Zi a, Zi b) { return ((a.v -= b.v) < 0) ? (a.v += Mod) : (a); } friend Zi operator * (Zi a, Zi b) { return 1ll * a.v * b.v; } friend Zi operator ~ (Zi a) { return inv(a.v); } Zi& operator += (Zi b) { return *this = *this + b; } Zi& operator -= (Zi b) { return *this = *this - b; } Zi& operator *= (Zi b) { return *this = *this * b; } } Zi; const ll Mod2 = 3ll * Mod * Mod; const int N = 105, M = 2005; int n, m; Zi sum[N]; Zi a[N][M]; ll f[N][250]; void fix2(ll& x) { (x >= Mod2) && (x -= Mod2); } int main() { freopen("meal.in", "r", stdin); freopen("meal.out", "w", stdout); scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { for (int j = 1; j <= m; j++) { scanf("%d", &a[i][j].v); sum[i] = sum[i] + a[i][j]; } } Zi ans = 1; for (int i = 1; i <= n; i++) ans *= (sum[i] + 1); ans = ans - 1; // cerr << "ans: " << ans.v << '\n'; const int V = 110; for (int b = 1; b <= m; b++) { memset(f, 0, sizeof(f)); f[0][V] = 1; for (int i = 1; i <= n; i++) { ll x = a[i][b].v, y = (sum[i] - a[i][b]).v; ll *g = f[i - 1] + V, *h = f[i] + V; for (int j = -i + 1; j < i; j++) { ll v = g[j]; if (v) { fix2(h[j + 1] += v * x); fix2(h[j - 1] += v * y); fix2(h[j] += v); } } for (int j = -i; j <= i; j++) { if (h[j]) { h[j] %= Mod; } } } Zi tmp = 0; for (int i = 1; i <= n; i++) tmp += f[n][i + V]; ans -= tmp; } printf("%d\n", ans.v); // cerr << 1.0 * clock() / CLOCKS_PER_SEC << '\n'; return 0; }
Problem B
只用考虑以每个前缀,最后一段长度最小的合法划分。
我还不会证,但它看上去很对。
设第 $i$ 个前缀最后一段的和为 $f_i$,设当前考虑的前缀为 $j$,问题相当于要找出满足 $f_i - S(i + 1, j) < 0$ 的最大的 $j$。
单调队列优化即可。
高精度 rush 成功,读入 rush 失败。
Code
#include <bits/stdc++.h> using namespace std; typedef bool boolean; typedef class Input { protected: const static int limit = 65536; FILE* file; int ss, st; char buf[limit]; public: Input() : file(NULL) { }; Input(FILE* file) : file(file) { } void open(FILE *file) { this->file = file; } void open(const char* filename) { file = fopen(filename, "r"); } char pick() { if (ss == st) st = fread(buf, 1, limit, file), ss = 0;//, cerr << "str: " << buf << "ed " << st << endl; return (ss == st) ? (-1) : (buf[ss++]); } } Input; #define digit(_x) ((_x) >= '0' && (_x) <= '9') Input& operator >> (Input& in, unsigned& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, unsigned long long& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, int& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, long long& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, double& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); if (x == '.') { double dec = 1; for ( ; ~(x = in.pick()) && digit(x); u = u + (dec *= 0.1) * (x - '0')); } u *= aflag; return in; } Input& operator >> (Input& in, char* str) { char x; while (~(x = in.pick()) && x != '\n' && x != ' ') *(str++) = x; *str = 0; return in; } Input in (stdin); typedef class Output { protected: const static int Limit = 65536; char *tp, *ed; char buf[Limit]; FILE* file; int precision; void flush() { fwrite(buf, 1, tp - buf, file); fflush(file); tp = buf; } public: Output() { } Output(FILE* file) : tp(buf), ed(buf + Limit), file(file), precision(6) { } Output(const char *str) : tp(buf), ed(buf + Limit), precision(6) { file = fopen(str, "w"); } ~Output() { flush(); } void put(char x) { if (tp == ed) flush(); *(tp++) = x; } int get_precision() { return precision; } void set_percision(int x) { precision = x; } } Output; Output& operator << (Output& out, int x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, long long x) { static char buf[36]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, unsigned x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, char x) { out.put(x); return out; } Output& operator << (Output& out, const char* str) { for ( ; *str; out.put(*(str++))); return out; } Output& operator << (Output& out, double x) { int y = x; x -= y; out << y << '.'; for (int i = out.get_precision(); i; i--, y = x * 10, x = x * 10 - y, out.put(y + '0')); return out; } Output out (stdout); #define ll long long const int Base = 1e9; typedef class BigInteger { public: vector<ll> a; BigInteger() : BigInteger(0) { } BigInteger(ll x) { if (!x) { a.resize(1); a[0] = 0; } else { while (x) { a.push_back(x % Base); x /= Base; } } } int length() const { return a.size(); } void resize(int newsize) { a.resize(newsize, 0ll); } void shrink() { while (length() > 1 && !a.back()) a.pop_back(); } ll& operator [] (int p) { return a[p]; } ll at(int p) const { return a[p]; } friend BigInteger operator + (BigInteger a, BigInteger b) { BigInteger rt; int n = max(a.length(), b.length()) + 1; rt.resize(n); a.resize(n); b.resize(n); for (int i = 0; i < n - 1; i++) { rt[i] += a[i] + b[i]; rt[i + 1] += rt[i] / Base; rt[i] %= Base; } rt.shrink(); return rt; } friend BigInteger operator * (BigInteger a, BigInteger b) { BigInteger rt; int n = a.length() + b.length(); rt.resize(n); for (int i = 0; i < a.length(); i++) { for (int j = 0; j < b.length(); j++) { rt[i + j] += a[i] * b[j]; rt[i + j + 1] += rt[i + j] / Base; rt[i + j] %= Base; } } rt.shrink(); return rt; } friend Output& operator << (Output& out, const BigInteger& x) { static char buf[14]; out << x.a.back(); signed int n = x.length(); for (int i = n - 2; i >= 0; i--) { char* top = buf + 12; ll tmp = x.at(i); *top = 0; for (int j = 0; j < 9; j++) { *(--top) = '0' + tmp % 10; tmp /= 10; } out << top; } return out; } } BigInteger; #define uint unsigned int int n; ll *s; int *a, *g, *Q; ll S(int l, int r) { return (!l) ? (s[r]) : (s[r] - s[l - 1]); } int main() { int op; in >> n >> op; a = new int[(n + 1)]; if (op == 0) { for (int i = 1; i <= n; i++) { in >> a[i]; } } else { uint x, y, z; uint *b = new uint[(n + 1)]; int m; in >> x >> y >> z >> b[1] >> b[2] >> m; int *p = new int[(m + 1)]; int *l = new int[(n + 1)]; int *r = new int[(n + 1)]; for (int i = 1; i <= m; i++) { in >> p[i] >> l[i] >> r[i]; } uint msk = (1u << 30) - 1; for (int i = 3; i <= n; i++) { b[i] = (x * b[i - 1] + y * b[i - 2] + z) & msk; } for (int i = 1, j = 1; i <= n; i++) { if (i > p[j]) j++; a[i] = b[i] % (r[j] - l[j] + 1) + l[j]; } delete[] p; delete[] l; delete[] r; delete[] b; } g = new int[(n + 1)]; s = new ll[(n + 1)]; Q = new int[(n + 10)]; s[0] = 0; for (int i = 1; i <= n; i++) { s[i] = s[i - 1] + a[i]; } int st = 1, ed = 1; Q[1] = 0, g[0] = 0; for (int i = 1; i <= n; i++) { while (st < ed && S(g[Q[st + 1]], Q[st + 1]) <= S(Q[st + 1] + 1, i)) st++; g[i] = Q[st] + 1; while (st <= ed && S(g[Q[ed]], Q[ed]) - S(Q[ed] + 1, i) >= S(g[i], i)) ed--; Q[++ed] = i; } BigInteger ans (0); for (int i = n; i; i = g[i] - 1) { ans = ans + BigInteger(S(g[i], i)) * S(g[i], i); } out << ans << '\n'; return 0; }
Problem C
先讲一下 $O(n\log n)$ 的垃圾做法:
考虑计算每个点作为中心的答案。
考虑以任意一个重心为根,重心的答案可以暴力计算,对于剩下的点必须删掉重心所在的子树中的一条边。
讨论一下这条边是在这个点到根的路径,还是其他地方。然后设删掉的边较浅一端的大小为 $x$,之后就是一个傻逼二维数点问题。树状数组维护即可。
有线性做法,我先咕着。
下面是线性做法,虽然我的线性做法比较垃圾,不开 O2 常数被吊锤。
考虑树链剖分,设根所在的重链为 $L$,如果一个点 $p$ 不在 $L$ 上,那么如果 $p$ 要成为重心,那么删掉的边必须满足下面任意条件之一,设删掉的边中较低端点为 $x$
- $x$ 是 $p$ 的祖先,并且 $x$ 和 $p$ 在同一条重链上。
- 设 $y$ 是 $p$ 的祖先中第一个在 $L$ 的点,$x$ 是 $y$ 或者 $x$ 在 $y$ 的重子树内。
对于 $L$ 上计算答案,注意到查询左端点或右端点是单调的,简单维护一下可以做到 $O(n)$。
计算不在 $L$ 上的点的第一部分答案类似,不难做到和链长线性相关的复杂度。
然后考虑计算第二部分的答案,考虑按 $L$ 上从深到浅扫描线,注意到查询总是在轻子树内,注意到查询涉及到最大的 $n - 2sz$ 中的 $sz$ 是最大轻子树的大小。考虑暴力预处理这一部分的后缀和,然后再计算轻子树内的点的答案。因为 $L$ 上每个点的轻子树不会相交,所以预处理总复杂度为 $O(n)$。
所以总复杂度 $O(n)$。
(第一份是带 log 的做法,第二份是线性)
Code1
#include <bits/stdc++.h> using namespace std; typedef bool boolean; typedef class Input { protected: const static int limit = 65536; FILE* file; int ss, st; char buf[limit]; public: Input() : file(NULL) { }; Input(FILE* file) : file(file) { } void open(FILE *file) { this->file = file; } void open(const char* filename) { file = fopen(filename, "r"); } char pick() { if (ss == st) st = fread(buf, 1, limit, file), ss = 0;//, cerr << "str: " << buf << "ed " << st << endl; return (ss == st) ? (-1) : (buf[ss++]); } } Input; #define digit(_x) ((_x) >= '0' && (_x) <= '9') Input& operator >> (Input& in, unsigned& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, unsigned long long& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, int& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, long long& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, double& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); if (x == '.') { double dec = 1; for ( ; ~(x = in.pick()) && digit(x); u = u + (dec *= 0.1) * (x - '0')); } u *= aflag; return in; } Input& operator >> (Input& in, char* str) { char x; while (~(x = in.pick()) && x != '\n' && x != ' ') *(str++) = x; *str = 0; return in; } Input in (stdin); typedef class Output { protected: const static int Limit = 65536; char *tp, *ed; char buf[Limit]; FILE* file; int precision; void flush() { fwrite(buf, 1, tp - buf, file); fflush(file); tp = buf; } public: Output() { } Output(FILE* file) : tp(buf), ed(buf + Limit), file(file), precision(6) { } Output(const char *str) : tp(buf), ed(buf + Limit), precision(6) { file = fopen(str, "w"); } ~Output() { flush(); } void put(char x) { if (tp == ed) flush(); *(tp++) = x; } int get_precision() { return precision; } void set_percision(int x) { precision = x; } } Output; Output& operator << (Output& out, int x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, long long x) { static char buf[36]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, unsigned x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, char x) { out.put(x); return out; } Output& operator << (Output& out, const char* str) { for ( ; *str; out.put(*(str++))); return out; } Output& operator << (Output& out, double x) { int y = x; x -= y; out << y << '.'; for (int i = out.get_precision(); i; i--, y = x * 10, x = x * 10 - y, out.put(y + '0')); return out; } Output out (stdout); const int N = 3e5 + 5; typedef class Fenwick { public: int n; int a[N]; void init(int n) { this->n = n; memset(a, 0, sizeof(int) * (n + 1)); } void add(int idx, int val) { for ( ; idx <= n; idx += (idx & (-idx))) a[idx] += val; } int query(int idx) { int rt = 0; for ( ; idx; idx -= (idx & (-idx))) rt += a[idx]; return rt; } int query(int l, int r) { if (l > r) return 0; return query(r) - query(l - 1); } } Fenwick; #define forv(_i, _v) for (vector<int>::iterator _i = (_v).begin(); (_i) != (_v).end(); (_i)++) #define ll long long int T, n, g; int sz[N]; Fenwick fen; vector<int> G[N]; int get_sz(int p, int fa) { sz[p] = 1; forv (e, G[p]) { if (*e ^ fa) { get_sz(*e, p); sz[p] += sz[*e]; } } return sz[p]; } int get_centroid(int p, int fa) { forv (e, G[p]) { if ((*e ^ fa) && sz[*e] > (n >> 1)) { return get_centroid(*e, p); } } return p; } int cnt[N], mxsz[N]; void dfs1(int p, int fa) { mxsz[p] = 0; forv (e, G[p]) { if (*e ^ fa) { mxsz[p] = max(mxsz[p], sz[*e]); } } int L = max(1, 2 * (n - sz[p]) - n); int R = n - 2 * mxsz[p]; cnt[p] += fen.query(L, R); fen.add(sz[p], -1); L = max(2 * mxsz[p], 1); R = min(2 * sz[p], n - 1); cnt[p] -= fen.query(L, R); forv (e, G[p]) { if (*e ^ fa) { dfs1(*e, p); } } fen.add(sz[p], 1); } void dfs2(int p, int fa) { int L = max(1, 2 * (n - sz[p]) - n); int R = n - 2 * mxsz[p]; cnt[p] += fen.query(L, R); fen.add(sz[p], 1); forv (e, G[p]) { if (*e ^ fa) { dfs2(*e, p); } } cnt[p] -= fen.query(L, R); } void calc(int p, int fa, int szcur, int szmx) { int lim = ((n - sz[p]) >> 1); cnt[g] += (szcur - sz[p] <= lim && szmx <= lim); forv (e, G[p]) { if (*e ^ fa) { calc(*e, p, szcur, szmx); } } } int tmp[N]; void solve() { in >> n; for (int i = 1; i <= n; i++) G[i].clear(); for (int i = 1, u, v; i < n; i++) { in >> u >> v; G[u].push_back(v); G[v].push_back(u); } get_sz(1, 0); g = get_centroid(1, 0); get_sz(g, 0); memset(cnt, 0, sizeof(int) * (n + 1)); // cerr << "centroid: " << g << '\n'; // part1 fen.init(n); // for (int i = 1; i <= n; i++) // fen.add(sz[i], 1); dfs1(g, 0); memset(tmp, 0, sizeof(int) * (n + 1)); for (int i = 1; i <= n; i++) tmp[sz[i]]++; for (int i = 1; i <= n; i++) tmp[i] += tmp[i - 1]; for (int p = 1; p <= n; p++) { int L = max(1, 2 * (n - sz[p]) - n); int R = n - 2 * mxsz[p]; if (L <= R) { cnt[p] += tmp[R] - tmp[L - 1]; } } // part2 fen.init(n); dfs2(g, 0); // for (int i = 1; i <= n; i++) { // cerr << cnt[i] << " "; // } // cerr << '\n'; // part4 cnt[g] = 0; int mx = 0, id = -1, sc = 0; forv (p, G[g]) { if (sz[*p] > mx) { swap(mx, sc); mx = sz[*p]; id = *p; } else if (sz[*p] > sc) { sc = sz[*p]; } } forv (p, G[g]) { if (*p == id) { calc(*p, g, mx, sc); } else { calc(*p, g, sz[*p], mx); } } ll ans = 0; for (int i = 1; i <= n; i++) { ans += 1ll * i * cnt[i]; // cerr << cnt[i] << " "; } // cerr << '\n'; out << ans << '\n'; } int main() { in >> T; while (T--) { solve(); } return 0; }
Code2
#include <bits/stdc++.h> using namespace std; typedef bool boolean; typedef class Input { protected: const static int limit = 65536; FILE* file; int ss, st; char buf[limit]; public: Input() : file(NULL) { }; Input(FILE* file) : file(file) { } void open(FILE *file) { this->file = file; } void open(const char* filename) { file = fopen(filename, "r"); } char pick() { if (ss == st) st = fread(buf, 1, limit, file), ss = 0;//, cerr << "str: " << buf << "ed " << st << endl; return (ss == st) ? (-1) : (buf[ss++]); } } Input; #define digit(_x) ((_x) >= '0' && (_x) <= '9') Input& operator >> (Input& in, unsigned& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, unsigned long long& u) { char x; while (~(x = in.pick()) && !digit(x)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); return in; } Input& operator >> (Input& in, int& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, long long& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); u *= aflag; return in; } Input& operator >> (Input& in, double& u) { char x; while (~(x = in.pick()) && !digit(x) && x != '-'); int aflag = ((x == '-') ? (x = in.pick(), -1) : (1)); for (u = x - '0'; ~(x = in.pick()) && digit(x); u = u * 10 + x - '0'); if (x == '.') { double dec = 1; for ( ; ~(x = in.pick()) && digit(x); u = u + (dec *= 0.1) * (x - '0')); } u *= aflag; return in; } Input& operator >> (Input& in, char* str) { char x; while (~(x = in.pick()) && x != '\n' && x != ' ') *(str++) = x; *str = 0; return in; } Input in (stdin); typedef class Output { protected: const static int Limit = 65536; char *tp, *ed; char buf[Limit]; FILE* file; int precision; void flush() { fwrite(buf, 1, tp - buf, file); fflush(file); tp = buf; } public: Output() { } Output(FILE* file) : tp(buf), ed(buf + Limit), file(file), precision(6) { } Output(const char *str) : tp(buf), ed(buf + Limit), precision(6) { file = fopen(str, "w"); } ~Output() { flush(); } void put(char x) { if (tp == ed) flush(); *(tp++) = x; } int get_precision() { return precision; } void set_percision(int x) { precision = x; } } Output; Output& operator << (Output& out, int x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, long long x) { static char buf[36]; static char * const lim = buf + 34; if (!x) out.put('0'); else { if (x < 0) out.put('-'), x = -x; char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, unsigned x) { static char buf[35]; static char * const lim = buf + 34; if (!x) out.put('0'); else { char *tp = lim; for ( ; x; *(--tp) = x % 10, x /= 10); for ( ; tp != lim; out.put(*(tp++) + '0')); } return out; } Output& operator << (Output& out, char x) { out.put(x); return out; } Output& operator << (Output& out, const char* str) { for ( ; *str; out.put(*(str++))); return out; } Output& operator << (Output& out, double x) { int y = x; x -= y; out << y << '.'; for (int i = out.get_precision(); i; i--, y = x * 10, x = x * 10 - y, out.put(y + '0')); return out; } Output out (stdout); const int N = 3e5 + 5; typedef class Event { public: int p, x, v; Event() { } Event(int p, int x, int v) : p(p), x(x), v(v) { } } Event; #define forv(_i, _v) for (vector<int>::iterator _i = (_v).begin(); (_i) != (_v).end(); (_i)++) #define ll long long int T, n, g; int sz[N]; int cnt[N]; vector<int> G[N]; int get_sz(int p, int fa) { sz[p] = 1; forv (e, G[p]) { if (*e ^ fa) { get_sz(*e, p); sz[p] += sz[*e]; } } return sz[p]; } int get_centroid(int p, int fa) { forv (e, G[p]) { if ((*e ^ fa) && sz[*e] > (n >> 1)) { return get_centroid(*e, p); } } return p; } int zson[N]; void dfs1(int p, int fa) { int mx = -1, &id = zson[p]; id = 0; forv (e, G[p]) { if (*e ^ fa) { dfs1(*e, p); if (sz[*e] > mx) { mx = sz[*e]; id = *e; } } } } int tp; int stk[N]; void work() { reverse(stk + 1, stk + tp + 1); stk[tp + 1] = 0; for (int *p = stk + 1, *q = stk + 1, *_p = stk + tp + 1; p != _p; p++) { while (sz[*q] > (sz[*p] << 1)) q++; cnt[*p] += p - q; } for (int *p = stk + 1, *q = stk + 1, *_p = stk + tp + 1; p != _p; p++) { int sc = sz[zson[*p]]; while (q != _p && sz[*q] >= (sc << 1)) q++; cnt[*p] -= p - q; } } void dfs2(int p, int fa) { if (zson[p]) { forv (e, G[p]) { if ((*e ^ fa) && (*e ^ zson[p])) { dfs2(*e, p); work(); } } dfs2(zson[p], p); } else { tp = 0; } stk[++tp] = p; } void calc(int p, int fa, int szcur, int szmx) { int lim = ((n - sz[p]) >> 1); cnt[g] += (szcur - sz[p] <= lim && szmx <= lim); forv (e, G[p]) { if (*e ^ fa) { calc(*e, p, szcur, szmx); } } } int tmp[N]; void dfs3(int p, int fa, int R, int& sum) { tmp[sz[p]]++; if (sz[p] <= R) sum++; forv (e, G[p]) { if (*e ^ fa) { dfs3(*e, p, R, sum); } } } vector<int> P; void dfs4(int p, int fa) { P.push_back(p); forv (e, G[p]) { if (*e ^ fa) { dfs4(*e, p); } } } void solve() { in >> n; for (int i = 1; i <= n; i++) G[i].clear(); for (int i = 1, u, v; i < n; i++) { in >> u >> v; G[u].push_back(v); G[v].push_back(u); } get_sz(1, 0); g = get_centroid(1, 0); get_sz(g, 0); memset(cnt, 0, sizeof(int) * (n + 1)); // cerr << "g: " << g << '\n'; dfs1(g, 0); // cerr << "zson: "; // for (int i = 1; i <= n; i++) { // cerr << zson[i] << " "; // } // cerr << '\n'; dfs2(g, 0); work(); memset(tmp, 0, sizeof(int) * (n + 1)); int R = 0, sum = 0; for (int *p = stk + 1, *_p = stk + tp + 1; p != _p; p++) { int q = n - (sz[zson[*p]] << 1); while (R < q) sum += tmp[++R]; cnt[*p] += sum; forv (e, G[*p]) { if (sz[*e] < sz[*p] && (*e ^ zson[*p])) { dfs3(*e, *p, R, sum); } } } memset(tmp, 0, sizeof(int) * (n + 1)); R = 0, sum = 0; for (int *p = stk + 1, *_p = stk + tp + 1; p != _p; p++) { int q = n - (sz[*p] << 1) - 1; while (R < q) sum += tmp[++R]; cnt[*p] -= sum; forv (e, G[*p]) { if (sz[*e] < sz[*p] && (*e ^ zson[*p])) { dfs3(*e, *p, R, sum); } } } reverse(stk + 1, stk + tp + 1); memset(tmp, 0, sizeof(int) * (n + 1)); for (int *p = stk + 1, *_p = stk + tp + 1; p != _p; p++) { int sc = 0; P.clear(); forv (e, G[*p]) { if (sz[*e] < sz[*p] && (*e ^ zson[*p])) { sc = max(sz[*e], sc); dfs4(*e, *p); } } sc = min(sc << 1 | 1, n); for (int j = n, t = sc; t; j--, t--) { tmp[j] += tmp[j + 1]; } forv (e, P) { int l = max(n - (sz[*e] << 1), 1); int r = min(n - (sz[zson[*e]] << 1), n - 1); cnt[*e] += tmp[l] - tmp[r + 1]; } for (int j = n - sc + 1; j <= n; j++) { tmp[j] -= tmp[j + 1]; } tmp[sz[*p]]++; forv (e, P) { tmp[sz[*e]]++; } } // get the answer for the centroid cnt[g] = 0; int mx = 0, id = -1, sc = 0; forv (p, G[g]) { if (sz[*p] > mx) { swap(mx, sc); mx = sz[*p]; id = *p; } else if (sz[*p] > sc) { sc = sz[*p]; } } forv (p, G[g]) { if (*p == id) { calc(*p, g, mx, sc); } else { calc(*p, g, sz[*p], mx); } } ll ans = 0; for (int i = 1; i <= n; i++) { ans += 1ll * i * cnt[i]; // cerr << cnt[i] << " "; } // cerr << '\n'; out << ans << '\n'; } int main() { freopen("centroid.in", "r", stdin); freopen("centroid.out", "w", stdout); in >> T; while (T--) { solve(); } return 0; }