P8350 [SDOI/SXOI2022] 进制转换
前 \(i\) 位的意思是从低往高 \(i\) 位,同理,后 \(i\) 位的意思是高往低 \(i\) 位。
一个性质是三进制的前 \(i\) 位只会影响二进制的前 \(\lceil \log 3^i \rceil\) 位,对更高位的影响仅为简单进位。
于是就着这个性质设计状态:\(f(i, j, 0/1)\) 表示还剩三进制的前 \(i\) 位没填,此时二进制的前 \(\lceil \log 3^i\rceil\) 位是 \(j\),是否向前进一的答案。
转移就是暴力枚举这一位填什么。
算算时间复杂度罢~
把所有三进制位对半分,高位部分只有 \(\mathcal O(\sqrt n)\) 个状态会被访问,低位部分可以记忆化,状态总数是 \(\mathcal O(\sqrt n)\),所以就 \(\mathcal O(\sqrt n)\) 了。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
constexpr int N = 30, TOTAL = 28, HALF = 14, MOD = 998244353;
char c1;
int mem[1 << 25][2], l[N], s[N], w[N];
ll n, x[N][3], y[60], z[3], cnt[N];
ll qp(ll base, ll e) {
ll res = 1;
while (e) {
if (e & 1) res = res * base % MOD;
base = base * base % MOD;
e >>= 1;
}
return res;
}
ll dp(int cur, ll S, bool add, bool lim) {
if (cur == -1) return add ? 0 : y[__builtin_popcount(S)];
if (!lim && cur < HALF && mem[s[cur] + S][add] != -1) return mem[s[cur] + S][add];
int res = 0;
for (int i = 0; i <= (lim ? w[cur] : 2); i++) for (ll t : {0, 1}) {
ll curS = S + i * cnt[cur] + (t << l[cur]);
if (add ^ (curS >= (1ll << l[cur + 1]))) continue;
curS &= (1ll << l[cur + 1]) - 1;
res = (res + dp(cur - 1, curS & ((1ll << l[cur]) - 1), t, lim & (i == w[cur])) * x[cur][i] % MOD * y[__builtin_popcount(curS >> l[cur])] % MOD * z[i]) % MOD;
}
if (!lim && cur < HALF) mem[s[cur] + S][add] = res;
return res;
}
char c2;
int main() {
// freopen("ex_conversion3.in", "r", stdin), freopen("conversion.out", "w", stdout);
ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
cerr << (&c2 - &c1) / 1024.0 / 1024 << " MiB\n";
cin >> n >> x[0][1] >> y[1] >> z[1];
y[0] = z[0] = 1, z[2] = z[1] * z[1] % MOD;
for (int i = 2; i < 60; i++) y[i] = y[i - 1] * y[1] % MOD;
cnt[0] = 1;
for (int i = 1; i <= TOTAL + 1; i++) cnt[i] = cnt[i - 1] * 3;
for (int i = 0; i < TOTAL; i++) x[i][0] = 1, x[i][1] = qp(x[0][1], cnt[i]), x[i][2] = x[i][1] * x[i][1] % MOD;
for (int i = 0; i <= TOTAL; i++) {while ((1ll << l[i]) <= cnt[i]) l[i]++; l[i + 1] = l[i];}
for (int i = 1; i <= HALF; i++) s[i] = s[i - 1] + (1 << l[i]);
for (int i = 0; i < TOTAL; i++) w[i] = (n / cnt[i]) % 3;
memset(mem, -1, sizeof(mem));
cout << (dp(TOTAL - 1, 0, 0, 1) - 1 + MOD) % MOD;
cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n";
return 0;
}