P5175题解

题意简述

给出数列 ${a_n}(1\le n\le10^{18})$ 的两项 $a_1,a_2$ 与递推公式 $a_n=xa_{n-1}+ya_{n-2}$,求:$$S_n=\sum_{k=1}^{n}a_k^2\mod (10^9+7)$$

题目分析

一看见 $1\le n\le 10^{18}$,就大概能知道要用 $O(\log n)$ 级别的算法。再一看递推,就知道要用矩阵快速幂了。

题目所求的答案是前 $n$ 项的 $a_k^2$ 之和,如果我们能快速求出每个 $a_k^2$,那么再用矩阵快速幂递推求和是十分简单的。

注意到 $a_{k+1}=xa_k+ya_{k-1}$,那么就有 $a_{k+1}^2=x^2a_k^2+y^2a_{k-1}^2+2xya_{k}a_{k-1}$,如果我们维护了 $a_k^2$ 与 $a_{k-1}^2$,就可以递推地求出 $a_{k+1}^2,a_{k+2}^2$ 等等。

还没完,注意到 $a_{k+1}$ 展开后还有一项 $2xya_{k}a_{k-1}$,去掉系数就还剩下 $a_{k}a_{k-1}$。维护它应该怎么处理呢?将 $a_k$ 拆开得到 $a_{k}a_{k-1}=(xa_{k-1}+ya_{k-2})a_{k-1}=xa_{k-1}^2+ya_{k-1}a_{k-2}$,其中 $a_{k-1}^2$ 已经维护了,所剩的只有 $ya_{k-1}a_{k-2}$。这就意味着我们维护每一个 $a_{k}a_{k-1}$ 就可以求出 $a_{k+1}a_k$ 和 $a_{k+2}a_{k+1}$ 等等。

综上,我们使用答案矩阵$\begin{bmatrix} S_{k} & a_k^2 & a_{k-1}^2 & a_ka_{k-1} \end{bmatrix}$ 进行维护。初始值为$\begin{bmatrix} S_{2}=a_1^2+a_2^2 & a_2^2 & a_1^2 & a_2a_1 \end{bmatrix}$。

为了将其递推到 $\begin{bmatrix} S_{k+1} & a_{k+1}^2 & a_{k}^2 & a_{k+1}a_k \end{bmatrix}$,需要有一个转移矩阵:$$\begin{bmatrix} 1 & 0 & 0 & 0\\ x^2 & x^2 & 1 & x\\ y^2 & y^2 & 0 & 0\\ 2xy & 2xy & 0 & y \end{bmatrix}$$

我们一列一列地进行解释。

对于第 $2$ 列,有 $a_{k+1}^2=0S_k+x^2a_k^2+y^2a_{k-1}^2+2xya_ka_{k-1}$。

那么对于第 $1$ 列,有 $S_{k+1}=S_k+a_{k+1}^2=1S_k+x^2a_k^2+y^2a_{k-1}^2+2xya_ka_{k-1}$。

对于第 $3$ 列,有 $a_k^2=0S_k+1a_k^2+0a_{k-1}^2+0a_ka_{k-1}$,没什么可解释的。

对于第 $4$ 列,有 $a_{k+1}a_k=0S_k+xa_k^2+0a_{k-1}^2+ya_ka_{k-1}$。

现在我们只需要计算出转移矩阵的 $n-2$ 次幂,再乘到答案矩阵上就好了。

需要注意的是要特判 $n=1$,否则会挂掉。然后就是该开 long long 的地方一定要开,不然也会挂。

分析一下时间复杂度。一次矩阵乘法是 $O(4^3)=O(64)$,因此矩阵快速幂的时间复杂度为 $O(64\log n)$,总复杂度为 $O(64T\log n)$,在 $T\le 3\times 10^4,1\le n\le 10^{18}$,时限 $1.50s$ 的情况下用个快读快写就可以过了。

代码实现

#include<bits/stdc++.h>
using namespace std;
const long long mod=1000000007;
int t;
long long n,x,y,a1,a2;
void rd(long long &x)
{
    x=0ll;
    char c=getchar();
    for(;c>'9'||c<'0';c=getchar());
    for(;c<='9'&&c>='0';c=getchar())
        x=(x<<3ll)+(x<<1ll)+c-'0';
}//long long快读 
void rd(int &x)
{
    x=0;
    char c=getchar();
    for(;c>'9'||c<'0';c=getchar());
    for(;c<='9'&&c>='0';c=getchar())
        x=(x<<3)+(x<<1)+c-'0';
}//int快读 
void pt(long long x)
{
    if(x>=10ll)
        pt(x/10ll);
    putchar(x%10ll+'0');
}//快写 
struct matrix
{
    int n,m;
    long long a[5][5];
    void init(long long k)
    {
        for(int i=1;i<=n;i++)
            for(int j=1;j<=m;j++)
                if(i!=j)
                    a[i][j]=0ll;
                else
                    a[i][j]=k;
    }//初始化 k=0零矩阵,k=1单位矩阵 
    friend matrix operator *(matrix a,matrix b)
    {
        matrix c;
        c.n=a.n,c.m=a.m;
        c.init(0ll);
        for(int i=1;i<=a.n;i++)
            for(int k=1;k<=a.m;k++)
                for(int j=1;j<=b.m;j++)
                    c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]%mod)%mod;
        return c;
    }//矩阵乘法 
    friend matrix operator ^(matrix a,long long b)
    {
        matrix res;
        res.n=res.m=a.n;
        res.init(1ll);
        for(;b;b>>=1ll)
        {
            if(b&1ll)
                res=res*a;
            a=a*a;
        }
        return res;
    }//矩阵快速幂 
}ans,trans;
int main()
{
    rd(t);
    while(t--)
    {
        rd(n),rd(a1),rd(a2),rd(x),rd(y);
        if(n==1ll)//特判 n=1,不然会 T 得很惨
        {
            pt(a1*a1%mod);
            putchar('\n');
            continue;
        }
        ans.n=1,ans.m=4;
        ans.a[1][1]=(a2*a2%mod+a1*a1%mod)%mod,ans.a[1][2]=a2*a2%mod,ans.a[1][3]=a1*a1%mod,ans.a[1][4]=a1*a2%mod;//初始化答案矩阵 
        trans.n=trans.m=4;
        trans.init(0ll);
        trans.a[2][1]=trans.a[2][2]=x*x%mod,trans.a[3][1]=trans.a[3][2]=y*y%mod,trans.a[4][1]=trans.a[4][2]=2ll*x*y%mod;
        trans.a[1][1]=trans.a[2][3]=1ll;
        trans.a[2][4]=x,trans.a[4][4]=y;//初始化转移矩阵 
        trans=trans^(n-2ll);//转移矩阵快速幂 
        ans=ans*trans;//答案矩阵乘上转移矩阵 
        pt(ans.a[1][1]);
        putchar('\n');
    }
    return 0;
}
posted @ 2023-06-26 21:42  Hadtsti  阅读(3)  评论(0编辑  收藏  举报  来源