130行C语言实现个用户态线程库(1)
准确的说是除掉头文件,测试代码和非关键的纯算法代码(只有双向环形链表的ADT),核心代码只有130行左右,已经是蝇量级的用户态线程库了。把这个库取名为ezthread,意思是,这太easy了,人人都可以读懂并且实现这个用户态线程库。我把该项目放在github上,欢迎来拍砖: https://github.com/Yuandong-Chen/coroutine/tree/old-version(注意,最新的版本已经用了共享栈技术,能够支持1000K数量级的协程了,读完这篇博文后可以进一步参考后续的博文:http://www.cnblogs.com/github-Yuandong-Chen/p/6973932.html)。那么下面谈谈怎么实现这个ezthread。
大家都会双向环形链表(就是头尾相连的双向链表),我们构造这个ADT结构:
首先是每个节点:
1 typedef struct __pnode pNode; 2 struct __pnode 3 { 4 pNode *next; 5 pNode *prev; 6 Thread_t *data; 7 };
显然,next指向下一个节点,prev指向上一个节点,data指向该节点数据,那么这个Thread_t是什么类型的数据结构呢?
typedef struct __ez_thread Thread_t; struct __ez_thread { Regs regs; int tid; unsigned int stacktop; unsigned int stacksize; void *stack; void *retval; };
这个结构体包含了线程内部的信息,比如第一项为Regs,记录的是各个寄存器的取值(我们在下面给出具体的结构),tid就是线程的ID了,stacktop记录的是线程栈的顶部(和页对齐的最大地址,每个线程都有自己的运行时的栈,用于构成他们相对独立的运行时环境),stacksize就是栈的大小了,stack指针指向我们给该线程栈分配的堆的指针(什么?怎么一会栈一会堆的?我们其实用了malloc函数分配出一些堆空间,把这些空间用于线程栈,当线程退出时候,我们再free这些堆),retval就是线程运行完了的返回值(pthread_join里头拿到的线程返回值就是这个了)。
下面是寄存器结构体:
typedef struct __thread_table_regs Regs; struct __thread_table_regs { int _ebp; int _esp; int _eip; int _eflags; };
真是好懂,一看就知道了,这个结构体只能支持X86体系的计算机了。那么还有个问题,为何只有这些寄存器,没用其他的比如:eax,ebx,edi,esi等等呢?因为我们在转换状态函数switch_to里头当返回时(准确地说是从上次切换的点切换回来时)用了return来切换回线程运行时环境,return会自动帮助我们把这些其他的寄存器的值恢复原状的(具体我们放到switch_to的时候再详细说明)。
然后呢,我们定义了一个游标去取这个环形链表的值,否则我们怎么读取这个环形链表里头的数据呢?总得有个东西指向其中某个节点吧。
typedef struct __loopcursor Cursor; struct __loopcursor { int total; pNode *current; };
这个游标结构体记录了现在指向的节点地址和这个环形链表里头一共有多少节点。
我们得用两个这样的环形链表结构体来支持我们的线程库,为何是俩呢?一个是正在运行的线程,我们把他们串成一个环形链表,取名为live(活的),然后用另外一个链表把运行结束的线程串成一串,取名为dead(死的)。然后最开始我们就有个线程在运行了,那就是主线程main,我们用pmain节点来记录主线程:
extern Cursor live; extern Cursor dead; extern Thread_t pmain;
好了,剩下的只有在这些结构体上操作的函数了:
void init(); void switch_to(int, ...); int threadCreat(Thread_t **, void *(*)(void *), void *); int threadJoin(Thread_t *, void **);
我们开始时调用init,以初始化我们的live,dead和pmain。然后当我们想创造线程时,就threadCreat就可以了,用法和pthread_create基本一模一样,熟悉posix多线程的人一看就明白了,threadJoin也是仿照pthread_join接口写的。这里的switch_to就是最关键的运行时环境转换函数了,当线程调用这个函数时候,我们就切换到其他线程上次暂停的点去执行了(这些状态都保存在我们的Thread_t结构体里,所以我们能够记录下切换前的状态,从而能够从容地去切换到下一个线程中)。我们没有用定时器每隔几微秒去激发switch_to(实现起来也是非常简单的,但是得添加多个signal_block函数,非常不简洁),而是让线程里头的函数主动调用switch_to来切换线程,这有点类似协程。
好了,现在讲具体的实现了。首先是对双向链表的操作函数,这个东西不是我们的重点,懂基础算法数据结构的人都能实现,具体是双向环形链表的增查删操作:
1 void initCursor(Cursor *cur) 2 { 3 cur->total = 0; 4 cur->current = NULL; 5 } 6 7 Thread_t *findThread(Cursor *cur, int tid) 8 { 9 int counter = cur->total; 10 if(counter == 0){ 11 return NULL; 12 } 13 14 int i; 15 pNode *tmp = cur->current; 16 for (int i = 0; i < counter; ++i) 17 { 18 if((tmp->data)->tid == tid){ 19 return tmp->data; 20 } 21 22 tmp = tmp->next; 23 } 24 return NULL; 25 } 26 27 int appendThread(Cursor *cur, Thread_t *pth) 28 { 29 if(cur->total == 0) 30 { 31 cur->current = (pNode *)malloc(sizeof(pNode)); 32 assert(cur->current); 33 (cur->current)->data = pth; 34 (cur->current)->prev = cur->current; 35 (cur->current)->next = cur->current; 36 cur->total++; 37 return 0; 38 } 39 else 40 { 41 if(cur->total > MAXCOROUTINES) 42 { 43 assert((cur->total == MAXCOROUTINES)); 44 return -1; 45 } 46 47 pNode *tmp = malloc(sizeof(pNode)); 48 assert(tmp); 49 tmp->data = pth; 50 tmp->prev = cur->current; 51 tmp->next = (cur->current)->next; 52 ((cur->current)->next)->prev = tmp; 53 (cur->current)->next = tmp; 54 cur->total++; 55 return 0; 56 } 57 } 58 59 pNode *deleteThread(Cursor *cur, int tid) 60 { 61 int counter = cur->total; 62 int i; 63 pNode *tmp = cur->current; 64 for (int i = 0; i < counter; ++i) 65 { 66 if((tmp->data)->tid == tid){ 67 (tmp->prev)->next = tmp->next; 68 (tmp->next)->prev = tmp->prev; 69 if(tmp == cur->current) 70 { 71 cur->current = cur->current->next; 72 } 73 74 cur->total--; 75 assert(cur->total >= 0); 76 return tmp; 77 } 78 tmp = tmp->next; 79 } 80 return NULL; 81 }
抛开这部分纯算法代码,我们只剩下130行代码了。这还不如某些函数的代码量大。但是我们就是在这130行代码里头实现了switch_to,threadCreat以及threadJoin等等关键代码。
先说下init怎么实现的:
1 void init() 2 { 3 initCursor(&live); 4 initCursor(&dead); 5 appendThread(&live, &pmain); 6 }
其实关键点只有一句,那就是第5行的append(&live,&pmain);往live链表里头添加pmain节点,但是我们的pmain还没初始化呢,里头stack,regs等等通通都是0,但是没事呢,因为当我们第一次进入switch_to的时候,switch_to在跳转前会帮助我们保存当前线程,这时也就是pmain的运行时状态。
然后我们看看threadCreat怎么实现:
1 int threadCreat(Thread_t **pth, void *(*start_rtn)(void *), void *arg) 2 { 3 4 *pth = malloc(sizeof(Thread_t)); 5 (*pth)->stack = malloc(PTHREAD_STACK_MIN); 6 assert((*pth)->stack); 7 (*pth)->stacktop = (((int)(*pth)->stack + PTHREAD_STACK_MIN)&(0xfffff000)); 8 (*pth)->stacksize = PTHREAD_STACK_MIN - (((int)(*pth)->stack + PTHREAD_STACK_MIN) - (*pth)->stacktop); 9 (*pth)->tid = fetchTID(); 10 /* set params */ 11 void *dest = (*pth)->stacktop - 12; 12 memcpy(dest, pth, 4); 13 dest += 4; 14 memcpy(dest, &start_rtn, 4); 15 dest += 4; 16 memcpy(dest, &arg, 4); 17 (*pth)->regs._eip = &real_entry; 18 (*pth)->regs._esp = (*pth)->stacktop - 16; 19 (*pth)->regs._ebp = 0; 20 appendThread(&live, (*pth)); 21 22 return 0; 23 }
我们在第4行分配了堆空间,然后让线程栈顶变量stacktop对齐页,设置stacksize大小(这个其实对我们的线程库没有用,因为我们还没有实现类似stackguard之类的检查机制),设置tid,这里fetchTID函数如下:
1 int fetchTID() 2 { 3 static int tid; 4 return ++tid; 5 }
接着,我们在threadCreat函数的11-16行代码中,在栈顶压入变量pth,start_rtn以及arg(我们用memcpy来操作线程栈空间),这些都是作为real_entry这个函数的参数压入线程栈的。我们不难发现,其实每个线程的最初入口地址都是real_entry函数(注意到我们在17行把eip设置为real_entry的地址)。最后,我们于17-19行设置寄存器变量,以满足刚进入该real_entry时的栈的状态,在live链表中添加该线程结构体指针,返回。这一系列操作导致的效果就是,比如我们第一次调用threadCreat函数,当发生switch_to的时候,当然我们先保存当前线程状态,然后就从主线程main中切换到了real_entry里头去了,而且对应的参数我们设置好了,就好像我们在主线程里头直接调用了real_entry一样。下面看下real_entry做了些什么:
1 void real_entry(Thread_t *pth, void *(*start_rtn)(void *), void* args) 2 { 3 ALIGN(); 4 5 pth->retval = (*start_rtn)(args); 6 7 deleteThread(&live, pth->tid); 8 appendThread(&dead, pth); 9 10 switch_to(-1); 11 }
第3行是对齐栈操作,我们先不做说明。接下来就是调用start_rtn函数,并且把args作为参数,返回值赋给线程的retval。当返回时,说明线程已经运行结束,在live链表里头删除该节点,在dead链表里头添加该节点。在第10行最后调用switch_to(-1),也就是在switch_to里头直接跳到下一个线程去执行,且不保存当前状态。
我们再看下threadJoin函数的实现:
1 int threadJoin(Thread_t *pth, void **rval_ptr) 2 { 3 4 Thread_t *find1, *find2; 5 find1 = findThread(&live, pth->tid); 6 find2 = findThread(&dead, pth->tid); 7 8 9 if((find1 == NULL)&&(find2 == NULL)){ 10 11 return -1; 12 } 13 14 if(find2){ 15 if(rval_ptr != NULL) 16 *rval_ptr = find2->retval; 17 18 pNode *tmp = deleteThread(&dead, pth->tid); 19 free(tmp); 20 free((Stack_t)find2->stack); 21 free(find2); 22 return 0; 23 } 24 25 while(1) 26 { 27 switch_to(0); 28 if((find2 = findThread(&dead, pth->tid))!= NULL){ 29 if(rval_ptr!= NULL) 30 *rval_ptr = find2->retval; 31 32 pNode *tmp = deleteThread(&dead, pth->tid); 33 free(tmp); 34 free((Stack_t)find2->stack); 35 free(find2); 36 return 0; 37 } 38 } 39 return -1; 40 }
threadJoin是用于回收线程资源并得到返回值的。实现大体的思路就是,我们先查找live和dead里头有没有这个线程,如果都没有,说明根本不存在这个线程,如果dead链表里头有,那么我们就得到返回值(15-16行),然后释放堆空间(19-22行)。如果在live里头,说明该线程还没执行结束,我们进入循环,先调用switch_to(0),保存当前线程状态,然后切换到下一个线程去。当再次回到这个循环时候,我们继续看看dead里头有没有这个线程,有就设置返回值(29-30行),然后释放资源(32-35行),否则继续切换并循环。
最后,最关键的,我们给出switch_to的实现:
1 void switch_to(int signo, ...) 2 { 3 4 va_list ap; 5 va_start(ap, signo); 6 7 Regs regs; 8 9 if(signo == -1) 10 { 11 regs = live.current->data->regs; 12 JMP(regs); 13 assert(0); 14 } 15 16 int _ebp; 17 int _esp; 18 int _eip = &&_REENTERPOINT; 19 int _eflags; 20 /* save current context */ 21 SAVE(); 22 /* save context in current thread */ 23 live.current->data->regs._eip = _eip; 24 live.current->data->regs._esp = _esp; 25 live.current->data->regs._ebp = _ebp; 26 live.current->data->regs._eflags = _eflags; 27 28 if(va_arg(ap,int) == -1){ 29 _REENTERPOINT: 30 assert(va_arg(ap,int) != -1); 31 return; 32 } 33 34 va_end(ap); 35 regs = live.current->next->data->regs; 36 live.current = live.current->next; 37 JMP(regs); 38 assert(0); 39 }
先看11-13行,我们把自动变量regs的值赋为当前线程的寄存器的结构体,然后跳转到当前线程(第12行JMP是跳转语句,13行永远不会执行)。这里大家有个疑问,从当前线程跳转到当前线程,那么还不是当前线程么?然后执行assert(0)报错退出?!其实只有当线程返回时,也就是在real_entry里头才可能执行switch_to(-1),注意到real_entry最后的几行代码,里头已经把当前线程从live里头删除,并添加到dead里了,所以现在live里头的当前线程其实是下一个线程。然后我们看21-26行,我们保存当前寄存器的值到当前线程中,注意第18行,我们把返回点设置在了_REENTERPOINT这个标签上,也就是以后如果再次切换到该线程时,我们会在第30行继续向下执行,很简单,第30行的有意义的代码只有return,也就是恢复其他寄存器(eax,edi,esi等等),然后返回到线程继续执行。我们继续看34-38行代码:我们把自动变量regs的值赋值为下一个线程的寄存器,然后live的当前线程指针current也指向了下一个线程,通过37行JMP,我们调到了下一个线程去执行,下个一个线程可能是real_entry处开始执行,也可能是_REENTERPOINT处开始执行。最后再从新说说31行的return到底return到哪里去了,我们看一下测试代码:
1 #include "ezthread.h" 2 #include <stdio.h> 3 #include <stdlib.h> 4 5 void *sum1tod(void *d) 6 { 7 int i, j=0; 8 9 for (i = 0; i <= d; ++i) 10 { 11 j += i; 12 printf("thread %d is grunting... %d\n",live.current->data->tid , i); 13 switch_to(0); // Give up control to next thread 14 } 15 16 return ((void *)j); 17 } 18 19 int main(int argc, char const *argv[]) 20 { 21 int res = 0; 22 int i; 23 init(); 24 Thread_t *tid1, *tid2; 25 int *res1, *res2; 26 27 threadCreat(&tid1, sum1tod, 10); 28 threadCreat(&tid2, sum1tod, 10); 29 30 for (i = 0; i <= 5; ++i){ 31 res+=i; 32 printf("main is grunting... %d\n", i); 33 switch_to(0); //Give up control to next thread 34 } 35 threadJoin(tid1, &res1); //Collect and Release the resourse of tid1 36 threadJoin(tid2, &res2); //Collect and Release the resourse of tid2 37 printf("parallel compute: %d = (1+2+3+4+5) + (1+2+...+10)*2\n", (int)res1+(int)res2+(int)res); 38 return 0; 39 }
注意到我们在测试代码里头sum1tod里头调用了switch_to(0),如果这个循环加法(11-13行)还未结束,那么上述的那个_REENTERPOINT里头的return就会return回这个循环继续执行,就如在sum1tod里的switch_to(0)函数直接调用return,什么事情也没干一样,但是其实我们经过了无数其他线程的执行,但是在sum1tod里头毫无感觉,简直好像其他线程不存在一样(除非我们在这里头调用threadJoin等待其他线程结束)。
现在我们给出讨厌的内嵌汇编:
1 #define JMP(r) asm volatile \ 2 ( \ 3 "pushl %3\n\t" \ 4 "popf\n\t" \ 5 "movl %0, %%esp\n\t" \ 6 "movl %2, %%ebp\n\t" \ 7 "jmp *%1\n\t" \ 8 : \ 9 : "m"(r._esp),"a"(r._eip),"m"(r._ebp), "m"(r._eflags) \ 10 : \ 11 ) 12 13 #define SAVE() asm volatile \ 14 ( \ 15 "movl %%esp, %0\n\t" \ 16 "movl %%ebp, %1\n\t" \ 17 "pushf\n\t" \ 18 "movl (%%esp), %%eax\n\t" \ 19 "movl %%eax, %2\n\t" \ 20 "popf\n\t" \ 21 : "=m"(_esp),"=m"(_ebp), "=m"(_eflags) \ 22 : \ 23 : \ 24 ) 25 26 #define ALIGN() asm volatile \ 27 ( \ 28 "andl $-16, %%esp\n\t" \ 29 : \ 30 : \ 31 :"%esp" \ 32 )
第一个就是起到跳转作用,第二个是保存寄存器到自动变量作用,最后一个是栈对齐作用。为何要栈对齐?因为我们在堆里头设置了这个栈的空间,这个和普通的栈空间并不完全一样,我们需要做对齐处理。
到这里我们就几乎完全明白了这个线程库的实现,还有一小点就是switch_to里头的可变参数怎么回事,其实那个是防止编译器中消除冗余代码造成我们_REENTERPOINT中的代码被优化而整个删除用的。如果我们在_REENTERPOINT前加入goto语句跳到下面执行,然后删除这个_REENTERPOINT之前的判断语句,我们会发现,编译器会把switch_to里头的第28-32行作为冗余代码全部删除。
谢谢你能看到最后,告诉你们一个消息,其实我们的实现是介于longjmp和汇编实现版本之间的某种实现:我们用汇编保存了运行时状态,但是其中的return又有点类似longjmp中自动恢复寄存器的作用。而且我们的库比纯汇编实现更具可移植性,但比longjmp实现版本又弱了点。