计算图(dfs+高数)

计算图(dfs+高数)

 

 

 

 

 

 解题思路:反向建图,拓扑dfs

AC_Code:

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 const int maxn = 5e4+10;
  5 
  6 int s[maxn][2],_map[maxn];
  7 double ans[maxn],der[maxn],x[maxn];//求导:derivative
  8 int in[maxn],f;
  9 bool vis[maxn];
 10 int cnt;
 11 vector<int>vec;
 12 
 13 int n;
 14 
 15 void dfs(int h){
 16     if( vis[h] ) return ;
 17     vis[h] = true;
 18 
 19     if( _map[h]==0 ){//叶子节点,变量
 20         ans[h] = x[h];
 21         //根据友情提示的第二条
 22         if( h==f ) der[h] = 1;
 23         else der[h] = 0;
 24     } else if( _map[h]==1 ){//加法求值,求偏导
 25         int l=s[h][0], r=s[h][1];
 26         dfs(l); dfs(r);
 27         ans[h] = ans[l] + ans[r];
 28         der[h] = der[l] + der[r];
 29     } else if( _map[h]==2 ){//减法求值,求偏导
 30         int l=s[h][0], r=s[h][1];
 31         dfs(l), dfs(r);
 32         ans[h] = ans[l] - ans[r];
 33         der[h] = der[l] - der[r];
 34     } else if( _map[h]==3 ){//乘法求值,求偏导
 35         int l=s[h][0], r=s[h][1];
 36         dfs(l); dfs(r);
 37         //复合函数求导根据友情提示第3条
 38         ans[h] = ans[l] * ans[r];
 39         der[h] = der[l] * ans[r] + der[r] * ans[l];
 40     } else if( _map[h]==4 ){//指数求导
 41         int v = s[h][0];
 42         dfs(v);
 43         ans[h] = exp(ans[v]);
 44         der[h] = exp(ans[v]) * der[v];
 45     } else if( _map[h]==5 ){//lnx求导
 46         int v = s[h][0];
 47         dfs(v);
 48         ans[h] = log(ans[v]);
 49         der[h] = der[v] / ans[v];
 50     } else if( _map[h]==6 ){//sin求导
 51         int v = s[h][0];
 52         dfs(v);
 53         ans[h] = sin(ans[v]);
 54         der[h] = cos(ans[v]) * der[v];
 55     }
 56 }
 57 int main()
 58 {
 59     scanf("%d",&n);
 60     for(int i=0;i<n;i++){
 61         int k; scanf("%d",&k);
 62         _map[i] = k;
 63         if( !k ){
 64             scanf("%lf",&x[i]);
 65             vec.push_back(i);
 66         }
 67         else if( k>=1 && k<=3 ){
 68             int u,v; scanf("%d%d",&u,&v);
 69             s[i][0] = u;
 70             s[i][1] = v;
 71             in[u]++;
 72             in[v]++;
 73         }
 74         else{
 75             int u; scanf("%d",&u);
 76             s[i][0] = u;
 77             in[u]++;
 78         }
 79     }
 80     int s = 0;
 81     for(int i=0;i<n;i++){
 82         if( !in[i] ){
 83             s = i;
 84             break;
 85         }
 86     }
 87     queue<double>q;
 88     cnt = vec.size();
 89     for(int i=0;i<cnt;i++){
 90         f = vec[i];//指定要偏导的变量
 91         memset(vis,false,sizeof(vis));
 92         dfs(s);
 93         q.push(der[s]);
 94     }
 95     printf("%.3f\n",ans[s]);
 96     bool flag = false;
 97     while( !q.empty()){
 98         double h = q.front(); q.pop();
 99         if( !flag ) printf("%.3f",h);
100         else printf(" %.3f",h);
101         flag = true;
102     }
103     printf("\n");
104     return 0;
105 }

 

posted @ 2020-11-23 09:16  swsyya  阅读(185)  评论(0编辑  收藏  举报

回到顶部