[Luogu] P7077 函数调用
Description
某数据库应用程序提供了若干函数用以维护数据。已知这些函数的功能可分为三类:
\(1.\)将数据中的指定元素加上一个值;
\(2.\)将数据中的每一个元素乘以一个相同值;
\(3.\)依次执行若干次函数调用,保证不会出现递归(即不会直接或间接地调用本身)。
在使用该数据库应用时,用户可一次性输入要调用的函数序列(一个函数可能被调用多次),在依次执行完序列中的函数后,系统中的数据被加以更新。为了计算出正确数据,小\(A\)查阅了软件的文档,了解到每个函数的具体功能信息,现在他想请你根据这些信息帮他计算出更新后的数据应该是多少。
Solution
可以考虑依次执行完操作序列后,所有数先一起被乘上一个数,有一些位置再被加上所加上的数乘上这个数的贡献。
注意到依次执行,不会出现递归等关键字眼,可以想到拓扑排序。(一定要注意这里的\(m\)才是原来拓扑排序的\(n\),因为把\(m\)个函数当作点)
可以新建一个点\(0\),作为主函数,同时将它和输入的所有调用函数连边。那么拓扑排序循环就要从\(0\sim{m}\)。
前一步是很好做的,可以考虑记忆化搜索,或者先建反图,跑一遍拓扑排序。维护一个乘法标记\(mul\),对于\(1\)类函数和\(3\)类函数,它们的\(mul=1\),对于\(2\)类函数,它的\(mul=v_i\)。然后从下到上,累乘\(mul\)即可。然后所有的\(a_i\)就要乘上\(mul[0]\)。
而如何处理加上的数究竟被加上了多少次呢?我们会发现,当操作序列不断执行\(\times,+,\times,+...\)时,对于一个\(+\),只有它后面的\(\times\)会对它产生乘法贡献。我们在原图上再跑一边拓扑排序,从上到下算出每个函数被执行的次数\(add\),也就是加法的次数。
对于\(u\rightarrow{v}\),设\(now=\prod\limits_{u\rightarrow{p},t_p>t_v}mul[p]\)(\(t_i\)表示\(i\)被执行的顺序先后),那么\(add[v]=add[v]+add[u]*now\)(\(u\)先执行\(add[u]\)次,每次都使\(add[v]\)乘上\(now\))。算完\(add\)后,处理加法即可。
(这个的确不是很难,但是谁有时间想啊,T1出题人1582.10.5~1582.10.14)
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod = 998244353;
queue < int > q;
int n, m, t, tot, hd[100005], nxt[1100005], to[1100005], rd[1100005], cnt[1100005];
ll a[100005];
struct node
{
int tp, pos, sz;
ll v, mul, add;
vector < int > w;
}f[100005];
int read()
{
int x = 0, fl = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
return x * fl;
}
void add(int x, int y)
{
tot ++ ;
to[tot] = y;
nxt[tot] = hd[x];
hd[x] = tot;
return;
}
void topo1()
{
for (int i = 0; i <= m; i ++ ) cnt[i] = f[i].sz;
for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
while (q.size())
{
int x = q.front(); q.pop();
for (int i = hd[x]; i; i = nxt[i])
{
int y = to[i];
f[y].mul = f[y].mul * f[x].mul % mod;
cnt[y] -- ;
if (!cnt[y]) q.push(y);
}
}
return;
}
void topo2()
{
for (int i = 0; i <= m; i ++ ) cnt[i] = rd[i];
for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
while (q.size())
{
int x = q.front(); q.pop();
ll now = 1ll;
for (int i = f[x].sz - 1; i >= 0; i -- )
{
int y = f[x].w[i];
f[y].add = (f[y].add + f[x].add * now)% mod;
now = now * f[y].mul % mod;
cnt[y] -- ;
if (!cnt[y]) q.push(y);
}
}
return;
}
int main()
{
n = read();
for (int i = 1; i <= n; i ++ )
a[i] = (ll)read();
m = read();
f[0].mul = 1ll;
for (int i = 1; i <= m; i ++ )
{
f[i].tp = read();
if (f[i].tp == 1)
{
f[i].pos = read();
f[i].v = (ll)read();
f[i].mul = 1ll;
}
else if (f[i].tp == 2)
{
f[i].v = (ll)read();
f[i].mul = f[i].v;
}
else
{
f[i].sz = read();
f[i].mul = 1ll;
for (int j = 1; j <= f[i].sz; j ++ )
{
int x = read();
f[i].w.push_back(x);
add(x, i);
rd[x] ++ ;
}
}
}
f[0].add = 1ll;
t = read();
while (t -- )
{
int x = read();
add(x, 0);
rd[x] ++ ;
f[0].sz ++ ;
f[0].w.push_back(x);
}
topo1(); topo2();
for (int i = 1; i <= n; i ++ )
a[i] = a[i] * f[0].mul % mod;
for (int i = 1; i <= m; i ++ )
if (f[i].tp == 1)
a[f[i].pos] = (a[f[i].pos] + f[i].add * f[i].v % mod) % mod;
for (int i = 1; i <= n; i ++ )
printf("%lld ", a[i]);
puts("");
return 0;
}