协程

#include "co.h"
#include <stdlib.h>
#include <string.h>
#include <stdio.h>


enum state{
  CREATED  = 0,
  RUNNING,
  HALT,
  WAIT,
  FINISHED
};

#define STACK_SIZE 4 * 1024 * 1024 * sizeof(char)

struct co {
  char         stack[STACK_SIZE];
  const char*  name;
  void*        arg;
  enum state   s;
  jmp_buf      context;
  uint64_t     RIP;
  uint64_t     RBP;
  uint64_t     RSP;
  struct co*   wait;
  //当前协程被谁等
  struct co*   waited;
};


struct{
  // coroutine[0] 始终存储main函数的stack frame
  struct co* coroutine[100];
  int        count;
  int        cur_idx;
  int        cnt;
  uint64_t   yield_func;
}coroutines;

struct co *co_start(const char *name, void (*func)(void *), void *arg) {
  // 如果是NULL,说明是首次调度,给main函数创建一个coroutine
  if(coroutines.coroutine[0] == NULL){
      coroutines.yield_func = (uint64_t)(void (*)())finish_co;

      struct co* coroutine = (struct co*) malloc(sizeof(struct co));
      coroutine->s = HALT;
      coroutines.coroutine[0] = coroutine;
      coroutines.cnt = 1;
  }

  ++coroutines.cnt;
  struct co* coroutine = (struct co*) malloc(sizeof(struct co));
  coroutine->name = name;
  coroutine->arg = arg;
  coroutine->s = CREATED;
  coroutine->RIP = (uint64_t)func;
  coroutine->RSP = (uint64_t)coroutine->stack + STACK_SIZE;
  coroutines.coroutine[++coroutines.count] = coroutine;
  return coroutine;
}

void co_wait(struct co *co) {
  if(co->s == FINISHED) return;
  struct co* cur = coroutines.coroutine[coroutines.cur_idx];
  
  co->waited = cur;
  cur->wait = co;

  cur->s = WAIT;
  co_yield();
}

void finish_co(){
  struct co *co = coroutines.coroutine[coroutines.cur_idx];
  co->s = FINISHED;
  --coroutines.cnt;
    
  if(co->waited != NULL){
    co->waited->wait = NULL;
    co->waited->s = HALT;
  }
  co_yield();
}

void change_thread(int next_idx){
  struct co* current = coroutines.coroutine[coroutines.cur_idx];
  struct co* next = coroutines.coroutine[next_idx];
  coroutines.cur_idx = next_idx;

  if(current->s == RUNNING) current->s = HALT;
  int flag = next->s;
  next->s = RUNNING;
  
  if(flag == CREATED){
    asm volatile(
        "movq    %0, %%rsp\n\t"
        "pushq   %3\n\t"
        "movq    %1, %%rdi\n\t"
        "jmp     *%2\n\t"
      :
      : "m"(next->RSP),
        "m"(next->arg),
        "m"(next->RIP),
        "m"(coroutines.yield_func)
      );
  } else {
    longjmp(next->context, 1);
  }
}

void co_yield() {
  if(!coroutines.cnt) return;
  int nxt = coroutines.cur_idx;
  struct co* cur_co = coroutines.coroutine[nxt];
  
  if(!setjmp(cur_co->context))
  {
    // jump to other program
    for(nxt = (nxt+1) % (coroutines.count + 1); 
        nxt <= coroutines.count;
        nxt=(nxt+1)%(coroutines.count+1))
    {
      if(coroutines.coroutine[nxt]->s == HALT 
          || coroutines.coroutine[nxt]->s == CREATED)
      {
        change_thread(nxt);
        break;
      }
    }
  }
}
posted @ 2023-07-06 12:12  INnoVation-V2  阅读(13)  评论(0编辑  收藏  举报