Codeforces Round #533 (Div. 2)C. Ayoub and Lost Array

 

C. Ayoub and Lost Array

Ayoub had an array 𝑎 of integers of size 𝑛 and this array had two interesting properties:

All the integers in the array were between 𝑙 and 𝑟 (inclusive).
The sum of all the elements was divisible by 3.

Unfortunately, Ayoub has lost his array, but he remembers the size of the array 𝑛 and the numbers 𝑙 and 𝑟, so he asked you to find the number of ways to restore the array. Since the answer could be very large, print it modulo 109+7 (i.e. the remainder when dividing by 109+7). In case there are no satisfying arrays (Ayoub has a wrong memory), print 0. Input The first and only line contains three integers 𝑛, 𝑙 and 𝑟 (1≤𝑛≤2⋅105,1≤𝑙≤𝑟≤109) — the size of the lost array and the range of numbers in the array. Output Print the remainder when dividing by 109+7 the number of ways to restore the array. Examples input 2 1 3 output 3 input 3 2 2 output 1 input 9 9 99 output 711426616 Note In the first example, the possible arrays are : [1,2],[2,1],[3,3]. In the second example, the only possible array is [2,2,2].

dp第一维枚举已经取的位数,第二维表示模3的余数:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define INF 0x3f3f3f3f
#define ll long long
#define ull unsigned long long
#define lowbit(x) (x&(-x))
#define eps 0.00000001
#define PI acos(-1)
#define pn printf("\n");
using namespace std;

int mod = 1e9+7;
const int maxn = 2e5+5;
ll dp[maxn][3];
int n;

void solve(ll c0, ll c1, ll c2)
{
    dp[1][0] = c0;
    dp[1][1] = c1;
    dp[1][2] = c2;
    for(int i = 2; i <= n ;i++)
    {
        dp[i][0] = (dp[i-1][1] * c2 % mod + dp[i-1][2] * c1 % mod + dp[i-1][0] * c0 % mod) % mod;
        dp[i][1] = (dp[i-1][1] * c0 % mod + dp[i-1][2] * c2 % mod + dp[i-1][0] * c1 % mod) % mod;
        dp[i][2] = (dp[i-1][1] * c1 % mod + dp[i-1][2] * c0 % mod + dp[i-1][0] * c2 % mod) % mod;
    }
    
}

int main()
{
    ll l, r;
    scanf("%d%lld%lld", &n, &l, &r);
    ll uni = (r - l + 1) / 3;
    ll res = (r - l + 1) % 3;
    ll c0 = 0, c1 = 0, c2 = 0;
    if(l % 3 == 0)
    {
        c0 = uni + (res >= 1);
        c1 = uni + (res >= 2);
        c2 = uni;
    }
    else if(l % 3 == 1)
    {
        c0 = uni;
        c1 = uni + (res >= 1);
        c2 = uni + (res >= 2);
    }
    else
    {
        c0 = uni + (res >= 2);
        c1 = uni;
        c2 = uni + (res >= 1);
    }
    solve(c0, c1, c2);
    printf("%lld\n", dp[n][0]);
}

 

posted @ 2019-01-21 12:28  HazelNuto  阅读(414)  评论(0编辑  收藏  举报