skynet源码阅读<6>--线程调度

    相比于上节我们提到的协程调度,skynet的线程调度从逻辑流程上来看要简单很多。下面我们就来具体做一分析。首先自然是以skynet_start.c为入口:

 1 static void
 2 start(int thread) {
 3     pthread_t pid[thread+3];
 4 
 5     struct monitor *m = skynet_malloc(sizeof(*m));
 6     memset(m, 0, sizeof(*m));
 7     m->count = thread;
 8     m->sleep = 0;
 9 
10     m->m = skynet_malloc(thread * sizeof(struct skynet_monitor *));
11     int i;
12     for (i=0;i<thread;i++) {
13         m->m[i] = skynet_monitor_new();
14     }
15     if (pthread_mutex_init(&m->mutex, NULL)) {
16         fprintf(stderr, "Init mutex error");
17         exit(1);
18     }
19     if (pthread_cond_init(&m->cond, NULL)) {
20         fprintf(stderr, "Init cond error");
21         exit(1);
22     }
23 
24     create_thread(&pid[0], thread_monitor, m);
25     create_thread(&pid[1], thread_timer, m);
26     create_thread(&pid[2], thread_socket, m);
27 
28     static int weight[] = { 
29         -1, -1, -1, -1, 0, 0, 0, 0,
30         1, 1, 1, 1, 1, 1, 1, 1, 
31         2, 2, 2, 2, 2, 2, 2, 2, 
32         3, 3, 3, 3, 3, 3, 3, 3, };
33     struct worker_parm wp[thread];
34     for (i=0;i<thread;i++) {
35         wp[i].m = m;
36         wp[i].id = i;
37         if (i < sizeof(weight)/sizeof(weight[0])) {
38             wp[i].weight= weight[i];
39         } else {
40             wp[i].weight = 0;
41         }
42         create_thread(&pid[i+3], thread_worker, &wp[i]);
43     }
44 
45     for (i=0;i<thread+3;i++) {
46         pthread_join(pid[i], NULL); 
47     }
48 
49     free_monitor(m);
50 }

    先是创建monitor、timer、socket三个线程分别执行监控、计时器、网络处理的工作,接着创建thread个用户线程,设置权重weight并开始执行。看下thread_worker里面做了什么:

 1 static void *
 2 thread_worker(void *p) {
 3     struct worker_parm *wp = p;
 4     int id = wp->id;
 5     int weight = wp->weight;
 6     struct monitor *m = wp->m;
 7     struct skynet_monitor *sm = m->m[id];
 8     skynet_initthread(THREAD_WORKER);
 9     struct message_queue * q = NULL;
10     while (!m->quit) {
11         q = skynet_context_message_dispatch(sm, q, weight);
12         if (q == NULL) {
13             if (pthread_mutex_lock(&m->mutex) == 0) {
14                 ++ m->sleep;
15                 // "spurious wakeup" is harmless,
16                 // because skynet_context_message_dispatch() can be call at any time.
17                 if (!m->quit)
18                     pthread_cond_wait(&m->cond, &m->mutex);
19                 -- m->sleep;
20                 if (pthread_mutex_unlock(&m->mutex)) {
21                     fprintf(stderr, "unlock mutex error");
22                     exit(1);
23                 }
24             }
25         }
26     }
27     return NULL;
28 }

    每一个工作线程,先是分发消息,并获得下一次要分发的消息队列。如果没有的话就休息一会儿,等待合适的时机被唤醒。每次socket线程中有新的消息到来,或是timer-update时就唤醒所有休息在m->cond上的线程,继续消息分发。看一下具体的分发过程:

 1 struct message_queue * 
 2 skynet_context_message_dispatch(struct skynet_monitor *sm, struct message_queue *q, int weight) {
 3     if (q == NULL) {
 4         q = skynet_globalmq_pop();
 5         if (q==NULL)
 6             return NULL;
 7     }
 8 
 9     uint32_t handle = skynet_mq_handle(q);
10 
11     struct skynet_context * ctx = skynet_handle_grab(handle);
12     if (ctx == NULL) {
13         struct drop_t d = { handle };
14         skynet_mq_release(q, drop_message, &d);
15         return skynet_globalmq_pop();
16     }
17 
18     int i,n=1;
19     struct skynet_message msg;
20 
21     for (i=0;i<n;i++) {
22         if (skynet_mq_pop(q,&msg)) {
23             skynet_context_release(ctx);
24             return skynet_globalmq_pop();
25         } else if (i==0 && weight >= 0) {
26             n = skynet_mq_length(q);
27             n >>= weight;
28         }
29         int overload = skynet_mq_overload(q);
30         if (overload) {
31             skynet_error(ctx, "May overload, message queue length = %d", overload);
32         }
33 
34         skynet_monitor_trigger(sm, msg.source , handle);
35 
36         if (ctx->cb == NULL) {
37             skynet_free(msg.data);
38         } else {
39             dispatch_message(ctx, &msg);
40         }
41 
42         skynet_monitor_trigger(sm, 0,0);
43     }
44 
45     assert(q == ctx->queue);
46     struct message_queue *nq = skynet_globalmq_pop();
47     if (nq) {
48         // If global mq is not empty , push q back, and return next queue (nq)
49         // Else (global mq is empty or block, don't push q back, and return q again (for next dispatch)
50         skynet_globalmq_push(q);
51         q = nq;
52     } 
53     skynet_context_release(ctx);
54 
55     return q;
56 }

    拿到当前消息队列q对应的handle和context,接着便从q中连续取出n个消息通过dispatch_message分发出去。权重越大的线程,当前可分发的消息数n就越大。如果在n范围内q的消息分发完了,那么就从全局队列globalmq中取出下一个消息队列q,返回等待下一次执行;如果n范围内q依然有消息剩余,那么就把当前消息队列q扔回到global_mq中去,并拿到新的消息队列nq返回,这样每个消息队列都会相对公平地得到执行。不同的线程只是在从全局队列global_mq拿消息队列q时对global_mq加锁,或者在向q中推入消息(比如向skynet-context的q发送消息),以及这里取出消息执行时需要对q加锁。
    具体的q或global_mq是按照链表实现的,这里就不废话了。最后看下消息分发的处理:

 1 static void
 2 dispatch_message(struct skynet_context *ctx, struct skynet_message *msg) {
 3     assert(ctx->init);
 4     CHECKCALLING_BEGIN(ctx)
 5     pthread_setspecific(G_NODE.handle_key, (void *)(uintptr_t)(ctx->handle));
 6     int type = msg->sz >> MESSAGE_TYPE_SHIFT;
 7     size_t sz = msg->sz & MESSAGE_TYPE_MASK;
 8     if (ctx->logfile) {
 9         skynet_log_output(ctx->logfile, msg->source, type, msg->session, msg->data, sz);
10     }
11     ++ctx->message_count;
12     int reserve_msg;
13     if (ctx->profile) {
14         ctx->cpu_start = skynet_thread_time();
15         reserve_msg = ctx->cb(ctx, ctx->cb_ud, type, msg->session, msg->source, msg->data, sz);
16         uint64_t cost_time = skynet_thread_time() - ctx->cpu_start;
17         ctx->cpu_cost += cost_time;
18     } else {
19         reserve_msg = ctx->cb(ctx, ctx->cb_ud, type, msg->session, msg->source, msg->data, sz);
20     }
21     if (!reserve_msg) {
22         skynet_free(msg->data);
23     }
24     CHECKCALLING_END(ctx)
25 }

    关键就在于ctx->cb这个函数指针的调用。你肯定会好奇,这个回调是如何进入LUA代码的呢?回想skynet.lua中skynet.start(f)函数,其调用了c.callback将回调f注册到skynet-os中。因此就让我们回到skynet-lua.c中,看一看lcallback函数的处理:

 1 static int
 2 lcallback(lua_State *L) {
 3     struct skynet_context * context = lua_touserdata(L, lua_upvalueindex(1));
 4     int forward = lua_toboolean(L, 2);
 5     luaL_checktype(L,1,LUA_TFUNCTION);
 6     lua_settop(L,1);
 7     lua_rawsetp(L, LUA_REGISTRYINDEX, _cb);
 8 
 9     lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_MAINTHREAD);
10     lua_State *gL = lua_tothread(L,-1);
11 
12     if (forward) {
13         skynet_callback(context, gL, forward_cb);
14     } else {
15         skynet_callback(context, gL, _cb);
16     }
17 
18     return 0;
19 }

    这里重点看下lua_rawsetp(L, LUA_REGISTRYINDEX, _cb)。此时栈顶为LUA_TFUNCTION元素f,这句的作用是从索引LUA_REGISTERINDEX处取出t,在这里也就是全局注册表,然后设置t[_cb]=f并弹出栈顶元素(此函数不触发__index元方法)。这样就将_cb与f之间关联起来了。最后调用skynet_callback将_cb设置到context.cb元素,而虚拟机L也设置到context.ud元素。了解了lcallback的注册过程,我们回到_cb函数,看看具体调用时发生了什么事情:

 1 static int
 2 _cb(struct skynet_context * context, void * ud, int type, int session, uint32_t source, const void * msg, size_t sz) {
 3     lua_State *L = ud;
 4     int trace = 1;
 5     int r;
 6     int top = lua_gettop(L);
 7     if (top == 0) {
 8         lua_pushcfunction(L, traceback);
 9         lua_rawgetp(L, LUA_REGISTRYINDEX, _cb);
10     } else {
11         assert(top == 2);
12     }
13     lua_pushvalue(L,2);
14 
15     lua_pushinteger(L, type);
16     lua_pushlightuserdata(L, (void *)msg);
17     lua_pushinteger(L,sz);
18     lua_pushinteger(L, session);
19     lua_pushinteger(L, source);
20 
21     r = lua_pcall(L, 5, 0 , trace);
22 
23     if (r == LUA_OK) {
24         return 0;
25     }
26 }

    以_cb(函数地址)为键,调用lua_rawgetp从全局注册表中取出上述注册进来的LUA_TFUNCTION函数f,压入参数调用执行。
    至此,skynet的线程调度流程已经清楚了。下一次我们来谈谈锁的问题。

posted on 2017-05-26 15:41  莫行  阅读(1049)  评论(0编辑  收藏  举报

导航