[bzoj3231][SDOI2008]递归数列——矩阵乘法

题目大意:

一个由自然数组成的数列按下式定义:
对于i <= k:ai = bi
对于i > k: ai = c1ai-1 + c2ai-2 + ... + ckai-k
其中bj和 cj (1<=j<=k)是给定的自然数。写一个程序,给定自然数m <= n, 计算am + am+1 + am+2 + ... + an, 并输出它除以给定自然数p的余数的值。

题解

首先显然我们可以构造一个矩阵递推\(a_i\)
如果直接从m递推到n,会超时(我也没写过,不知道),我们在矩阵中加一维,记录\(s_i\),具体可以见代码
注意一个问题:减法取模时应该加成正数。因为这个WA了好几次。

代码

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
using namespace std;
const ll maxn = 20;
ll k, B[maxn], C[maxn], s[maxn];
ll m, n, p;
struct M {
  ll n, m;
  ll a[maxn][maxn];
} a, b;
M operator*(M a, M b) {
  M c;
  c.n = a.n;
  c.m = b.m;
  memset(c.a, 0, sizeof(c.a));
  for (ll i = 1; i <= c.n; i++) {
    for (ll j = 1; j <= c.m; j++) {
      for (ll k = 1; k <= a.m; k++)
        c.a[i][j] = (c.a[i][j] + (ull)(a.a[i][k] * b.a[k][j]) % p) % p;
    }
  }
  return c;
}
M pow(M a, ll b) {
  M ret;
  ret.n = a.n;
  ret.m = a.m;
  memset(ret.a, 0, sizeof(ret.a));
  for (ll i = 1; i <= ret.n; i++)
    ret.a[i][i] = 1;
  while (b) {
    if (b & 1)
      ret = ret * a;
    a = a * a;
    b >>= 1;
  }
  return ret;
}
void print(M x) {
  for (ll i = 1; i <= x.n; i++) {
    for (ll j = 1; j <= x.m; j++)
      cout << x.a[i][j] << ' ';
    cout << endl;
  }
}
ll calc(ll x) {
  M y = pow(a, x + 1);
  // prll (y);
  y = y * b;
  return y.a[1][1];
}
int main() {
  // freopen("input", "r", stdin);
  scanf("%lld", &k);
  a.n = k + 1;
  a.m = k + 1;
  b.n = k + 1;
  b.m = 1;
  s[0] = 0;
  for (ll i = 1; i <= k; i++) {
    scanf("%lld", &B[i]);
  }
  for (ll i = 1; i <= k; i++)
    scanf("%lld", &C[i]);
  scanf("%lld %lld %lld", &m, &n, &p);
  for (ll i = 1; i <= k; i++)
    s[i] = (B[i] + s[i - 1]) % p;
  b.a[1][1] = s[k - 1];
  for (ll i = 1; i <= k; i++)
    b.a[i + 1][1] = B[k - i + 1];
  a.a[1][1] = a.a[1][2] = 1;
  for (ll i = 1; i <= k; i++)
    a.a[2][i + 1] = C[i];
  for (ll i = 1; i < k; i++)
    a.a[i + 2][i + 1] = 1;
  // a = a * b;
  ll ans1, ans2;
  // calc(1);
  if (m - 1 > k)
    ans1 = calc(m - 1 - k);
  else
    ans1 = s[m - 1];
  if (n > k)
    ans2 = calc(n - k);
  else
    ans2 = s[n];
  printf("%lld\n", (ans2 - ans1 + p) % p);
  return 0;
}

posted on 2017-02-23 10:45  蒟蒻konjac  阅读(249)  评论(0编辑  收藏  举报