题解 P7468【[NOI Online 2021 提高组] 愤怒的小 N】

题解 P7468【[NOI Online 2021 提高组] 愤怒的小 N】

\(parity(x) = popcount(x) \bmod 2\)

problem

首先是有一个字符串 \(S=\texttt{"0"}\),做无限次“将 \(S\) 的每一位取反接在 \(S\) 后面”的操作,形如 \(S=0110100110010110\cdots\)

另外给一个 \(k-1\) 次多项式 \(f\),求 \(\sum_{i=0}^{n-1}S_if(i).\)

\(n\leq 2^{5\times 10^5}, k\leq 500\)

solution 0

第一个观察是 \(S_i=parity(i)\)。因为每次将高位拿掉,值就反转。

考虑 dp。\(dp(i, j, 0/1)\) 表示 \([0,2^i)\)\(parity=0/1\) 的数字的 \(j\) 次方和。

转移

初值为 \(dp(0, j, 0)=[j=0]\) 表示只有 \(0\) 一个数字。

\[\begin{aligned} dp(i, j, e)&=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}(l+2^{i-1})^j\\ &=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}\sum_{t=0}^{j}\binom{t}{j}l^t(2^{i-1})^{j-t}\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}l^t\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}dp(i-1, t, e\oplus 1)\\ \end{aligned} \]

统计答案

  • 取出 \(2^T=lowbit(n), L=n-2^T\)

  • 答案累加 \(\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^t-1}f(l)\)。注意这里 \(l-L, L\) 相加不进位,所以这玩意等于

  • \[\begin{aligned} \displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}f_j(l+L)^j &=\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jl^tL^{j-t}\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}l^t\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}dp(T, t, parity(L)\oplus 1)\\ &=\sum_{t=0}^{k-1}dp(T, t, parity(L)\oplus 1)\sum_{j=t}^{k-1}\binom{j}{t} f_jL^{j-t}\\ \end{aligned} \]

  • \(n:=L\)

  • 明显枚举了所有区间。

optimize

现在的复杂度是 \(O(k^2\log n)\)

重量级结论是,\(i>j\)\(dp(i, j, 0)=dp(i, j, 1)=\frac{1}{2}\sum_{l=0}^{2^i-1}l^j\)。(怎么证明呢,待补,关键是对 \(i-1\to i\) 归纳,用二项式定理展开,考察各项系数)

换句话来说,对于 \(i>j\) 的一大段区间,我们直接求出整段区间的 \(f\) 的和,然后除以二就断定是区间的答案。这一大段区间,只算 \(i\geq k,j<k\) 的,就是 \(0\) 到 “\(n\) 的二进制表示中后面 \(k\) 为改成 \(0\)” 减一,于是可以计算。并观察到 \(f\) 的前缀和是 \(k-1\) 次多项式,考虑直接拉格朗日插值,\(O(k^2)-O(n+k)\) 完成这一部分。

可能发生 \(i<j\) 的区间,假定是 \(i<k\) 的,暴力计算是 \(O(k^3)\) 的。

所以总的复杂度是 \(O(\log n+k^3)\)。就是将其中一个很大的 \(\log n\) 用结论打成 \(k\)

code



#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P>
struct modint {
  unsigned v;
  modint() : v(0) {}
  template <class T>
  modint(T x) {
    x %= (int)P, v = x < 0 ? x + P : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint(0) - *this; }
  modint inv() const { return assert(v), qpow(*this, P - 2); }
  friend int raw(const modint &self) { return self.v; }
  template <class T>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  modint &operator+=(const modint &rhs) {
    if (v += rhs.v, v >= P) v -= P;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    if (v -= rhs.v, v >= P) v += P;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % P;
    return *this;
  }
  modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  friend bool operator==(const modint &lhs, const modint &rhs) {
    return lhs.v == rhs.v;
  }
  friend bool operator!=(const modint &lhs, const modint &rhs) {
    return lhs.v != rhs.v;
  }
};
typedef modint<1000000007> mint;
vector<mint> multiple(const vector<mint> &a, const vector<mint> &b) {
  vector<mint> c(a.size() + b.size() - 1);
  for (int i = 0; i < a.size(); i++) {
    for (int j = 0; j < b.size(); j++) c[i + j] += a[i] * b[j];
  }
  return c;
}
vector<mint> addition(const vector<mint> &a, const vector<mint> &b) {
  vector<mint> c(max(a.size(), b.size()));
  for (int i = 0; i < a.size(); i++) c[i] += a[i];
  for (int i = 0; i < b.size(); i++) c[i] += b[i];
  return c;
}
vector<mint> divide(vector<mint> a, mint b1) {
  vector<mint> res(a.size() - 1);
  for (int i = (int)a.size() - 1; i >= 1; i--) {
    mint coe = res[i - 1] = a[i];
    a[i - 1] -= a[i] * b1;
  }
  return res;
}
vector<mint> numes[510];
mint idenos[510];
vector<mint> lagrange(const vector<mint> &a, const vector<mint> &b) {
  assert(a.size() == b.size());
  vector<mint> ans(a.size());
  for (int i = 0; i < a.size(); i++) {
    mint coe = b[i];
    for (int j = 0; j < a.size(); j++) ans[j] += numes[i][j] * coe;
  }
  return ans;
}
mint getValue(const vector<mint> &a, mint x) {
  mint res = 0;
  for (int i = (int)a.size() - 1; i >= 0; i--) res = res * x + a[i];
  return res;
}
int n, k;
char a[1 << 19];
vector<mint> f, sumG[510], sumF;  // sumG[j](n) = sum{i=0..n-1} i^j
mint dp[510][510][2], qp2[1 << 19], binom[510][510];
const mint inv2 = 1 / mint(2);
void init() {
  for (int i = raw(qp2[0] = 1); i <= max(k * k, n); i++)
    qp2[i] = qp2[i - 1] + qp2[i - 1];
  for (int i = 0; i < k; i++) {
    binom[i][0] = 1;
    for (int j = 1; j <= i; j++)
      binom[i][j] = binom[i - 1][j] + binom[i - 1][j - 1];
  }
  vector<mint> per = {};
  for (int i = 1; i <= k + 1; i++) per.push_back(i);
  vector<mint> ans(per.size()), product = {1};
  for (int i = 0; i < per.size(); i++)
    product = multiple(product, {-per[i], 1});
  for (int i = 0; i < per.size(); i++) {
    numes[i] = divide(product, -per[i]);
    idenos[i] = 1;
    for (int j = 0; j < per.size(); j++)
      if (i != j) idenos[i] *= per[i] - per[j];
    idenos[i] = 1 / idenos[i];
    for (int j = 0; j < per.size(); j++) numes[i][j] *= idenos[i];
  }
  for (int j = 0; j < k; j++) {  //这一段没用,,,
    vector<mint> tmp = {};
    for (int i = 1; i <= k + 1; i++) tmp.push_back(qpow(mint(i - 1), j));
    for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
    sumG[j] = lagrange(per, tmp);
  }
  {
    vector<mint> tmp = {};
    for (int i = 1; i <= k + 1; i++) tmp.push_back(getValue(f, i - 1));
    for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
    sumF = lagrange(per, tmp);
  }
}
void DP() {
  for (int j = 0; j < k; j++) dp[0][j][0] = !j;
  for (int i = 1; i < min(n, k); i++) {
    // for (int i = 1; i < n; i++) {
    memcpy(dp[i], dp[i - 1], sizeof dp[i]);
    for (int j = 0; j < k; j++) {
      for (int e : {0, 1}) {
        for (int t = 0; t <= j; t++) {
          dp[i][j][e] +=
              dp[i - 1][t][1 - e] * binom[j][t] * qp2[(i - 1) * (j - t)];
        }
      }
    }
  }
  // forall i > j, dp[i][j][e] = sumg[j](2^i) / 2
}
mint solve() {
  mint L = 0, ans = 0;
  bool flag = 0;
  if (n > k) {
    mint lim = 0;
    for (int i = n - 1; i >= k; i--)
      if (a[i]) lim += qp2[i];
    ans += getValue(sumF, lim) * inv2;
    for (int i = n - 1; i >= k; i--)
      if (a[i]) {
        L += qp2[i], flag ^= 1;
      }
  }
  for (int i = min(k, n) - 1; i >= 0; i--)
    if (a[i]) {
      for (int t = 0; t < k; t++) {
        mint coe = 0, now = 1;
        for (int j = t; j < k; j++, now *= L) coe += binom[j][t] * f[j] * now;
        ans += dp[i][t][flag ^ 1] * coe;
      }
      L += qp2[i], flag ^= 1;
    }
  return ans;
}
int main() {
  scanf("%s%d", a, &k), n = strlen(a);
  for (int i = 0; i < n; i++) a[i] -= '0';
  reverse(a, a + n);
  f = vector<mint>(k);
  for (int i = 0; i < k; i++) scanf("%u", &f[i].v);
  init(), DP();
  printf("%d\n", raw(solve()));
  return 0;
}

posted @ 2023-10-16 22:52  caijianhong  阅读(12)  评论(0编辑  收藏  举报