协程学习
协程学习
很早之前,其实也就半年前吧,我当时用golang写了mit6.824这门课程,众所周知,golang里面协程算是这门语言核心的特点。我一直很想搞清楚协程到底是啥,但是有点畏难情绪,趁着在学校隔离期间,我也是抽出时间看了看协程的实现。没想到协程原理其实挺简单的,我是搭配大佬的笔记和代码学习的
调度器
调度器这个东西让我想到了当时学习操作系统这门课程中。进程的切换会走到一个sched进程,这个sched线程是一个内核态进程,会从proc数组中获取可以执行的线程,然后分配相应的CPU core执行进程。
而协程的调度器也是起到类似的作用,走到调度器中,调度器负责切换。不过调度器也是在用户态中,这是最大的不同,这意味着我们的调度器其实就是像用户自己写的一段函数而已。每次进入调度器,实际上就是相当于调用了调度器函数而已。
/**
* 协程调度器
*/
struct schedule {
char stack[STACK_SIZE]; // 运行时栈
ucontext_t main; // 主协程的上下文
int nco; // 当前存活的协程个数
int cap; // 协程管理器的当前最大容量,即可以同时支持多少个协程。如果不够了,则进行扩容
int running; // 正在运行的协程ID
struct coroutine **co; // 一个一维数组,用于存放协程
};
里面有个stack有意思,每个协程的栈都共用一个stack,这也就是共享栈协程的来源。
协程
struct coroutine {
coroutine_func func; // 协程所用的函数
void *ud; // 协程参数
ucontext_t ctx; // 协程上下文
struct schedule * sch; // 该协程所属的调度器
ptrdiff_t cap; // 已经分配的内存大小
ptrdiff_t size; // 当前协程运行时栈,保存起来后的大小
int status; // 协程当前的状态
char *stack; // 当前协程的保存起来的运行时栈
};
其中ctx
这个数据结构ucontext_t
也就是用户态上下文,可以看这篇文章 https://langzi989.github.io/2017/10/06/ucontext函数详解/
对ucontext_t
进行操作主要过程
- 使用
getcontext
获取上下文,保存在ctx中 - 对上下文进行设置,一般设置栈顶和栈大小以及
uc_link
等数据结构 - 使用
makecontext
修改该context的入口函数,我猜测是否是修改了对应寄存器,比如EIP的值?使得指令流发生了改变?这部分还不太明了。 - 使用
swapcontext
进行上下文切换。
这里有个非常令人迷惑的点,设置栈顶的时候是直接设置了栈的最顶端
char stack[10*1204];
//get current context
ucontext_t curContext;
getcontext(&curContext);
//modify the current context
ucontext_t newContext = curContext;
newContext.uc_stack.ss_sp = stack;
我知道这句话有点语无伦次,但是栈指针不是从栈底(高地址)向低地址生长吗?那么栈顶指针应该指向的是栈的当前问斩,也就是栈顶元素所在位置。此时初始化context,应该是栈顶等于栈底?我觉得我哪里可能理解错了。
创建协程
int
coroutine_new(struct schedule *S, coroutine_func func, void *ud) {
struct coroutine *co = _co_new(S, func , ud); // 创建一个协程数据结构
if (S->nco >= S->cap) {
// 如果目前协程的数量已经大于调度器的容量,那么进行扩容
int id = S->cap; // 新的协程的id直接为当前容量的大小
// 扩容的方式为,扩大为当前容量的2倍,这种方式和Hashmap的扩容略像
S->co = realloc(S->co, S->cap * 2 * sizeof(struct coroutine *));
// 初始化内存
memset(S->co + S->cap , 0 , sizeof(struct coroutine *) * S->cap);
//将协程放入调度器
S->co[S->cap] = co;
// 将容量扩大为两倍
S->cap *= 2;
// 尚未结束运行的协程的个数
++S->nco;
return id;
} else {
// 如果目前协程的数量小于调度器的容量,则取一个为NULL的位置,放入新的协程
int i;
for (i=0;i<S->cap;i++) {
/*
* 为什么不 i%S->cap,而是要从nco+i开始呢
* 这其实也算是一种优化策略吧,因为前nco有很大概率都非NULL的,直接跳过去更好
*/
int id = (i+S->nco) % S->cap;
if (S->co[id] == NULL) {
S->co[id] = co;
++S->nco;
return id;
}
}
}
assert(0);
return -1;
}
创建协程数据结构,然后放到调度器的协程队列中。其中还有调度器队列的扩容。
执行协程
/*
* 通过low32和hi32 拼出了struct schedule的指针,这里为什么要用这种方式,而不是直接传struct schedule*呢?
* 因为makecontext的函数指针的参数是int可变列表,在64位下,一个int没法承载一个指针
*/
static void
mainfunc(uint32_t low32, uint32_t hi32) {
uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)hi32 << 32);
struct schedule *S = (struct schedule *)ptr;
int id = S->running;
struct coroutine *C = S->co[id];
C->func(S,C->ud); // 中间有可能会有不断的yield
_co_delete(C);
S->co[id] = NULL;
--S->nco;
S->running = -1;
}
/**
* 切换到对应协程中执行
*
* @param S 协程调度器
* @param id 协程ID
*/
void
coroutine_resume(struct schedule * S, int id) {
assert(S->running == -1);
assert(id >=0 && id < S->cap);
// 取出协程
struct coroutine *C = S->co[id];
if (C == NULL)
return;
int status = C->status;
switch(status) {
case COROUTINE_READY:
//初始化ucontext_t结构体,将当前的上下文放到C->ctx里面
getcontext(&C->ctx);
// 将当前协程的运行时栈的栈顶设置为S->stack
// 每个协程都这么设置,这就是所谓的共享栈。(注意,这里是栈顶)
C->ctx.uc_stack.ss_sp = S->stack;
C->ctx.uc_stack.ss_size = STACK_SIZE;
C->ctx.uc_link = &S->main; // 如果协程执行完,将切换到主协程中执行
S->running = id;
C->status = COROUTINE_RUNNING;
// 设置执行C->ctx函数, 并将S作为参数传进去
uintptr_t ptr = (uintptr_t)S;
makecontext(&C->ctx, (void (*)(void)) mainfunc, 2, (uint32_t)ptr, (uint32_t)(ptr>>32));
// 将当前的上下文放入S->main中,并将C->ctx的上下文替换到当前上下文
swapcontext(&S->main, &C->ctx);
break;
case COROUTINE_SUSPEND:
// 将协程所保存的栈的内容,拷贝到当前运行时栈中
// 其中C->size在yield时有保存
memcpy(S->stack + STACK_SIZE - C->size, C->stack, C->size);
S->running = id;
C->status = COROUTINE_RUNNING;
swapcontext(&S->main, &C->ctx);
break;
default:
assert(0);
}
}
原理就像上面提到到,改变当前协程的指令流,我曾经问过不少人,协程到底是什么?都回答我一句协程就是函数。这句话我现在可算是理解了。因为栈就是用于存储函数调用的数据结构,
这里我直接截图了:
我们切换用户态的执行这让我想到了之前做的操作系统课程https://songlinlife.top/2022/MIT6-s081-lab7thread/,我现在一想咦咦咦,这不就是写了一个协程的调度吗?
这个课程当时做的其实确实有一点一知半解的,现在我可谓是完全领悟了。可以说我悟了!咦!悟!
.globl thread_switch
thread_switch:
/* YOUR CODE HERE */
sd ra, 0(a0)
sd sp, 8(a0)
sd s0, 16(a0)
sd s1, 24(a0)
sd s2, 32(a0)
sd s3, 40(a0)
sd s4, 48(a0)
sd s5, 56(a0)
sd s6, 64(a0)
sd s7, 72(a0)
sd s8, 80(a0)
sd s9, 88(a0)
sd s10, 96(a0)
sd s11, 104(a0)
ld ra, 0(a1)
ld sp, 8(a1)
ld s0, 16(a1)
ld s1, 24(a1)
ld s2, 32(a1)
ld s3, 40(a1)
ld s4, 48(a1)
ld s5, 56(a1)
ld s6, 64(a1)
ld s7, 72(a1)
ld s8, 80(a1)
ld s9, 88(a1)
ld s10, 96(a1)
ld s11, 104(a1)
ret /* return to ra */
这里做的实际上就是保存ra
寄存器和sp
寄存器,然后切换就是从内存中在把之前保存的寄存器load上来。注意这里保存的寄存器都是caller save寄存器,整个过程太清晰。
这里ra
寄存器就是函数返回的地址。我这里记下我当时的实现:
#include "kernel/types.h"
#include "kernel/stat.h"
#include "user/user.h"
/* Possible states of a thread: */
#define FREE 0x0
#define RUNNING 0x1
#define RUNNABLE 0x2
#define STACK_SIZE 8192
#define MAX_THREAD 4
struct thread {
char stack[STACK_SIZE]; /* the thread's stack */
int state; /* FREE, RUNNING, RUNNABLE */
};
struct thread all_thread[MAX_THREAD];
struct thread *current_thread;
extern void thread_switch(uint64, uint64);
void
thread_init(void)
{
// main() is thread 0, which will make the first invocation to
// thread_schedule(). it needs a stack so that the first thread_switch() can
// save thread 0's state. thread_schedule() won't run the main thread ever
// again, because its state is set to RUNNING, and thread_schedule() selects
// a RUNNABLE thread.
current_thread = &all_thread[0];
current_thread->state = RUNNING;
}
void
thread_schedule(void)
{
struct thread *t, *next_thread;
/* Find another runnable thread. */
next_thread = 0;
t = current_thread + 1;
for(int i = 0; i < MAX_THREAD; i++){
if(t >= all_thread + MAX_THREAD)
t = all_thread;
if(t->state == RUNNABLE) {
next_thread = t;
break;
}
t = t + 1;
}
if (next_thread == 0) {
printf("thread_schedule: no runnable threads\n");
exit(-1);
}
if (current_thread != next_thread) { /* switch threads? */
next_thread->state = RUNNING;
t = current_thread;
current_thread = next_thread;
thread_switch((uint64)t->context, (uint64)current_thread->context);
} else
next_thread = 0;
}
void
thread_create(void (*func)())
{
struct thread *t;
for (t = all_thread; t < all_thread + MAX_THREAD; t++) {
if (t->state == FREE) break;
}
t->state = RUNNABLE;
// YOUR CODE HERE
*(uint64*)(t->context) = (uint64)func; // 这个就是使用 func
*(uint64*)(t->context + 8) = (uint64)(t->stack+STACK_SIZE);
}
void
thread_yield(void)
{
current_thread->state = RUNNABLE;
thread_schedule();
}
volatile int a_started, b_started, c_started;
volatile int a_n, b_n, c_n;
void
thread_a(void)
{
int i;
printf("thread_a started\n");
a_started = 1;
while(b_started == 0 || c_started == 0)
thread_yield();
for (i = 0; i < 100; i++) {
printf("thread_a %d\n", i);
a_n += 1;
thread_yield();
}
printf("thread_a: exit after %d\n", a_n);
current_thread->state = FREE;
thread_schedule();
}
void
thread_b(void)
{
int i;
printf("thread_b started\n");
b_started = 1;
while(a_started == 0 || c_started == 0)
thread_yield();
for (i = 0; i < 100; i++) {
printf("thread_b %d\n", i);
b_n += 1;
thread_yield();
}
printf("thread_b: exit after %d\n", b_n);
current_thread->state = FREE;
thread_schedule();
}
void
thread_c(void)
{
int i;
printf("thread_c started\n");
c_started = 1;
while(a_started == 0 || b_started == 0)
thread_yield();
for (i = 0; i < 100; i++) {
printf("thread_c %d\n", i);
c_n += 1;
thread_yield();
}
printf("thread_c: exit after %d\n", c_n);
current_thread->state = FREE;
thread_schedule();
}
int
main(int argc, char *argv[])
{
a_started = b_started = c_started = 0;
a_n = b_n = c_n = 0;
thread_init();
thread_create(thread_a);
thread_create(thread_b);
thread_create(thread_c);
thread_schedule();
exit(0);
}
通常修改ra寄存器然后改变函数返回地址,从而实现协程。简直秀的发麻。
当时我做的时候,就觉得很秀,现在一看简直秀麻了。
- 创建协程也就是uthread,将ra寄存器的值设置为对应的func address,这样第一次切换uthread就会去执行对应的function。
- 核心的
thread_schedule
函数,其会调用thread_switch
这个asm函数,起作用就是保存ra sp寄存器和切换ra和sp寄存器。这样thread_switch
执行完毕后,就会ret ra,也就是跑去执行新的切换的ra地址。 - thread yield也是执行thread_shcedule 这个函数,相当于切换执行不同的进程。thread_a执行完毕后,同样再调用一次thread_schedule,并且将当前进程设置为FREE,这样接下来就不会返回thread_a中,也同样不会返回thread_a 的return address。
我当时没有搞懂这个协程,但是今晚上可算是整的明明白白了。