[CF2039G] Shohag Loves Pebae 做题记录
高级筛法题。
每条路径的条件是很难求的,考虑将其转化。
发现对于一条路径,点数为 \(c = a\cdot b\),那么其条件是无用的:考虑其包含的所有点数为 \(a\) 的路径,需要满足这 \(c\) 个点的权值乘积不被 \(a\) 整除。
进一步的,只有点数为质数的路径条件才有用。对于每个点 \(i\),求出 \(a_i\) 表示最长的包含点 \(i\) 的路径点数是多少,那么点 \(i\) 分配的权值不能包含 \(\le a_i\) 的质数。
筛出质数,求出 \(c_i\) 表示 \(\le a_i\) 的质数个数,那么点 \(i\) 的权值不能包含前 \(c_i\) 个质数。
不难发现这和 min25 筛前半部分要求的东西很像,设 \(f_{i, j}\) 表示 \([1, j]\) 中除去最小质因子是前 \(i\) 个质数的合数后,剩下的数的个数,那么点 \(i\) 可以分配的权值个数为 \(f_{c_i, m}\)。
加上所有点 \(\gcd\) 为 \(1\) 的条件,考虑莫比乌斯反演。令 \(t \gets \max\limits_i c_i\),\(p_i\) 为第 \(i\) 个质数,答案即为:
枚举 \(d\) 可以整除分块,这样我们需要求:
-
\([1, r]\) 中除去最小质因子是前若干个质数的数后,剩下数的莫比乌斯函数之和。
-
固定 \(t\),求 \(\prod_i f_{c_i, t}\)。
前者可以先杜教筛求一个莫比乌斯函数前缀和,再类似于 min25 前半部分的 DP 求出。
后者考虑统计 \(cnt_j\) 表示 \(\sum\limits_i [c_i = j]\)。设 \(k = \lfloor \sqrt m \rfloor\),则我们只需要用 \(\le k\) 的质数去筛,但是可能 \(p_{c_i} > k\),所以需要特殊处理。
但是这样复杂度会出现问题:\(f\) 的第二维大小为 \(\mathcal O(\sqrt m)\),而 \(c_i\) 可达 \(\mathcal O(\dfrac n {\log n})\),加上快速幂,时间复杂度为 \(\mathcal O(\dfrac {n \sqrt m \log (\frac n {\sqrt m})} {\log n})\)。
瓶颈在于 \(c_i\) 的大小。此时又注意到一个性质:\(\max\limits_i a_i \le 2\min\limits_i a_i\),因为每个点可以到直径的任意一端。
所以我们可以分讨处理:
-
当 \(\max\limits_i a_i \ge 2k\),此时每个点的权值一定是一个质数,用单步容斥代替莫比乌斯反演。
-
当 \(\max\limits_i a_i \ge 2k\),此时 \(c_i\) 的上界得到保证 \(\mathcal O(\dfrac {\sqrt m} {\log m})\),时间复杂度 \(\mathcal O(\dfrac {m \log (\frac n {\sqrt m})} {\log n})\)
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 1e6 + 10, inf = 1e13, mod = 998244353, L = 1e7 + 10;
ll power(ll a, ll b = mod - 2, ll p = mod) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %p;
a = 1ll * a * a %p, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n, m, a[maxn], c[maxn];
namespace Init {
vector <ll> to[maxn];
ll len, pre[maxn], suf[maxn], f[maxn], g[maxn];
void dfs1(ll u, ll fa = 0) {
for(ll v: to[u])
if(v ^ fa) dfs1(v, u), chkmax(f[u], f[v] + 1);
}
void dfs2(ll u, ll fa = 0) {
len = 0;
for(ll v: to[u])
if(v ^ fa) pre[len] = suf[++len] = f[v] + 1;
suf[len + 1] = pre[0] = 0;
for(ll i = 1; i <= len; i++) chkmax(pre[i], pre[i - 1]);
for(ll i = len; i; i--) chkmax(suf[i], suf[i + 1]);
ll c = 0;
for(ll v: to[u])
if(v ^ fa)
++c, g[v] = 1 + max(g[u], max(pre[c - 1], suf[c + 1]));
for(ll v: to[u])
if(v ^ fa) dfs2(v, u);
}
void solve() {
for(ll i = 1; i < n; i++) {
ll u, v; rd(u), rd(v);
to[u].pb(v), to[v].pb(u);
} dfs1(1), dfs2(1);
for(ll u = 1; u <= n; u++) {
ll mx = g[u], se = 0;
for(ll v: to[u])
if(f[v] < f[u]) {
if(f[v] + 1 > mx) se = mx, mx = f[v] + 1;
else chkmax(se, f[v] + 1);
}
a[u] = mx + se + 1;
}
}
}
ll sq, pri[maxn], pr, f[maxn], len, w[maxn], mu[maxn]; bool vis[maxn];
ll id1[maxn], id2[maxn], g[maxn], k, cnt[maxn], sum[maxn];
ll h[maxn];
void xxs() {
for(ll i = 2; i <= 1e6; i++) {
if(!vis[i]) pri[++pr] = i, mu[i] = mod - 1;
for(ll j = 1; j <= pr && i * pri[j] <= 1e6; j++) {
ll k = i * pri[j]; vis[k] = true;
if(i % pri[j]) mu[k] = mod - mu[i];
else break;
}
} k = pr;
while(pri[k] > sq) --k;
mu[1] = 1;
for(ll i = 1; i <= 1e6; i++) sum[i] = pls(sum[i - 1], mu[i]);
}
ll Id(ll x) {return x <= sq? id1[x] : id2[m / x];}
unordered_map <ll, ll> mp;
ll Sum(ll n) {
if(n <= 1e6) return sum[n];
if(mp.count(n)) return mp[n];
ll ret = 1;
for(ll i = 2; i <= n; i++) {
ll d = n / i, r = n / d;
ret = (ret - Sum(d) * (r - i + 1)) %mod;
i = r;
} return mp[n] = pls(ret, mod);
}
int main() {
rd(n), rd(m); Init::solve();
sq = sqrt(m), xxs();
ll mx = 0;
for(ll i = 1; i <= n; i++) {
chkmax(mx, a[i]);
c[i] = upper_bound(pri + 1, pri + 1 + pr, a[i]) - pri - 1;
++cnt[c[i]];
}
for(ll i = 1; i <= m; i++) {
ll d = m / i, r = m / d; w[++len] = d;
if(d <= sq) id1[d] = len;
else id2[m / d] = len;
i = r, f[len] = d, h[len] = Sum(d), g[len] = 1;
}
for(ll i = 1; i <= pr; i++) {
ll o = 0;
for(ll j = 1; j <= len && pri[i] * pri[i] <= w[j]; j++)
f[j] -= f[Id(w[j] / pri[i])] - i, o = j;
if(pri[i] <= mx)
for(ll j = o; j; j--)
h[j] = (h[j] + h[Id(w[j] / pri[i])] + i - 1 + mod) %mod;
if(mx < sq * 2 && cnt[i]) {
for(ll j = 1; j <= len; j++)
g[j] = g[j] * power(max(f[j] - i, 1ll), cnt[i]) %mod;
}
}
if(mx >= sq * 2) {
ll ans = 1;
for(ll i = 1; i <= n; i++) ans = ans * max(1ll, f[1] - c[i]) %mod;
mx = 0;
for(ll i = 1; i <= n; i++) chkmax(mx, c[i]);
if(mx <= f[1]) ans = (ans + mx - f[1] + 1) %mod;
printf("%lld\n", ans); return 0;
} ll ans = 0; mx = 0;
for(ll i = 1; i <= n; i++) chkmax(mx, c[i]);
for(ll i = 1; i <= len; i++) h[i] = (h[i] + min(f[i] - 1, mx)) %mod;
for(ll i = 1; i <= m; i++) {
ll d = m / i, r = m / d;
ans = (ans + (h[Id(r)] - h[Id(i - 1)] + mod) * g[Id(d)]) %mod;
i = r;
} printf("%lld\n", ans);
return 0;
}