Visual Studio 2019使用C语言进行websocket编程
一直在写C#代码好多年不写C语言代码了,记录一下之前某个项目里用C写的一个websocket服务,用C的优势是写的东西体积小性能高,但是写业务的话还得用C#、Java之类的语言,不然会折腾死人。。。
用Visual Studio新建一个C++(因为不能直接建C语言项目)项目,我演示就创建一个控制台项目。项目创建完后首先要添加socket编程需要的依赖库ws2_32.lib,添加方式如下图
也可以在代码文件里添加这句代码:#
添加完成后就可以开始写代码了,说句题外话,Visual Studio写C语言最好把SDL检查也关掉。
新建一个wsserver.h头文件,头文件相关定义代码如下
#pragma once #include <WinSock2.h> #include <stdint.h> #include <stdbool.h> #include <stdio.h> #include <string.h> #include <ctype.h> #include <windef.h> #include <stdlib.h> #include "sha1.h" #include "b64.h" #include "cJSON.h" typedef enum FrameType { frameType_continuation, frameType_text, frameType_binary, frameType_connectionClose, frameType_ping, frameType_pong } FrameType; typedef enum FrameState { frameState_init, // 未读取任何字节 frameState_firstByte, // 已读取首字节FIN、RSV、opcode frameState_mask, // 已读取掩码 frameState_7bitLength, // 已读取7bit长度 frameState_16bitLengthWait, // 等待读取16bit长度 frameState_63bitLengthWait, // 等待读取63bit长度 frameState_16bitLength, // 已读取16bit长度 frameState_63bitLength, // 已读取63bit长度 frameState_maskingKey, // 已读取Masking-key frameState_readingData, // 正在读取载荷数据 frameState_success, // 读取完毕 frameState_failure // 读取错误 } FrameState; typedef struct WsFrame { FrameState state; bool FIN; FrameType frameType; uint8_t mask[4]; unsigned char* buff; // 数据存放的空间 uint64_t buffSize; // 当前申请的buff大小 uint64_t handledLen; // 已处理的帧长度 uint64_t headerLen; // 帧头长度 只有在state为'已读取掩码'及之后才有意义 uint64_t payloadLen; // 载荷长度 只有在state为'已读取xbit长度'后才有意义 struct WsFrame* next; // 下一帧的指针 } WsFrame; void initWsFrameStruct(WsFrame* wsFrame); char* convertToWebSocketFrame(const char* data, FrameType type, size_t len, size_t* newLen); int readWebSocketFrameStream(WsFrame* wsFrame, const char* buff, int len); void freeWebSocketFrame(WsFrame* wsFrame); int wsShakeHands(const char* recvBuff, int recvLen, SOCKET socket, const char* path); int wsFrameSend(SOCKET socket, const char* buff, int len, FrameType type); void wsFrameSendToAll(const char* buff, int len, FrameType type); int serverStart(const char* address, u_short port, const char* path); void serverStop(void);
新建一个wsserver.c代码文件,我们一步一步的来实现这些方法。
首先是serverStart方法,顾名思义,启动ws服务,第一个参数是地址(一般传本机IP),第二个参数是要监听的端口,第三个参数是路径,完整代码如下
#include "wsserver.h" #include <time.h> #define _CRT_SECURE_NO_WARNINGS #define _WINSOCK_DEPRECATED_NO_WARNINGS typedef enum { socketProtocol, websocketProtocol } Protocol; typedef struct { Protocol protocol; SOCKET socket; WsFrame wsFrame; } Client; #define MAX_CLIENT_NUM FD_SETSIZE static struct { int total; Client clients[MAX_CLIENT_NUM]; } clientSockets; static SOCKET serverSocket; // 打印日志 void printLog(const char* type, const char* format, ...) { char buff[512] = { 0 }; va_list arg; va_start(arg, format); vsnprintf(buff, sizeof(buff) - 1, format, arg); va_end(arg); char rbuf[512] = { 0 }; time_t log_time = time(NULL); struct tm* tm_log = localtime(&log_time); printf("[%04d-%02d-%02d %02d:%02d:%02d] ", tm_log->tm_year + 1900, tm_log->tm_mon + 1, tm_log->tm_mday, tm_log->tm_hour, tm_log->tm_min, tm_log->tm_sec); snprintf(rbuf, 512, "%s->%s\n", type, buff); printf(rbuf); } char* UTF8ToGBK(const char* str) { // GB18030代码页 const int CODE_PAGE = 54936; int n = MultiByteToWideChar(CP_UTF8, 0, str, -1, NULL, 0); wchar_t u16str[10000]; MultiByteToWideChar(CP_UTF8, 0, str, -1, u16str, n); n = WideCharToMultiByte(CODE_PAGE, 0, u16str, -1, NULL, 0, NULL, NULL); char* gbstr = malloc(n + 1); WideCharToMultiByte(CODE_PAGE, 0, u16str, -1, gbstr, n, NULL, NULL); return gbstr; } char* GBKToUTF8(const char* str) { const int CODE_PAGE = 54936; int n = MultiByteToWideChar(CODE_PAGE, 0, str, -1, NULL, 0); wchar_t u16str[10000]; MultiByteToWideChar(CODE_PAGE, 0, str, -1, u16str, n); n = WideCharToMultiByte(CP_UTF8, 0, u16str, -1, NULL, 0, NULL, NULL); char* u8str = malloc(n + 1); WideCharToMultiByte(CP_UTF8, 0, u16str, -1, u8str, n, NULL, NULL); return u8str; } int serverStart(const char* address, u_short port, const char* path) { // 调用 WSAStartup() 函数进行初始化,并指明要使用的版本号。 WSADATA wsaData; // WSAStartup 函数启动进程使用 Winsock DLL。 int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); if (iResult != 0) { printLog("ServerStart", "WSAStartup failed"); return -1; } struct sockaddr_in sockAddr; // ZeroMemory 宏 等价于 memset((buf),0,(BUF_SIZE)) ZeroMemory(&sockAddr, sizeof(sockAddr)); sockAddr.sin_family = PF_INET; // 等价于 AF_INET TCP UDP etc.. sockAddr.sin_addr.s_addr = inet_addr(address); sockAddr.sin_port = htons(port); // 构建一个socket对象 serverSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); if (serverSocket == INVALID_SOCKET) { printLog("ServerStart","Error at socket(): %d", WSAGetLastError()); WSACleanup(); return -1; } // 给socket绑定地址 if (bind(serverSocket, (SOCKADDR*)&sockAddr, sizeof(SOCKADDR)) == SOCKET_ERROR) { printLog("ServerStart", "Bind failed with error: %d", WSAGetLastError()); closesocket(serverSocket); WSACleanup(); return -1; } // 开始启动监听 if (listen(serverSocket, SOMAXCONN) == SOCKET_ERROR) { printLog("ServerStart", "Listen failed with error: %d", WSAGetLastError()); closesocket(serverSocket); WSACleanup(); return -1; } clientSockets.total = 0; DWORD dwThreadId; // 创建线程开始接收socket数据 HANDLE hHandle = CreateThread(NULL, 0, (void*)receiveComingData, (PVOID)path, 0, &dwThreadId); return 0; }
serverStart方法中最后创建线程开始接收socket数据的方法receiveComingData代码
void receiveComingData(const char* path) { #define RECV_BUFLEN 0X40000 char recvbuf[RECV_BUFLEN]; int iResult; int ret; fd_set fdread; struct timeval tv = { 1, 0 }; receivingDataLoop: FD_ZERO(&fdread); // 清空socket集合 FD_SET(serverSocket, &fdread); // 设置socket数据读取集合 for (int i = 0; i < clientSockets.total; i++) { FD_SET(clientSockets.clients[i].socket, &fdread); } // 检查socket是否有数据可读 ret = select(0, &fdread, NULL, NULL, &tv); if (ret == 0) { goto receivingDataLoop; // select的等待时间到达,开始下一轮等待 } // 检查socket是否在这个集合里 if (FD_ISSET(serverSocket, &fdread)) { acceptConnect(); // 处理socket连接 } for (int i = 0; i < clientSockets.total; i++) { Client* client = &clientSockets.clients[i]; if (!FD_ISSET(client->socket, &fdread)) { continue; } // 接收数据 iResult = recv(client->socket, recvbuf, RECV_BUFLEN, 0); if (iResult > 0) { printLog("receiveComingData", "Bytes received: %d", iResult); // 协议升级 if (client->protocol == socketProtocol) { int result = wsShakeHands(recvbuf, iResult, client->socket, path); if (result != 0) { removeClient(i--); } else { client->protocol = websocketProtocol; initWsFrameStruct(&client->wsFrame); // 初始化ws帧结构 } } // WebSocket通信 else if (client->protocol == websocketProtocol) { int result = wsClientDataHandle(recvbuf, iResult, client); if (result == -1) { removeClient(i--); } } } else { if (iResult == 0) { // 客户端礼貌的关闭连接 printLog("receiveComingData", "Connection closing..."); } else { // 客户端异常关闭连接等情况 printLog("receiveComingData", "Recv failed: %d", WSAGetLastError()); } removeClient(i--); } } goto receivingDataLoop; }
receiveComingData方法里处理socket协议升级的代码
// 不区分大小写的比较字符串,相等返回true bool stricasecmp(const char* a, const char* b) { do { if (*a == '\0' && *b == '\0') return true; } while (tolower(*a++) == tolower(*b++)); return false; } // 不区分大小写的比较字符串,n个字符内(包括n)相等返回true bool strnicasecmp(const char* a, const char* b, unsigned n) { do { if (n-- == 0 || (*a == '\0' && *b == '\0')) return true; } while (tolower(*a++) == tolower(*b++)); return false; } int getSecWebSocketAcceptKey(const char* key, char* b64buff, int len) { SHA1_CTX ctx; unsigned char hash[20], buff[512]; if (strlen(key) > 256) { return -1; } sprintf(buff, "%s%s", key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); SHA1Init(&ctx); SHA1Update(&ctx, buff, strlen(buff)); SHA1Final(hash, &ctx); const char* base64 = b64_encode(hash, sizeof(hash)); strncpy(b64buff, base64, len - 1); b64buff[len - 1] = '\0'; free((void*)base64); return 0; } // 校验WebSocket握手的HTTP头,失败返回NULL,校验成功顺带返回Sec-WebSocket-Key,记得free char* verifyHandshakeHeaders(const char* str, size_t len) { char* secKey = NULL; char a[1024], b[1024]; bool connection, upgrade, version, key; connection = upgrade = version = key = false; if (strcmp(str + len - 4, "\r\n\r\n") != 0) { printLog("verifyHandshakeHeaders","HTTP header does not end with '\\r\\n\\r\\n'"); return NULL; } const char* cur1 = strstr(str, "\r\n") + 2; const char* cur2; while ((cur2 = strstr(cur1, "\r\n")) != cur1) { cur2 += 2; // 跳过\r\n const char* colon = strchr(cur1, ':'); if (colon == NULL || colon >= cur2) { printLog("verifyHandshakeHeaders", "Unexpected HTTP header"); break; } if (sscanf(cur1, "%[^:]:%s", a, b) != 2) { printLog("verifyHandshakeHeaders", "HTTP header parsing failed"); break; } if (stricasecmp(a, "connection")) { connection = true; } else if (stricasecmp(a, "upgrade")) { if (!stricasecmp(b, "websocket")) { printLog("verifyHandshakeHeaders", "Unexpected value '%s' of Upgrade filed", b); break; } upgrade = true; } else if (stricasecmp(a, "Sec-WebSocket-Version")) { if (!stricasecmp(b, "13")) { printLog("verifyHandshakeHeaders","Unexpected value '%s' of Sec-WebSocket-Version filed", b); break; } version = true; } else if (stricasecmp(a, "Sec-WebSocket-Key")) { if (!key) { key = true; secKey = malloc(strlen(b) + 1); strcpy(secKey, b); } } cur1 = cur2; } if (!(connection && upgrade && version && key)) { printLog("verifyHandshakeHeaders", "Missing necessary fields"); if (key) free((void*)secKey); // 释放申请的内存 return NULL; } return secKey; } int wsShakeHands(const char* recvBuff, int recvLen, SOCKET socket, const char* path) { #define RECV_BUFLEN 0X40001 #define HTTP_MAXLEN 1536 #define HTTP_400 "HTTP/1.1 400 Bad Request\r\n\r\n" // HTTP握手包太长 if (recvLen > HTTP_MAXLEN) { send(socket, HTTP_400, strlen(HTTP_400), 0); printLog("wsShakeHands","Request too long"); return -1; } // 注:recvBuff不以'\0'结尾 char resText[HTTP_MAXLEN + 1]; memcpy(resText, recvBuff, recvLen); resText[recvLen] = '\0'; char requestLine[512]; sprintf(requestLine, "GET %s%s HTTP/1.1\r\n", (strlen(path) == 0 || path[0] != '/') ? "/" : "", path); // 注:路径部分也被不区分大小写的比较 if (!strnicasecmp(resText, requestLine, strlen(requestLine))) { send(socket, HTTP_400, strlen(HTTP_400), 0); printLog("wsShakeHands","Unexpected request line"); printLog("wsShakeHands", resText); return -1; } const char* secKey = verifyHandshakeHeaders(resText, recvLen); if (!secKey) { send(socket, HTTP_400, strlen(HTTP_400), 0); return -1; } // 获取Sec-WebSocket-Accept char acptBuff[128]; getSecWebSocketAcceptKey(secKey, acptBuff, sizeof(acptBuff)); printLog("wsShakeHands","Sec-WebSocket-Key is '%s'", secKey); printLog("wsShakeHands", "Sec-WebSocket-Accept is '%s'", acptBuff); free((void*)secKey); // 释放secKey // 协议升级 char resBuff[256]; // 注:当前的CORS设置可能会导致安全问题 // 注:响应中没有包含Sec-Websocket-Protocol头,代表不接受任何客户端请求的ws扩展 const char resHeader[] = "HTTP/1.1 101 ojbk\r\n" "Connection: Upgrade\r\n" "Upgrade: websocket\r\n" "Sec-WebSocket-Accept: %s\r\n" "Access-Control-Allow-Origin: *\r\n" "\r\n" ; int resLen = sprintf(resBuff, resHeader, acptBuff); // Send data to the client int iSendResult = send(socket, resBuff, resLen, 0); if (iSendResult == SOCKET_ERROR) { return -1; } printLog("wsShakeHands","Bytes sent: %d", iSendResult); printLog("wsShakeHands","WebSocket handshake succeeded"); return 0; }
receiveComingData方法里处理websocket数据的方法wsClientDataHandle代码
int readWebSocketFrameStream(WsFrame* wsFrame, const char* buff, int len) { if (wsFrame->buff == NULL) { wsFrame->buff = malloc(len); wsFrame->buffSize = len; memcpy(wsFrame->buff, buff, len); } else { char* copyStartAddr; int requiedLen = wsFrame->buffSize + len; wsFrame->buff = realloc((void*)wsFrame->buff, requiedLen); copyStartAddr = wsFrame->buff + wsFrame->buffSize; wsFrame->buffSize = requiedLen; memcpy(copyStartAddr, buff, len); } // 消耗的数据量 int consumed = 0; stateTransitionBegin: switch (wsFrame->state) { case frameState_init: if (wsFrame->buffSize < 1) { return consumed; } wsFrame->FIN = !!(wsFrame->buff[0] & 0X80); // RSV位不全为0,存在扩展协议,服务器不处理扩展协议 if ((wsFrame->buff[0] & 0X70) != 0) { wsFrame->state = frameState_failure; break; } int opcode = wsFrame->buff[0] & 0X0F; if (opcode == 0X0) { wsFrame->frameType = frameType_continuation; } else if (opcode == 0X1) { wsFrame->frameType = frameType_text; } else if (opcode == 0X2) { wsFrame->frameType = frameType_binary; } else if (opcode == 0X8) { wsFrame->frameType = frameType_connectionClose; } else if (opcode == 0X9) { wsFrame->frameType = frameType_ping; } else if (opcode == 0XA) { wsFrame->frameType = frameType_pong; } else { wsFrame->state = frameState_failure; break; } consumed += 1; wsFrame->handledLen += 1; wsFrame->headerLen += 1; wsFrame->state = frameState_firstByte; break; case frameState_firstByte: if (wsFrame->buffSize < 2) { return consumed; } // 标准规定客户端传入帧的掩码位必须不为0 if ((wsFrame->buff[1] & 0X80) == 0) { wsFrame->state = frameState_failure; } wsFrame->state = frameState_mask; break; case frameState_mask: if (wsFrame->buffSize < 2) { return consumed; } uint8_t payloadLen = wsFrame->buff[1] & 0X7F; // frame-payload-length-7 if (payloadLen < 126) { wsFrame->payloadLen = payloadLen; wsFrame->state = frameState_7bitLength; } else if (payloadLen == 126) { wsFrame->state = frameState_16bitLengthWait; } else if (payloadLen == 127) { wsFrame->state = frameState_63bitLengthWait; } consumed += 1; wsFrame->headerLen += 1; wsFrame->handledLen += 1; break; case frameState_7bitLength: // 2字节共有字段 + 0字节附加长度字段 + 4字节掩码 if (wsFrame->buffSize < 6) { return consumed; } for (int i = 0; i < 4; i++) { wsFrame->mask[i] = wsFrame->buff[i + 2]; } consumed += 4; wsFrame->headerLen += 4; wsFrame->handledLen += 4; wsFrame->state = frameState_maskingKey; break; case frameState_16bitLengthWait: if (wsFrame->buffSize < 4) { return consumed; } wsFrame->payloadLen = ((uint16_t)wsFrame->buff[2] << 8) + (uint16_t)wsFrame->buff[3]; consumed += 2; wsFrame->headerLen += 2; wsFrame->handledLen += 2; wsFrame->state = frameState_16bitLength; break; case frameState_63bitLengthWait: if (wsFrame->buffSize < 10) { return consumed; } unsigned char* recvBuff = wsFrame->buff; // 注:标准规定64位时最高bit必须为0,这里未作处理 wsFrame->payloadLen = ((uint64_t)recvBuff[2] << (8 * 7)) + ((uint64_t)recvBuff[3] << (8 * 6)) + ((uint64_t)recvBuff[4] << (8 * 5)) + ((uint64_t)recvBuff[5] << (8 * 4)) + ((uint64_t)recvBuff[6] << (8 * 3)) + ((uint64_t)recvBuff[7] << (8 * 2)) + ((uint64_t)recvBuff[8] << (8 * 1)) + ((uint64_t)recvBuff[9] << (8 * 0)); consumed += 8; wsFrame->headerLen += 8; wsFrame->handledLen += 8; wsFrame->state = frameState_63bitLength; break; case frameState_16bitLength: // 2字节共有字段 + 2字节附加长度字段 + 4字节掩码 if (wsFrame->buffSize < 8) { return consumed; } for (int i = 0; i < 4; i++) { wsFrame->mask[i] = wsFrame->buff[i + 4]; } consumed += 4; wsFrame->headerLen += 4; wsFrame->handledLen += 4; wsFrame->state = frameState_maskingKey; break; case frameState_63bitLength: // 2字节共有字段 + 8字节附加长度字段 + 4字节掩码 if (wsFrame->buffSize < 14) { return consumed; } for (int i = 0; i < 4; i++) { wsFrame->mask[i] = wsFrame->buff[i + 10]; } consumed += 4; wsFrame->headerLen += 4; wsFrame->handledLen += 4; wsFrame->state = frameState_maskingKey; break; case frameState_maskingKey: wsFrame->state = frameState_readingData; break; case frameState_readingData: ; // case第一个语句不能是变量声明 uint64_t total = wsFrame->payloadLen + wsFrame->headerLen; // 注意 buff的长度可能大于帧总长度 // 因为TCP是面向字节流的,buff中有可能包含下一帧的数据 // 所以读取时要根据帧头和载荷长度来判断最多读多少数据 if (wsFrame->buffSize >= total) { consumed += total - wsFrame->handledLen; wsFrame->handledLen = total; wsFrame->state = frameState_success; } else { consumed += wsFrame->buffSize - wsFrame->handledLen; wsFrame->handledLen = wsFrame->buffSize; return consumed; } break; case frameState_success: return consumed; break; case frameState_failure: return consumed; break; } goto stateTransitionBegin; return consumed; } int wsFrameSend(SOCKET socket, const char* buff, int len, FrameType type) { int newLen; const char* frame = convertToWebSocketFrame(buff, type, len, &newLen); int iSendResult = send(socket, frame, newLen, 0); if (iSendResult == SOCKET_ERROR) { printLog("wsFrameSend", 1, "Send failed: %d", WSAGetLastError()); goto wsFrameSendEnd; } wsFrameSendEnd: free((void*)frame); return iSendResult; } // 处理WebSocket帧数据,返回-1代表需要关闭连接 int wsClientDataHandle(const char* recvBuff, int recvLen, Client* client) { WsFrame* wsFrame = &client->wsFrame; if (recvLen == 0) { return 0; } int consume = readWebSocketFrameStream(wsFrame, recvBuff, recvLen); if (wsFrame->state == frameState_success) { // 暂时不处理多帧数据,遇到多帧数据关闭连接 if (wsFrame->FIN == 0) { return -1; } // 客户端希望关闭连接 if (wsFrame->frameType == frameType_connectionClose) { return -1; } // 遇到意料之外的帧类型 if (wsFrame->frameType == frameType_binary || wsFrame->frameType == frameType_pong || wsFrame->frameType == frameType_continuation ) { return -1; } uint64_t payloadLen = wsFrame->payloadLen; u_char* payload = wsFrame->buff + wsFrame->headerLen; // 解码载荷 for (uint64_t j = 0; j < payloadLen; j++) { payload[j] = payload[j] ^ wsFrame->mask[j % 4]; } int iSendResult = 0; // 心跳 if (wsFrame->frameType == frameType_ping) { wsFrameSend(client->socket, payload, payloadLen, frameType_pong); } // 处理文本数据 if (wsFrame->frameType == frameType_text) { wsClientTextDataHandle(payload, payloadLen, client->socket); } } // 一个帧接收完成并处理完毕后释放内存 if (wsFrame->state == frameState_success) { freeWebSocketFrame(wsFrame); } // 解析ws帧出错,释放内存并通知关闭连接 if (wsFrame->state == frameState_failure) { freeWebSocketFrame(wsFrame); return -1; } // 传入的数据不止包含当前帧,包含下一帧的数据 if (consume != recvLen) { return wsClientDataHandle(recvBuff + consume, recvLen - consume, client); } return 0; }
处理文本数据的方法wsClientTextDataHandle代码
// 处理数据,也可以写在别处做回调函数 void wsClientTextDataHandle(const char* payload, uint64_t payloadLen, SOCKET socket) { const char* parseEnd; cJSON* json = cJSON_ParseWithOpts(payload, &parseEnd, 0); if (json == NULL) { const char* error_ptr = cJSON_GetErrorPtr(); if (error_ptr != NULL) { printLog("jsonParse", "Error before: %d", error_ptr - payload); } return; } const cJSON* j_msg = cJSON_GetObjectItemCaseSensitive(json, "msg"); const cJSON_bool e_msg = cJSON_IsString(j_msg); const char* v_msg = e_msg ? j_msg->valuestring : NULL; char* gbkText = UTF8ToGBK(v_msg); sendJSON(socket, "send", GBKToUTF8(gbkText)); free((void*)gbkText); }
给客户端发送消息的代码
int wsFrameSend(SOCKET socket, const char* buff, int len, FrameType type) { int newLen; const char* frame = convertToWebSocketFrame(buff, type, len, &newLen); int iSendResult = send(socket, frame, newLen, 0); if (iSendResult == SOCKET_ERROR) { printLog("wsFrameSend", 1, "Send failed: %d", WSAGetLastError()); goto wsFrameSendEnd; } printLog("wsFrameSend","Bytes sent: %d", iSendResult); wsFrameSendEnd: free((void*)frame); return iSendResult; } void sendJSON(SOCKET socket, const char* event, const char* data) { cJSON* root = cJSON_CreateObject(); cJSON_AddItemToObject(root, "event", cJSON_CreateString(event)); cJSON_AddItemToObject(root, "data", cJSON_CreateString(data)); const char* jsonStr = cJSON_PrintUnformatted(root); wsFrameSend(socket, jsonStr, strlen(jsonStr), frameType_text); cJSON_Delete(root); free((void*)jsonStr); }
到这里主要的核心代码就写完了,接下来测试一下
int main() { int result = serverStart("127.0.0.1", 1024, "/"); if (result != 0) { MessageBoxA(NULL, "wsserver start failed", "WebSocket Plugin", MB_OK | MB_ICONERROR); } else { //fileLog("websocket server startup success"); } system("pause"); return 0; }
测试成功 如图所示