协程学习

协程学习

很早之前,其实也就半年前吧,我当时用golang写了mit6.824这门课程,众所周知,golang里面协程算是这门语言核心的特点。我一直很想搞清楚协程到底是啥,但是有点畏难情绪,趁着在学校隔离期间,我也是抽出时间看了看协程的实现。没想到协程原理其实挺简单的,我是搭配大佬的笔记和代码学习的

代码 博客

调度器

调度器这个东西让我想到了当时学习操作系统这门课程中。进程的切换会走到一个sched进程,这个sched线程是一个内核态进程,会从proc数组中获取可以执行的线程,然后分配相应的CPU core执行进程。

而协程的调度器也是起到类似的作用,走到调度器中,调度器负责切换。不过调度器也是在用户态中,这是最大的不同,这意味着我们的调度器其实就是像用户自己写的一段函数而已。每次进入调度器,实际上就是相当于调用了调度器函数而已。

/**
* 协程调度器
*/
struct schedule {
	char stack[STACK_SIZE];	// 运行时栈

	ucontext_t main; // 主协程的上下文
	int nco; // 当前存活的协程个数
	int cap; // 协程管理器的当前最大容量,即可以同时支持多少个协程。如果不够了,则进行扩容
	int running; // 正在运行的协程ID
	struct coroutine **co; // 一个一维数组,用于存放协程 
};

里面有个stack有意思,每个协程的栈都共用一个stack,这也就是共享栈协程的来源。

协程

struct coroutine {
	coroutine_func func; // 协程所用的函数
	void *ud;  // 协程参数
	ucontext_t ctx; // 协程上下文
	struct schedule * sch; // 该协程所属的调度器
	ptrdiff_t cap; 	 // 已经分配的内存大小
	ptrdiff_t size; // 当前协程运行时栈,保存起来后的大小
	int status;	// 协程当前的状态
	char *stack; // 当前协程的保存起来的运行时栈
};

其中ctx这个数据结构ucontext_t也就是用户态上下文,可以看这篇文章 https://langzi989.github.io/2017/10/06/ucontext函数详解/

ucontext_t进行操作主要过程

  1. 使用getcontext获取上下文,保存在ctx中
  2. 对上下文进行设置,一般设置栈顶和栈大小以及uc_link等数据结构
  3. 使用makecontext修改该context的入口函数,我猜测是否是修改了对应寄存器,比如EIP的值?使得指令流发生了改变?这部分还不太明了。
  4. 使用swapcontext进行上下文切换。

这里有个非常令人迷惑的点,设置栈顶的时候是直接设置了栈的最顶端

char stack[10*1204];
//get current context
ucontext_t curContext;
getcontext(&curContext);
//modify the current context
ucontext_t newContext = curContext;
newContext.uc_stack.ss_sp = stack;

我知道这句话有点语无伦次,但是栈指针不是从栈底(高地址)向低地址生长吗?那么栈顶指针应该指向的是栈的当前问斩,也就是栈顶元素所在位置。此时初始化context,应该是栈顶等于栈底?我觉得我哪里可能理解错了。

创建协程

int 
coroutine_new(struct schedule *S, coroutine_func func, void *ud) {
	struct coroutine *co = _co_new(S, func , ud); // 创建一个协程数据结构
	if (S->nco >= S->cap) {
		// 如果目前协程的数量已经大于调度器的容量,那么进行扩容
		int id = S->cap;	// 新的协程的id直接为当前容量的大小
		// 扩容的方式为,扩大为当前容量的2倍,这种方式和Hashmap的扩容略像
		S->co = realloc(S->co, S->cap * 2 * sizeof(struct coroutine *));
		// 初始化内存
		memset(S->co + S->cap , 0 , sizeof(struct coroutine *) * S->cap);
		//将协程放入调度器
		S->co[S->cap] = co;
		// 将容量扩大为两倍
		S->cap *= 2;
		// 尚未结束运行的协程的个数 
		++S->nco; 
		return id;
	} else {
		// 如果目前协程的数量小于调度器的容量,则取一个为NULL的位置,放入新的协程
		int i;
		for (i=0;i<S->cap;i++) {
			/* 
			 * 为什么不 i%S->cap,而是要从nco+i开始呢 
			 * 这其实也算是一种优化策略吧,因为前nco有很大概率都非NULL的,直接跳过去更好
			*/
			int id = (i+S->nco) % S->cap;
			if (S->co[id] == NULL) {
				S->co[id] = co;
				++S->nco;
				return id;
			}
		}
	}
	assert(0);
	return -1;
}

创建协程数据结构,然后放到调度器的协程队列中。其中还有调度器队列的扩容。

执行协程

/*
 * 通过low32和hi32 拼出了struct schedule的指针,这里为什么要用这种方式,而不是直接传struct schedule*呢?
 * 因为makecontext的函数指针的参数是int可变列表,在64位下,一个int没法承载一个指针
*/
static void
mainfunc(uint32_t low32, uint32_t hi32) {
	uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)hi32 << 32);
	struct schedule *S = (struct schedule *)ptr;

	int id = S->running;
	struct coroutine *C = S->co[id];
	C->func(S,C->ud);	// 中间有可能会有不断的yield
	_co_delete(C);
	S->co[id] = NULL;
	--S->nco;
	S->running = -1;
}

/**
* 切换到对应协程中执行
* 
* @param S 协程调度器
* @param id 协程ID
*/
void 
coroutine_resume(struct schedule * S, int id) {
	assert(S->running == -1);
	assert(id >=0 && id < S->cap);

    // 取出协程
	struct coroutine *C = S->co[id];
	if (C == NULL)
		return;

	int status = C->status;
	switch(status) {
	case COROUTINE_READY:
	    //初始化ucontext_t结构体,将当前的上下文放到C->ctx里面
		getcontext(&C->ctx);
		// 将当前协程的运行时栈的栈顶设置为S->stack
		// 每个协程都这么设置,这就是所谓的共享栈。(注意,这里是栈顶)
		C->ctx.uc_stack.ss_sp = S->stack; 
		C->ctx.uc_stack.ss_size = STACK_SIZE;
		C->ctx.uc_link = &S->main; // 如果协程执行完,将切换到主协程中执行
		S->running = id;
		C->status = COROUTINE_RUNNING;

		// 设置执行C->ctx函数, 并将S作为参数传进去
		uintptr_t ptr = (uintptr_t)S;
		makecontext(&C->ctx, (void (*)(void)) mainfunc, 2, (uint32_t)ptr, (uint32_t)(ptr>>32));

		// 将当前的上下文放入S->main中,并将C->ctx的上下文替换到当前上下文
		swapcontext(&S->main, &C->ctx);
		break;
	case COROUTINE_SUSPEND:
	    // 将协程所保存的栈的内容,拷贝到当前运行时栈中
		// 其中C->size在yield时有保存
		memcpy(S->stack + STACK_SIZE - C->size, C->stack, C->size);
		S->running = id;
		C->status = COROUTINE_RUNNING;
		swapcontext(&S->main, &C->ctx);
		break;
	default:
		assert(0);
	}
}

原理就像上面提到到,改变当前协程的指令流,我曾经问过不少人,协程到底是什么?都回答我一句协程就是函数。这句话我现在可算是理解了。因为栈就是用于存储函数调用的数据结构,

这里我直接截图了:

我们切换用户态的执行这让我想到了之前做的操作系统课程https://songlinlife.top/2022/MIT6-s081-lab7thread/,我现在一想咦咦咦,这不就是写了一个协程的调度吗?

这个课程当时做的其实确实有一点一知半解的,现在我可谓是完全领悟了。可以说我悟了!咦!悟!

	.globl thread_switch
thread_switch:
	/* YOUR CODE HERE */
        sd ra, 0(a0)
        sd sp, 8(a0)
        sd s0, 16(a0)
        sd s1, 24(a0)
        sd s2, 32(a0)
        sd s3, 40(a0)
        sd s4, 48(a0)
        sd s5, 56(a0)
        sd s6, 64(a0)
        sd s7, 72(a0)
        sd s8, 80(a0)
        sd s9, 88(a0)
        sd s10, 96(a0)
        sd s11, 104(a0)
        
        
        ld ra, 0(a1)
        ld sp, 8(a1)
        ld s0, 16(a1)
        ld s1, 24(a1)
        ld s2, 32(a1)
        ld s3, 40(a1)
        ld s4, 48(a1)
        ld s5, 56(a1)
        ld s6, 64(a1)
        ld s7, 72(a1)
        ld s8, 80(a1)
        ld s9, 88(a1)
        ld s10, 96(a1)
        ld s11, 104(a1)
	ret    /* return to ra */

这里做的实际上就是保存ra寄存器和sp寄存器,然后切换就是从内存中在把之前保存的寄存器load上来。注意这里保存的寄存器都是caller save寄存器,整个过程太清晰。

这里ra 寄存器就是函数返回的地址。我这里记下我当时的实现:

#include "kernel/types.h"
#include "kernel/stat.h"
#include "user/user.h"

/* Possible states of a thread: */
#define FREE        0x0
#define RUNNING     0x1
#define RUNNABLE    0x2

#define STACK_SIZE  8192
#define MAX_THREAD  4


struct thread {
  char       stack[STACK_SIZE]; /* the thread's stack */
  int        state;             /* FREE, RUNNING, RUNNABLE */

};
struct thread all_thread[MAX_THREAD];
struct thread *current_thread;
extern void thread_switch(uint64, uint64);
              
void 
thread_init(void)
{
  // main() is thread 0, which will make the first invocation to
  // thread_schedule().  it needs a stack so that the first thread_switch() can
  // save thread 0's state.  thread_schedule() won't run the main thread ever
  // again, because its state is set to RUNNING, and thread_schedule() selects
  // a RUNNABLE thread.
  current_thread = &all_thread[0];
  current_thread->state = RUNNING;
}

void 
thread_schedule(void)
{
  struct thread *t, *next_thread;

  /* Find another runnable thread. */
  next_thread = 0;
  t = current_thread + 1;
  for(int i = 0; i < MAX_THREAD; i++){
    if(t >= all_thread + MAX_THREAD)
      t = all_thread;
    if(t->state == RUNNABLE) {
      next_thread = t;
      break;
    }
    t = t + 1;
  }

  if (next_thread == 0) {
    printf("thread_schedule: no runnable threads\n");
    exit(-1);
  }

  if (current_thread != next_thread) {         /* switch threads?  */
    next_thread->state = RUNNING;
    t = current_thread;
    current_thread = next_thread;
    thread_switch((uint64)t->context, (uint64)current_thread->context);    
  } else
    next_thread = 0;
}

void 
thread_create(void (*func)())
{
  struct thread *t;

  for (t = all_thread; t < all_thread + MAX_THREAD; t++) {
    if (t->state == FREE) break;
  }
  t->state = RUNNABLE;
  // YOUR CODE HERE
   *(uint64*)(t->context) = (uint64)func; // 这个就是使用 func
  *(uint64*)(t->context + 8) = (uint64)(t->stack+STACK_SIZE);
}

void 
thread_yield(void)
{
  current_thread->state = RUNNABLE;
  thread_schedule();
}

volatile int a_started, b_started, c_started;
volatile int a_n, b_n, c_n;

void 
thread_a(void)
{
  int i;
  printf("thread_a started\n");
  a_started = 1;
  while(b_started == 0 || c_started == 0)
    thread_yield();
  
  for (i = 0; i < 100; i++) {
    printf("thread_a %d\n", i);
    a_n += 1;
    thread_yield();
  }
  printf("thread_a: exit after %d\n", a_n);

  current_thread->state = FREE;
  thread_schedule();
}

void 
thread_b(void)
{
  int i;
  printf("thread_b started\n");
  b_started = 1;
  while(a_started == 0 || c_started == 0)
    thread_yield();
  
  for (i = 0; i < 100; i++) {
    printf("thread_b %d\n", i);
    b_n += 1;
    thread_yield();
  }
  printf("thread_b: exit after %d\n", b_n);

  current_thread->state = FREE;
  thread_schedule();
}

void 
thread_c(void)
{
  int i;
  printf("thread_c started\n");
  c_started = 1;
  while(a_started == 0 || b_started == 0)
    thread_yield();
  
  for (i = 0; i < 100; i++) {
    printf("thread_c %d\n", i);
    c_n += 1;
    thread_yield();
  }
  printf("thread_c: exit after %d\n", c_n);

  current_thread->state = FREE;
  thread_schedule();
}

int 
main(int argc, char *argv[]) 
{
  a_started = b_started = c_started = 0;
  a_n = b_n = c_n = 0;
  thread_init();
  thread_create(thread_a);
  thread_create(thread_b);
  thread_create(thread_c);
  thread_schedule();
  exit(0);
}

通常修改ra寄存器然后改变函数返回地址,从而实现协程。简直秀的发麻。

当时我做的时候,就觉得很秀,现在一看简直秀麻了。

  1. 创建协程也就是uthread,将ra寄存器的值设置为对应的func address,这样第一次切换uthread就会去执行对应的function。
  2. 核心的thread_schedule函数,其会调用thread_switch这个asm函数,起作用就是保存ra sp寄存器和切换ra和sp寄存器。这样thread_switch执行完毕后,就会ret ra,也就是跑去执行新的切换的ra地址。
  3. thread yield也是执行thread_shcedule 这个函数,相当于切换执行不同的进程。thread_a执行完毕后,同样再调用一次thread_schedule,并且将当前进程设置为FREE,这样接下来就不会返回thread_a中,也同样不会返回thread_a 的return address。

我当时没有搞懂这个协程,但是今晚上可算是整的明明白白了。

posted @ 2022-11-25 22:10  kalice  阅读(86)  评论(0编辑  收藏  举报