协程 + epoll 的两个小例子

getcontext/setupcontext/swapcontext/setcontext 方式的协程实现

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <errno.h>
#include <unistd.h>
#include <ucontext.h>
#include <sys/time.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>

#define MYPORT 12345
#define MAX_STACK 8192
#define MAX_EVENTS 1024

int epollfd;

typedef struct sock_ctx {
    int sock ;
    ucontext_t * rctx ;
    ucontext_t * wctx ;
}sock_ctx;

void read_sock();
void write_sock();

uint32_t high32(uint64_t value) {
    return value >> 32;
}

uint32_t low32(uint64_t value) {
    return value;
}

void * make64(uint32_t low, uint32_t high) {
    return (void *)((uint64_t) high << 32 | low);
}

void setupcontext(ucontext_t * ctx, sock_ctx * sockctx, void(*func)()) {
    makecontext( ctx, func, 2, low32((uint64_t)sockctx), high32((uint64_t)sockctx) );
}

// 定义
int epoll_action(int epollevt, sock_ctx* sockctx, int action) {
    struct epoll_event event;
    event.events = epollevt;
    event.data.ptr = sockctx;
    return epoll_ctl(epollfd, action, sockctx->sock, &event);
}

// 宏定义
#define epoll_mod(epollevt, sockctx) epoll_action(epollevt, sockctx, EPOLL_CTL_MOD)
// 宏定义
#define epoll_add(epollevt, sockctx) epoll_action(epollevt, sockctx, EPOLL_CTL_ADD)
// 宏定义
#define epoll_del(sockctx)  epoll_ctl(epollfd, EPOLL_CTL_DEL, sockctx->sock, NULL)

void write_sock(int low, int high) {
    sock_ctx * sockctx = (sock_ctx *)make64(low, high);
    epoll_mod(EPOLLIN, sockctx);
    printf("write_sock[%d] --- function\n", (int)sockctx->sock);
    setcontext( sockctx->wctx->uc_link );
}

void read_sock(int low, int high) {
    sock_ctx * sockctx = (sock_ctx *)make64(low, high);
    while ( 1 ) {
        char body[1024] = {0};
        int ret = recv( sockctx->sock, (char *)body, 1024, 0 );
        if ( ret == 0 ) {
            printf("sock disconnect\n");
            epoll_del(sockctx);
            setcontext(sockctx->rctx->uc_link);
            break;
        } else if ( ret < 0 ) {
            printf("sock error : %d\n", errno);
            break;
        }

        printf("read_sock[%d] --- buf = %s\n", sockctx->sock, body);
        if ( ret < 1024 ) {
            break;
        }
    }

    epoll_mod(EPOLLOUT, sockctx);
    printf("read_sock[%d] --- function\n", (int)sockctx->sock);
    setcontext( sockctx->rctx->uc_link );
}

void accept_sock(int low, int high) {
    struct sockaddr_in sin;
    socklen_t len = sizeof(struct sockaddr_in);
    sock_ctx * tmpctx = (sock_ctx *)make64(low, high);
    bzero(&sin, len);
    int confd = accept(tmpctx->sock, (struct sockaddr*)&sin, &len);
    if ( confd < 0 ) {
       printf("bad accept\n");
       return;
    }

    sock_ctx * sockctx = malloc(sizeof(sock_ctx));
    sockctx->sock = confd;
    sockctx->rctx = malloc(sizeof(ucontext_t));
    sockctx->wctx = malloc(sizeof(ucontext_t));
    getcontext(sockctx->rctx);
    getcontext(sockctx->wctx);

    sockctx->rctx ->uc_link = tmpctx->rctx->uc_link;
    sockctx->rctx ->uc_stack.ss_size = MAX_STACK;
    sockctx->rctx ->uc_stack.ss_sp = malloc(MAX_STACK);
    sockctx->wctx ->uc_link = tmpctx->rctx->uc_link;
    sockctx->wctx ->uc_stack.ss_size = MAX_STACK;
    sockctx->wctx ->uc_stack.ss_sp = malloc(MAX_STACK);
    setupcontext( sockctx->rctx, sockctx, (void(*)())read_sock );
    setupcontext( sockctx->wctx, sockctx, (void(*)())write_sock );
    printf("accept_sock ---- new connection %d\n", confd);

    if ( epoll_add(EPOLLIN, sockctx) < 0 ) {
       printf( "epoll_ctl failed\n" ) ;
       return;
    }
    setcontext ( tmpctx->rctx->uc_link ) ;
}

//-------------------------------------------------------
// 这个方式性能不一定高,只是演示协程的用法
//-------------------------------------------------------
int main() {
    ucontext_t ctx_main;
    int i, timeout = 100000;
    fd_set readfds, writefds;
    struct sockaddr_in server_addr;
    struct sockaddr_in client_addr;

    int sockListen = socket(AF_INET, SOCK_STREAM, 0);
    if ( sockListen < 0 ) {
        printf("socket error\n");
        return -1;
    }

    bzero(&server_addr, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(MYPORT);
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    if ( bind(sockListen, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0 ) {
        printf("bind error\n");
        return -1;
    }

    if ( listen(sockListen, 5) < 0 ) {
        printf("listen error\n");
        return -1;
    }

    epollfd = epoll_create(MAX_EVENTS);
    sock_ctx * sockctx = malloc(sizeof(sock_ctx));
    sockctx->sock = sockListen;
    sockctx->wctx = 0;
    sockctx->rctx = malloc(sizeof(ucontext_t));
    getcontext( sockctx->rctx );

    sockctx->rctx->uc_link = &ctx_main;
    sockctx->rctx->uc_stack.ss_size = MAX_STACK;
    sockctx->rctx->uc_stack.ss_sp = malloc(MAX_STACK);
    setupcontext( sockctx->rctx, sockctx, (void(*)())accept_sock );

    if ( epoll_add(EPOLLIN, sockctx) < 0 ) {
        printf("epoll add fail : fd = %d\n", sockListen);
        return -1;
    }

    while ( 1 ) {
        struct epoll_event eventList[MAX_EVENTS];
        int ret = epoll_wait(epollfd, eventList, MAX_EVENTS, timeout);
        if ( ret == 0 ) {
            continue;
        } else if ( ret < 0 ) {
            break;
        }

        printf("epoll_wait wakeup enter\n");
        for ( i = 0; i < ret; i++ ) {
            sock_ctx * sockctx = (sock_ctx *)eventList[i].data.ptr ;
            if ( (eventList[i].events & EPOLLERR) || (eventList[i].events & EPOLLHUP) ) {
                printf ( "sock[%d] error\n", sockctx->sock );
                close (sockctx->sock);
                continue;
            }
            if ( eventList[i].events & EPOLLIN ) {
                if ( swapcontext( &ctx_main, sockctx->rctx ) == -1 ) {
                    printf ( "swapcontext read error\n");
                }
            } else if ( eventList[i].events & EPOLLOUT ) {
                if ( swapcontext( &ctx_main, sockctx->wctx ) == -1 ) {
                    printf ( "swapcontext write error\n");
                }
            }
        }
        printf("epoll_wait wakeup leave\n");
    }

    close(epollfd);
    close(sockListen);
 
    return 0;
}

  setjmp/longjmp 的实现方式

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <errno.h>
#include <unistd.h>
#include <setjmp.h>
#include <sys/time.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>

#define MYPORT      12345
#define MAX_STACK   8192
#define MAX_EVENTS  1024

int epollfd;
struct epoll_event eventList[MAX_EVENTS];

typedef struct sock_ctx {
    int sock;
    jmp_buf rjmp;  // read jmp
    jmp_buf wjmp;  // write jmp
    jmp_buf bjmp;  // back jmp
}sock_ctx;

//-----------------------------------------------------------------------------------------
// 很多函数内的参数和变量经过 setjmp/longjmp 后,并不能正确恢复,所以需要重新赋值
// setjmp/longjmp 中间可能会被 信号中断,堆栈信息也不能恢复,想做成熟的框架,还得继续改进
//-----------------------------------------------------------------------------------------

// 挂起 ( 不能用函数 setjmp 需要保存当前函数的栈帧信息 )
#define co_yield(oldjmp, newjmp, value) \
    int setjmp_ret = setjmp(oldjmp); \
    if ( 0 == setjmp_ret ) { \
        longjmp(newjmp, value); \
    }

// 定义
int epoll_action(int epollevt, sock_ctx* sockctx, int action) {
    struct epoll_event event;
    event.events = epollevt;
    event.data.ptr = sockctx;
    return epoll_ctl(epollfd, action, sockctx->sock, &event);
}

// 宏定义
#define epoll_mod(epollevt, sockctx) epoll_action(epollevt, sockctx, EPOLL_CTL_MOD)
// 宏定义
#define epoll_add(epollevt, sockctx) epoll_action(epollevt, sockctx, EPOLL_CTL_ADD)
// 宏定义
#define epoll_del(sockctx)  epoll_ctl(epollfd, EPOLL_CTL_DEL, sockctx->sock, NULL)

typedef void (*co_func)(jmp_buf *, sock_ctx*);
int start_coroutine(jmp_buf * tjmp, co_func func, void* arg) {
    if ( 0 == setjmp(*tjmp) ) {
        func(tjmp, arg);
        return 0;
    }
    return 1;
}

void write_sock(jmp_buf * newjmp, sock_ctx* sockctx) {
    co_yield(sockctx->wjmp, *newjmp, 1);
    sockctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
    printf("write_sock[%d] run %p\n", (int)sockctx->sock, sockctx);

    while ( 1 ) {
        epoll_mod(EPOLLIN, sockctx);
        printf("write_sock[%d] --- function1\n", (int)sockctx->sock);

        co_yield(sockctx->wjmp, sockctx->bjmp, 1);
        sockctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
        printf("write_sock[%d] --- function2\n", (int)sockctx->sock);
    }
}

void read_sock(jmp_buf* newjmp, sock_ctx* sockctx) {
    co_yield(sockctx->rjmp, *newjmp, 1);
    sockctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
    printf("read_sock[%d] run %p\n", (int)sockctx->sock, sockctx);

    while ( 1 ) {
        while ( 1 ) {
            char body[1024] = { 0 };
            int ret = recv(sockctx->sock, (char*)body, 1024, 0);
            if ( ret == 0 ) {
                printf("sock disconnect\n");
                epoll_del(sockctx);
                longjmp(sockctx->bjmp, 2);
                break;
            } else if ( ret < 0 ) {
                printf("sock error : %d\n", errno);
                break;
            }

            printf("read_sock[%d] --- buf = %s\n", sockctx->sock, body);
            if (ret < 1024) {
                break;
            }
        }

        epoll_mod(EPOLLOUT, sockctx);
        printf("read_sock[%d] --- function1\n", (int)sockctx->sock);

        co_yield(sockctx->rjmp, sockctx->bjmp, 1);
        sockctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
        printf("read_sock[%d] --- function2\n", (int)sockctx->sock);
    }
}

void accept_sock(jmp_buf* newjmp, sock_ctx* tmpctx) {
    co_yield(tmpctx->rjmp, *newjmp, 1);
    tmpctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
    printf("accept_sock run %d: %p\n", setjmp_ret, tmpctx);

    while ( 1 ) {
        struct sockaddr_in sin;
        socklen_t len = sizeof(struct sockaddr_in);
        bzero(&sin, len);
        int confd = accept(tmpctx->sock, (struct sockaddr*)&sin, &len);
        if (confd < 0) {
            printf("%d bad accept(%d)\n", tmpctx->sock, errno);
            return;
        }

        sock_ctx* sockctx = malloc(sizeof(sock_ctx));
        sockctx->sock = confd;
        start_coroutine(&tmpctx->rjmp, read_sock, sockctx);
        start_coroutine(&tmpctx->rjmp, write_sock, sockctx);
        printf("accept_sock ---- new connection %d\n", confd);

        epoll_add(EPOLLIN, sockctx);
        printf("accept_sock[%d] --- function1\n", (int)tmpctx->sock);

        co_yield(tmpctx->rjmp, tmpctx->bjmp, 1);
        tmpctx = (sock_ctx*)eventList[setjmp_ret - 1].data.ptr;
        printf("accept_sock[%d] --- function2\n", (int)tmpctx->sock);
    }
}

int main() {
    int i, timeout = 100000;
    fd_set readfds, writefds;
    struct sockaddr_in server_addr;
    struct sockaddr_in client_addr;
 
    int sockListen = socket(AF_INET, SOCK_STREAM, 0);
    if ( sockListen < 0 ) {
        printf("socket error\n");
        return -1;
    }

    bzero(&server_addr, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(MYPORT);
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    if ( bind(sockListen, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0 ) {
        printf("bind error\n");
        return -1;
    }
 
    if ( listen(sockListen, 5) < 0 ) {
        printf("listen error\n");
        return -1;
    }

    epollfd = epoll_create(MAX_EVENTS);

    jmp_buf mjmp;
    sock_ctx* sockctx = malloc(sizeof(sock_ctx));
    sockctx->sock = sockListen;
    start_coroutine(&mjmp, accept_sock, sockctx);
    if ( -1 == epoll_add(EPOLLIN, sockctx)) {
        printf("epoll_ctl error\n");
        return -1;
    }

    while ( 1 ) {
        int ret = epoll_wait(epollfd, eventList, MAX_EVENTS, timeout);
        if ( ret == 0 ) {
            continue;
        } else if ( ret < 0 ) {
            break;
        }

        printf("epoll_wait wakeup enter\n");
        for ( i = 0; i < ret; i++ ) {
            sock_ctx * sockctx = (sock_ctx *)eventList[i].data.ptr ;
            if ( (eventList[i].events & EPOLLERR) || (eventList[i].events & EPOLLHUP) ) {
                printf ( "sock[%d] error\n", sockctx->sock );
                close (sockctx->sock);
                continue;
            }
            if ( eventList[i].events & EPOLLIN ) {
                co_yield(sockctx->bjmp, sockctx->rjmp, i + 1);
            } else if ( eventList[i].events & EPOLLOUT ) {
                co_yield(sockctx->bjmp, sockctx->wjmp, i + 1);
            }
        }
        printf("epoll_wait wakeup leave\n");
    }

    close(epollfd);
    close(sockListen);
 
    return 0;
}

  

posted @ 2022-11-10 07:01  superconvert  阅读(159)  评论(0编辑  收藏  举报