poj 3735 Training little cats (矩阵快速幂)

 

 

【题意】:有n只猫咪,开始时每只猫咪有花生0颗,现有一组操作,由下面三个中的k个操作组成:
               1. g i 给i只猫咪一颗花生米
               2. e i 让第i只猫咪吃掉它拥有的所有花生米
               3. s i j 将猫咪i与猫咪j的拥有的花生米交换

               现将上述一组操作做m次后,问每只猫咪有多少颗花生?


【题解】:m达到10^9,显然不能直接算。
              因为k个操作给出之后就是固定的,所以想到用矩阵,矩阵快速幂可以把时间复杂度降到O(logm)。问题转化为如何构造转置矩阵?
              说下我的思路,观察以上三种操作,发现第二,三种操作比较容易处理,重点落在第一种操作上。
              有一个很好的办法就是添加一个辅助,使初始矩阵变为一个n+1元组,编号为0到n,下面以3个猫为例:
              定义初始矩阵A = [1 0 0 0],0号元素固定为1,1~n分别为对应的猫所拥有的花生数。
              对于第一种操作g i,我们在单位矩阵基础上使Mat[0][i]变为1,例如g 1:
              1 1 0 0
              0 1 0 0
              0 0 1 0
              0 0 0 1,显然[1 0 0 0]*Mat = [1 1 0 0]
              对于第二种操作e i,我们在单位矩阵基础使Mat[i][i] = 0,例如e 2:
              1 0 0 0
              0 1 0 0
              0 0 0 0
              0 0 0 1, 显然[1 2 3 4]*Mat = [1 2 0 4]
              对于第三种操作s i j,我们在单位矩阵基础上使第i列与第j互换,例如s 1 2:
              1 0 0 0
              0 0 0 1
              0 0 1 0
              0 1 0 0,显然[1 2 0 4]*Mat = [1 4 0 2]
              现在,对于每一个操作我们都可以得到一个转置矩阵,把k个操作的矩阵相乘我们可以得到一个新的转置矩阵T。
              A * T 表示我们经过一组操作,类似我们可以得到经过m组操作的矩阵为 A * T ^ m,最终矩阵的[0][1~n]即为答案。

              上述的做法比较直观,但是实现过于麻烦,因为要构造k个不同矩阵。

              有没有别的方法可以直接构造转置矩阵T?答案是肯定的。
              我们还是以单位矩阵为基础:
              对于第一种操作g i,我们使Mat[0][i] = Mat[0][i] + 1;
              对于第二种操作e i,我们使矩阵的第i列清零;
              对于第三种操作s i j,我们使第i列与第j列互换。
              这样实现的话,我们始终在处理一个矩阵,免去构造k个矩阵的麻烦。

              至此,构造转置矩阵T就完成了,接下来只需用矩阵快速幂求出 A * T ^ m即可,还有一个注意的地方,该题需要用到long long。

              具体实现可以看下面的代码。

              首先想想 为什么这样是可以的呢?

             其实这道题可一看成 1*n 的矩阵 和n*n 的矩阵相乘(因为 对我们有用的信息只是第一行)

            例如 :  上面的例子,第一次是 {0,0,0,1} * T (最后一个 1 的作用是将    加起来 (模拟一下就明白了) )  一次后 变为 {2,0,1,1}* T

 

 

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<string>
#define Min(a,b) a<b?a:b
#define Max(a,b) a>b?a:b
#define CL(a,num) memset(a,num,sizeof(a));
#define maxn  40
#define eps  1e-6
#define inf 9999999
#define mx 1<<60
#define ll   __int64
using namespace std;
ll n,m,k;
struct matrix
{
    ll m[110][110];
    void  clear()
    {
        memset(m,0,sizeof(m));
    }
    void unit()
    {
        clear();                       //错在了这里,没清空
        for(int i = 0; i <= n;i++)
        {
            m[i][i] = 1;
        }
    }
};

matrix a;

void init()
{
    int  i;
   a.clear();
    for(i = 0; i <= n;i++)
    {
        a.m[i][i] = 1;

    }


}

matrix mtmul(matrix a,matrix b)
{
    int i,j,k;
    matrix c;
    c.clear();
    for(i =0 ; i <= n;i++)
    {
        for(k = 0; k <=n;k++)
        {

            if(a.m[i][k])
            for(j = 0 ; j <=n;j++)
            {
                c.m[i][j] += a.m[i][k]*b.m[k][j];

            }
        }
    }
    return c;
}

/*matrix mtpow(matrix a,int k)
{
    matrix b;
    if(k == 1) return a;
    int mid = k / 2;
    b = mtpow(a,mid);
    b = mtmul(b,b);
    if(k&1)
    {
        b = mtmul(b,a);
    }

    return b;


}
*/

matrix  mtpow(matrix a,ll k)
{
    if(k == 1return a;
    matrix e;
    e.unit();
    while(k)
    {
        if(k&1)
        {
           e = mtmul(e,a);
        }

        k>>=1;
        a = mtmul(a,a);


    }
    return e;
}

int main()
{
      ll i ,j,x,y;
     char c[4] ;
    while(scanf("%I64d%I64d%I64d",&n,&m,&k)!=EOF)
    {   if(n ==0 && m ==0 && k == 0break;
         init();
        for(i = 0; i < k;i++)
        {
            scanf("%s",c);
            if(c[0] == 'g')
            {

                scanf("%I64d",&x);

                a.m[0][x] ++ ;
            }
            if(c[0] == 'e')
            {
                 scanf("%I64d",&x);

                for(j = 0; j <= n;j++)
                     a.m[j][x] = 0;
            }
            if(c[0] == 's')
            {
                scanf("%I64d%I64d",&x,&y);
                  if(x == y) continue ;
                for(j = 0; j <= n;j++)
                {
                    swap(a.m[j][x],a.m[j][y]) ;
                }
            }
        }

        if(m == 0)
        {
            for(i = 0; i < n;i++)
            if(i == 0)printf("0");
            else printf(" 0");

            printf("\n");

            continue ;
        }

        a = mtpow(a,m);

        for(i = 1;i <= n;i++)
        {
            if(i == 1 )printf("%I64d",a.m[0][i]);
            else  printf(" %I64d",a.m[0][i]);
        }
        printf("\n");
    }

}

 

 

 

 

   

       

             

posted @ 2012-08-20 21:05  Szz  阅读(683)  评论(0编辑  收藏  举报