CF 1042 E. Vasya and Magic Matrix

E. Vasya and Magic Matrix

http://codeforces.com/contest/1042/problem/E

题意:

  一个n*m的矩阵,每个位置有一个元素,给定一个起点,每次随机往一个小于这个点位置走,走过去的值为欧几里得距离的平方,求期望的值。

分析:

  逆推期望。

  将所有值取出,按元素大小排序,然后最小的就是0,往所有大于它的转移即可,复杂度n^2,见下方考试代码。

  考虑每个点,从所有小于它的元素转移。排序后,维护前缀和,可以做到O(1)转移。

  $f[i] = \sum\limits_{j=1,val[j]<val[i]}f[j] + (x_j - x_i)^2 + (y_j - y_i) ^ 2$

  $f[i] =\sum\limits_{j=1,val[j]<val[i]} f[j] + x_j^2 - 2x_jx_i + x_i^2 + y_j^2 - 2y_jy_i + y_i ^ 2$

代码:

 1 #include<cstdio>
 2 #include<algorithm>
 3 #include<cstring>
 4 #include<cmath>
 5 #include<iostream>
 6 #include<cctype>
 7 #include<set>
 8 #include<vector>
 9 #include<queue>
10 #include<map>
11 #define fi(s) freopen(s,"r",stdin);
12 #define fo(s) freopen(s,"w",stdout);
13 using namespace std;
14 typedef long long LL;
15 
16 inline int read() {
17     int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
18     for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f;
19 }
20 
21 const LL mod = 998244353;
22 const int N = 1000010;
23 
24 struct Node {
25     int x, y, val;
26     bool zh;
27     bool operator < (const Node &A) const {
28         return val < A.val;
29     }
30 }A[N];
31 LL f[N], cnt[N], sumx[N], sumy[N], sumx2[N], sumy2[N];
32 
33 LL ksm(LL a,LL b) {
34     LL ans = 1;
35     while (b) {
36         if (b & 1) ans = 1ll * ans * a % mod;
37         a = 1ll * a * a % mod;
38         b >>= 1;
39     }
40     return ans;
41 }
42 
43 inline void add(LL &x,LL y) { (x += y) >= mod ? (x -= mod) : x; }
44 inline void sub(LL &x,LL y) { (x -= y) < 0 ? (x += mod) : x; }
45 
46 void solve2(int n) {
47     
48     A[0].val = -1;
49     for (int i=1; i<=n; ++i) {
50         if (A[i].val == A[i - 1].val) cnt[i] = cnt[i - 1];
51         else cnt[i] = i - 1;
52         sumx[i] = (sumx[i - 1] + A[i].x) % mod;
53         sumy[i] = (sumy[i - 1] + A[i].y) % mod;
54         sumx2[i] = (sumx2[i - 1] + 1ll * A[i].x * A[i].x % mod) % mod;
55         sumy2[i] = (sumy2[i - 1] + 1ll * A[i].y * A[i].y % mod) % mod;
56     }
57     
58     LL sum = 0, tmp = 0;
59     for (int i=1; i<=n; ++i) {
60         LL x2 = sumx2[cnt[i]];
61         LL y2 = sumy2[cnt[i]];
62         LL z1 = 1ll * sumx[cnt[i]] * 2 % mod * A[i].x % mod;
63         LL z2 = 1ll * sumy[cnt[i]] * 2 % mod * A[i].y % mod;
64         LL h1 = 1ll * cnt[i] * A[i].x % mod * A[i].x % mod;
65         LL h2 = 1ll * cnt[i] * A[i].y % mod * A[i].y % mod;
66         
67         add(f[i], x2); add(f[i], y2); 
68         sub(f[i], z1); sub(f[i], z2);
69         add(f[i], h1); add(f[i], h2);
70         add(f[i], sum);
71         
72         f[i] = 1ll * f[i] * ksm(cnt[i], mod - 2) % mod;
73         if (A[i].zh) {
74             cout << f[i]; return ;
75         }
76         add(tmp, f[i]); // 只有小于的时候才转移!!! 
77         if (A[i].val < A[i + 1].val) add(sum, tmp), tmp = 0;
78     }
79     
80 }
81 
82 int main() {
83     int n = read(), m = read(), tot = 0;
84     for (int i=1; i<=n; ++i) 
85         for (int j=1; j<=m; ++j) 
86             A[++tot].x = i, A[tot].y = j, A[tot].val = read(), A[tot].zh = false;
87     
88     int x = read(), y = read(), z = (x - 1) * m + y;
89     A[z].zh = true;
90     
91     sort(A + 1, A + tot + 1);
92 
93     solve2(tot);
94     return 0;
95 }

 

比赛时代码,记录调试历程。

  1 #include<cstdio>
  2 #include<algorithm>
  3 #include<cstring>
  4 #include<cmath>
  5 #include<iostream>
  6 #include<cctype>
  7 #include<set>
  8 #include<vector>
  9 #include<queue>
 10 #include<map>
 11 #define fi(s) freopen(s,"r",stdin);
 12 #define fo(s) freopen(s,"w",stdout);
 13 using namespace std;
 14 typedef long long LL;
 15 
 16 inline int read() {
 17     int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
 18     for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f;
 19 }
 20 
 21 const LL mod = 998244353;
 22 const int N = 1000010;
 23 
 24 struct Node {
 25     int x, y, val;
 26     bool zh;
 27     bool operator < (const Node &A) const {
 28         return val < A.val;
 29     }
 30 }A[N];
 31 
 32 LL inv[N], deg[N], f[N];
 33 //double dp[N];
 34 
 35 LL ksm(LL a,LL b) {
 36     LL ans = 1;
 37     while (b) {
 38         if (b & 1) ans = 1ll * ans * a % mod;
 39         a = 1ll * a * a % mod;
 40         b >>= 1;
 41     }
 42     return ans;
 43 }
 44 
 45 LL Calc(int i,int j) {
 46     return ((A[i].x - A[j].x) * (A[i].x - A[j].x) % mod + (A[i].y - A[j].y) * (A[i].y - A[j].y) % mod) % mod;
 47 }
 48 
 49 void solve1(int n) {
 50     for (int i=1; i<=n; ++i) {
 51 //        cout << A[i].val << ": ";
 52         if (deg[i]) {
 53 //            dp[i] = dp[i] / (double)(deg[i]);
 54             f[i] = 1ll * ksm(deg[i], mod - 2) * f[i] % mod;
 55         }
 56         if (A[i].zh) {
 57             cout << f[i]; return ;
 58         }
 59         for (int j=i+1; j<=n; ++j) 
 60             if (A[j].val > A[i].val) {
 61 //                dp[j] = dp[j] + dp[i] + Calc(i, j);
 62 //                cout << A[j].val << " " << Calc(i, j) <<"--";
 63                 f[j] = (f[j] + f[i] + Calc(i, j)) % mod;
 64                 deg[j] ++;
 65             }
 66 //        puts("");
 67     }
 68 }
 69 
 70 int cnt[N], sumx[N], sumy[N], sumx2[N], sumy2[N];
 71 
 72 inline void add(LL &x,LL y) { (x += y) >= mod ? (x -= mod) : x; }
 73 inline void sub(LL &x,LL y) { (x -= y) < 0 ? (x += mod) : x; }
 74 
 75 void solve2(int n) {
 76     
 77     A[0].val = -1;
 78     for (int i=1; i<=n; ++i) {
 79         if (A[i].val == A[i - 1].val) cnt[i] = cnt[i - 1];
 80         else cnt[i] = i - 1;
 81         sumx[i] = (sumx[i - 1] + A[i].x) % mod;
 82         sumy[i] = (sumy[i - 1] + A[i].y) % mod;
 83         sumx2[i] = (sumx2[i - 1] + 1ll * A[i].x * A[i].x % mod) % mod;
 84         sumy2[i] = (sumy2[i - 1] + 1ll * A[i].y * A[i].y % mod) % mod;
 85     }
 86     
 87     LL sum = 0, tmp = 0;
 88     for (int i=1; i<=n; ++i) {
 89         LL x2 = sumx2[cnt[i]];
 90         LL y2 = sumy2[cnt[i]];
 91         LL z1 = 1ll * sumx[cnt[i]] * 2 % mod * A[i].x % mod;
 92         LL z2 = 1ll * sumy[cnt[i]] * 2 % mod * A[i].y % mod;
 93         LL h1 = 1ll * cnt[i] * A[i].x % mod * A[i].x % mod;
 94         LL h2 = 1ll * cnt[i] * A[i].y % mod * A[i].y % mod;
 95         
 96         add(f[i], x2); add(f[i], y2); 
 97         sub(f[i], z1); sub(f[i], z2);
 98         add(f[i], h1); add(f[i], h2);
 99         add(f[i], sum);
100         
101         f[i] = 1ll * f[i] * ksm(cnt[i], mod - 2) % mod;
102         if (A[i].zh) {
103             cout << f[i]; return ;
104         }
105         add(tmp, f[i]); // 只有小于的时候才转移!!! 
106         if (A[i].val < A[i + 1].val) add(sum, tmp), tmp = 0;
107     }
108     
109 }
110 
111 int main() {
112     int n = read(), m = read(), tot = 0;
113     for (int i=1; i<=n; ++i) 
114         for (int j=1; j<=m; ++j) 
115             A[++tot].x = i, A[tot].y = j, A[tot].val = read(), A[tot].zh = false;
116     
117     int x = read(), y = read(), z = (x - 1) * m + y;
118     A[z].zh = true;
119     
120     sort(A + 1, A + tot + 1);
121     
122     
123 //    if (tot <= 1000) {
124 //        solve1(tot) ;return 0;
125 //    }
126     solve2(tot);
127     return 0;
128 }
View Code

 

 

 

 

posted @ 2018-09-17 21:04  MJT12044  阅读(405)  评论(0编辑  收藏  举报