boost asio多线程模式-IOThreadPool

今天给大家介绍asio多线程模式的第二种,之前我们介绍了IOServicePool的方式,一个IOServicePool开启n个线程和n个iocontext,每个线程内独立运行iocontext, 各个iocontext监听各自绑定的socket是否就绪,如果就绪就在各自线程里触发回调函数。为避免线程安全问题,我们将网络数据封装为逻辑包投递给逻辑系统,逻辑系统有一个单独线程处理,这样将网络IO和逻辑处理解耦合,极大的提高了服务器IO层面的吞吐率。这一次介绍的另一种多线程模式IOThreadPool,我们只初始化一个iocontext用来监听服务器的读写事件,包括新连接到来的监听也用这个iocontext。只是我们让iocontext.run在多个线程中调用,这样回调函数就会被不同的线程触发,从这个角度看回调函数被并发调用了。

结构图

线程池模式的多线程模型调度结构图,如下:

其中代码如下:

const.h

复制代码
#pragma once
#define MAX_LENGTH  1024*2
//头部总长度
#define HEAD_TOTAL_LEN 4
//头部id长度
#define HEAD_ID_LEN 2
//头部数据长度
#define HEAD_DATA_LEN 2
#define MAX_RECVQUE  10000
#define MAX_SENDQUE 1000


enum MSG_IDS {
    MSG_HELLO_WORD = 1001
};
复制代码

Singleton.h

复制代码
#pragma once
#include <memory>
#include <mutex>
#include <iostream>
using namespace std;
template <typename T>
class Singleton {
protected:
    Singleton() = default;
    Singleton(const Singleton<T>&) = delete;
    Singleton& operator=(const Singleton<T>& st) = delete;
    
    static std::shared_ptr<T> _instance;
public:
    static std::shared_ptr<T> GetInstance() {
        static std::once_flag s_flag;
        std::call_once(s_flag, [&]() {
            _instance = shared_ptr<T>(new T);
            });

        return _instance;
    }
    void PrintAddress() {
        std::cout << _instance.get() << endl;
    }
    ~Singleton() {
        std::cout << "this is singleton destruct" << std::endl;
    }
};

template <typename T>
std::shared_ptr<T> Singleton<T>::_instance = nullptr;
复制代码

MsgNode.h

复制代码
#pragma once
#include <string>
#include "const.h"
#include <iostream>
#include <boost/asio.hpp>
using namespace std;
using boost::asio::ip::tcp;
class LogicSystem;
class MsgNode
{
public:
    MsgNode(short max_len) :_total_len(max_len), _cur_len(0) {
        _data = new char[_total_len + 1]();
        _data[_total_len] = '\0';
    }

    ~MsgNode() {
        std::cout << "destruct MsgNode" << endl;
        delete[] _data;
    }

    void Clear() {
        ::memset(_data, 0, _total_len);
        _cur_len = 0;
    }

    short _cur_len;
    short _total_len;
    char* _data;
};

class RecvNode :public MsgNode {
    friend class LogicSystem;
public:
    RecvNode(short max_len, short msg_id);
private:
    short _msg_id;
};

class SendNode:public MsgNode {
    friend class LogicSystem;
public:
    SendNode(const char* msg,short max_len, short msg_id);
private:
    short _msg_id;
};
复制代码

MsgNode.cpp

复制代码
#include "MsgNode.h"
RecvNode::RecvNode(short max_len, short msg_id):MsgNode(max_len),
_msg_id(msg_id){

}


SendNode::SendNode(const char* msg, short max_len, short msg_id):MsgNode(max_len + HEAD_TOTAL_LEN)
, _msg_id(msg_id){
    //先发送id, 转为网络字节序
    short msg_id_host = boost::asio::detail::socket_ops::host_to_network_short(msg_id);
    memcpy(_data, &msg_id_host, HEAD_ID_LEN);
    //转为网络字节序
    short max_len_host = boost::asio::detail::socket_ops::host_to_network_short(max_len);
    memcpy(_data + HEAD_ID_LEN, &max_len_host, HEAD_DATA_LEN);
    memcpy(_data + HEAD_ID_LEN + HEAD_DATA_LEN, msg, max_len);
}
复制代码

LogicSystem.h

复制代码
#pragma once
#include "Singleton.h"
#include <queue>
#include <thread>
#include "CSession.h"
#include <queue>
#include <map>
#include <functional>
#include "const.h"
#include <json/json.h>
#include <json/value.h>
#include <json/reader.h>

typedef  function<void(shared_ptr<CSession>, const short &msg_id, const string &msg_data)> FunCallBack;
class LogicSystem:public Singleton<LogicSystem>
{
    friend class Singleton<LogicSystem>;
public:
    ~LogicSystem();
    void PostMsgToQue(shared_ptr < LogicNode> msg);
private:
    LogicSystem();
    void DealMsg();
    void RegisterCallBacks();
    void HelloWordCallBack(shared_ptr<CSession>, const short &msg_id, const string &msg_data);
    std::thread _worker_thread;
    std::queue<shared_ptr<LogicNode>> _msg_que;
    std::mutex _mutex;
    std::condition_variable _consume;
    bool _b_stop;
    std::map<short, FunCallBack> _fun_callbacks;
};
复制代码

LogicSystem.cpp

复制代码
#include "LogicSystem.h"

using namespace std;

LogicSystem::LogicSystem():_b_stop(false){
    RegisterCallBacks();
    _worker_thread = std::thread (&LogicSystem::DealMsg, this);
}

LogicSystem::~LogicSystem(){
    _b_stop = true;
    _consume.notify_one();
    _worker_thread.join();
}

void LogicSystem::PostMsgToQue(shared_ptr < LogicNode> msg) {
    std::unique_lock<std::mutex> unique_lk(_mutex);
    _msg_que.push(msg);
    //由0变为1则发送通知信号
    if (_msg_que.size() == 1) {
        unique_lk.unlock();
        _consume.notify_one();
    }
}

void LogicSystem::DealMsg() {
    for (;;) {
        std::unique_lock<std::mutex> unique_lk(_mutex);
        //判断队列为空则用条件变量阻塞等待,并释放锁
        while (_msg_que.empty() && !_b_stop) {
            _consume.wait(unique_lk);
        }

        //判断是否为关闭状态,把所有逻辑执行完后则退出循环
        if (_b_stop ) {
            while (!_msg_que.empty()) {
                auto msg_node = _msg_que.front();
                cout << "recv_msg id  is " << msg_node->_recvnode->_msg_id << endl;
                auto call_back_iter = _fun_callbacks.find(msg_node->_recvnode->_msg_id);
                if (call_back_iter == _fun_callbacks.end()) {
                    _msg_que.pop();
                    continue;
                }
                call_back_iter->second(msg_node->_session, msg_node->_recvnode->_msg_id,
                    std::string(msg_node->_recvnode->_data, msg_node->_recvnode->_cur_len));
                _msg_que.pop();
            }
            break;
        }

        //如果没有停服,且说明队列中有数据
        auto msg_node = _msg_que.front();
        cout << "recv_msg id  is " << msg_node->_recvnode->_msg_id << endl;
        auto call_back_iter = _fun_callbacks.find(msg_node->_recvnode->_msg_id);
        if (call_back_iter == _fun_callbacks.end()) {
            _msg_que.pop();
            continue;
        }
        call_back_iter->second(msg_node->_session, msg_node->_recvnode->_msg_id, 
            std::string(msg_node->_recvnode->_data, msg_node->_recvnode->_cur_len));
        _msg_que.pop();
    }
}

void LogicSystem::RegisterCallBacks() {
    _fun_callbacks[MSG_HELLO_WORD] = std::bind(&LogicSystem::HelloWordCallBack, this,
        placeholders::_1, placeholders::_2, placeholders::_3);
}

void LogicSystem::HelloWordCallBack(shared_ptr<CSession> session, const short &msg_id, const string &msg_data) {
    Json::Reader reader;
    Json::Value root;
    reader.parse(msg_data, root);
    std::cout << "recevie msg id  is " << root["id"].asInt() << " msg data is "
        << root["data"].asString() << endl;
    root["data"] = "server has received msg, msg data is " + root["data"].asString();
    std::string return_str = root.toStyledString();
    session->Send(return_str, root["id"].asInt());
}
复制代码

 

对于多线程触发回调函数的情况,我们可以利用asio提供的串行类strand封装一下,这样就可以被串行调用了,其基本原理就是在线程各自调用函数时取消了直接调用的方式,而是利用一个strand类型的对象将要调用的函数投递到strand管理的队列中,再由一个统一的线程调用回调函数,调用是串行的,解决了线程并发带来的安全问题。

 

 

CSession.h

#pragma once
#include <boost/asio.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <queue>
#include <mutex>
#include <memory>
#include "const.h"
#include "MsgNode.h"
using namespace std;

using boost::asio::ip::tcp;
class CServer;
class LogicSystem;
using boost::asio::strand;
using boost::asio::io_context;
class CSession: public std::enable_shared_from_this<CSession>
{
public:
    CSession(boost::asio::io_context& io_context, CServer* server);
    ~CSession();
    tcp::socket& GetSocket();
    std::string& GetUuid();
    void Start();
    void Send(char* msg,  short max_length, short msgid);
    void Send(std::string msg, short msgid);
    void Close();
    std::shared_ptr<CSession> SharedSelf();
private:
    void HandleRead(const boost::system::error_code& error, size_t  bytes_transferred, std::shared_ptr<CSession> shared_self);
    void HandleWrite(const boost::system::error_code& error, std::shared_ptr<CSession> shared_self);
    tcp::socket _socket;
    std::string _uuid;
    char _data[MAX_LENGTH];
    CServer* _server;
    bool _b_close;
    std::queue<shared_ptr<SendNode> > _send_que;
    std::mutex _send_lock;
    //收到的消息结构
    std::shared_ptr<RecvNode> _recv_msg_node;
    bool _b_head_parse;
    //收到的头部结构
    std::shared_ptr<MsgNode> _recv_head_node;
    strand<io_context::executor_type> _strand;
};

class LogicNode {
    friend class LogicSystem;
public:
    LogicNode(shared_ptr<CSession>, shared_ptr<RecvNode>);
private:
    shared_ptr<CSession> _session;
    shared_ptr<RecvNode> _recvnode;
};

CSession.cpp

#include "CSession.h"
#include "CServer.h"
#include <iostream>
#include <sstream>
#include <json/json.h>
#include <json/value.h>
#include <json/reader.h>
#include "LogicSystem.h"

CSession::CSession(boost::asio::io_context& io_context, CServer* server):
    _socket(io_context), _server(server), _b_close(false),
    _b_head_parse(false), _strand(io_context.get_executor()){
    boost::uuids::uuid  a_uuid = boost::uuids::random_generator()();
    _uuid = boost::uuids::to_string(a_uuid);
    _recv_head_node = make_shared<MsgNode>(HEAD_TOTAL_LEN);
}
CSession::~CSession() {
    std::cout << "~CSession destruct" << endl;
}

tcp::socket& CSession::GetSocket() {
    return _socket;
}

std::string& CSession::GetUuid() {
    return _uuid;
}

void CSession::Start(){
    ::memset(_data, 0, MAX_LENGTH);

    _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
        boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, this,
            std::placeholders::_1, std::placeholders::_2, SharedSelf())));
}

void CSession::Send(std::string msg, short msgid) {
    std::lock_guard<std::mutex> lock(_send_lock);
    int send_que_size = _send_que.size();
    if (send_que_size > MAX_SENDQUE) {
        std::cout << "session: " << _uuid << " send que fulled, size is " << MAX_SENDQUE << endl;
        return;
    }

    _send_que.push(make_shared<SendNode>(msg.c_str(), msg.length(), msgid));
    if (send_que_size > 0) {
        return;
    }
    auto& msgnode = _send_que.front();
    boost::asio::async_write(_socket, boost::asio::buffer(msgnode->_data, msgnode->_total_len),
        boost::asio::bind_executor(_strand,std::bind(&CSession::HandleWrite, this, std::placeholders::_1, SharedSelf()))
        );
}

void CSession::Send(char* msg, short max_length, short msgid) {
    std::lock_guard<std::mutex> lock(_send_lock);
    int send_que_size = _send_que.size();
    if (send_que_size > MAX_SENDQUE) {
        std::cout << "session: " << _uuid << " send que fulled, size is " << MAX_SENDQUE << endl;
        return;
    }

    _send_que.push(make_shared<SendNode>(msg, max_length, msgid));
    if (send_que_size>0) {
        return;
    }
    auto& msgnode = _send_que.front();
    boost::asio::async_write(_socket, boost::asio::buffer(msgnode->_data, msgnode->_total_len), 
        boost::asio::bind_executor(_strand, std::bind(&CSession::HandleWrite, this, std::placeholders::_1, SharedSelf()))
        );
}

void CSession::Close() {
    _socket.close();
    _b_close = true;
}

std::shared_ptr<CSession>CSession::SharedSelf() {
    return shared_from_this();
}

void CSession::HandleWrite(const boost::system::error_code& error, std::shared_ptr<CSession> shared_self) {
    //增加异常处理
    try {
        if (!error) {
            std::lock_guard<std::mutex> lock(_send_lock);
            //cout << "send data " << _send_que.front()->_data+HEAD_LENGTH << endl;
            _send_que.pop();
            if (!_send_que.empty()) {
                auto& msgnode = _send_que.front();
                boost::asio::async_write(_socket, boost::asio::buffer(msgnode->_data, msgnode->_total_len),
                    boost::asio::bind_executor(_strand, std::bind(&CSession::HandleWrite, this, std::placeholders::_1, shared_self))
                    );
            }
        }
        else {
            std::cout << "handle write failed, error is " << error.what() << endl;
            Close();
            _server->ClearSession(_uuid);
        }
    }
    catch (std::exception& e) {
        std::cerr << "Exception code : " << e.what() << endl;
    }
    
}

void CSession::HandleRead(const boost::system::error_code& error, size_t  bytes_transferred, std::shared_ptr<CSession> shared_self){
    try {
        if (!error) {
            //已经移动的字符数
            int copy_len = 0;
            while (bytes_transferred > 0) {
                if (!_b_head_parse) {
                    //收到的数据不足头部大小
                    if (bytes_transferred + _recv_head_node->_cur_len < HEAD_TOTAL_LEN) {
                        memcpy(_recv_head_node->_data + _recv_head_node->_cur_len, _data + copy_len, bytes_transferred);
                        _recv_head_node->_cur_len += bytes_transferred;
                        ::memset(_data, 0, MAX_LENGTH);
                        _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
                            boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, this, std::placeholders::_1, std::placeholders::_2, shared_self))
                            );
                        return;
                    }
                    //收到的数据比头部多
                    //头部剩余未复制的长度
                    int head_remain = HEAD_TOTAL_LEN - _recv_head_node->_cur_len;
                    memcpy(_recv_head_node->_data + _recv_head_node->_cur_len, _data + copy_len, head_remain);
                    //更新已处理的data长度和剩余未处理的长度
                    copy_len += head_remain;
                    bytes_transferred -= head_remain;
                    //获取头部MSGID数据
                    short msg_id = 0;
                    memcpy(&msg_id, _recv_head_node->_data, HEAD_ID_LEN);
                    //网络字节序转化为本地字节序
                    msg_id = boost::asio::detail::socket_ops::network_to_host_short(msg_id);
                    std::cout << "msg_id is " << msg_id << endl;
                    //id非法
                    if (msg_id > MAX_LENGTH) {
                        std::cout << "invalid msg_id is " << msg_id << endl;
                        _server->ClearSession(_uuid);
                        return;
                    }
                    short msg_len = 0;
                    memcpy(&msg_len, _recv_head_node->_data+HEAD_ID_LEN, HEAD_DATA_LEN);
                    //网络字节序转化为本地字节序
                    msg_len = boost::asio::detail::socket_ops::network_to_host_short(msg_len);
                    std::cout << "msg_len is " << msg_len << endl;
                    //id非法
                    if (msg_len > MAX_LENGTH) {
                        std::cout << "invalid data length is " << msg_len << endl;
                        _server->ClearSession(_uuid);
                        return;
                    }

                    _recv_msg_node = make_shared<RecvNode>(msg_len, msg_id);

                    //消息的长度小于头部规定的长度,说明数据未收全,则先将部分消息放到接收节点里
                    if (bytes_transferred < msg_len) {
                        memcpy(_recv_msg_node->_data + _recv_msg_node->_cur_len, _data + copy_len, bytes_transferred);
                        _recv_msg_node->_cur_len += bytes_transferred;
                        ::memset(_data, 0, MAX_LENGTH);
                        _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
                            boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, 
                                this, std::placeholders::_1, std::placeholders::_2, shared_self))
                            );
                        //头部处理完成
                        _b_head_parse = true;
                        return;
                    }

                    memcpy(_recv_msg_node->_data + _recv_msg_node->_cur_len, _data + copy_len, msg_len);
                    _recv_msg_node->_cur_len += msg_len;
                    copy_len += msg_len;
                    bytes_transferred -= msg_len;
                    _recv_msg_node->_data[_recv_msg_node->_total_len] = '\0';
                    //cout << "receive data is " << _recv_msg_node->_data << endl;
                    //此处将消息投递到逻辑队列中
                    LogicSystem::GetInstance()->PostMsgToQue(make_shared<LogicNode>(shared_from_this(), _recv_msg_node));
                
                    //继续轮询剩余未处理数据
                    _b_head_parse = false;
                    _recv_head_node->Clear();
                    if (bytes_transferred <= 0) {
                        ::memset(_data, 0, MAX_LENGTH);
                        _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
                            boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, this,
                                std::placeholders::_1, std::placeholders::_2, shared_self))
                            );
                        return;
                    }
                    continue;
                }

                //已经处理完头部,处理上次未接受完的消息数据
                //接收的数据仍不足剩余未处理的
                int remain_msg = _recv_msg_node->_total_len - _recv_msg_node->_cur_len;
                if (bytes_transferred < remain_msg) {
                    memcpy(_recv_msg_node->_data + _recv_msg_node->_cur_len, _data + copy_len, bytes_transferred);
                    _recv_msg_node->_cur_len += bytes_transferred;
                    ::memset(_data, 0, MAX_LENGTH);
                    _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
                        boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, 
                            this, std::placeholders::_1, std::placeholders::_2, shared_self))
                        );
                    return;
                }
                memcpy(_recv_msg_node->_data + _recv_msg_node->_cur_len, _data + copy_len, remain_msg);
                _recv_msg_node->_cur_len += remain_msg;
                bytes_transferred -= remain_msg;
                copy_len += remain_msg;
                _recv_msg_node->_data[_recv_msg_node->_total_len] = '\0';
                //cout << "receive data is " << _recv_msg_node->_data << endl;
                //此处将消息投递到逻辑队列中
                LogicSystem::GetInstance()->PostMsgToQue(make_shared<LogicNode>(shared_from_this(), _recv_msg_node));
                
                //继续轮询剩余未处理数据
                _b_head_parse = false;
                _recv_head_node->Clear();
                if (bytes_transferred <= 0) {
                    ::memset(_data, 0, MAX_LENGTH);
                    _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
                        boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead, 
                            this, std::placeholders::_1, std::placeholders::_2, shared_self))
                        );
                    return;
                }
                continue;
            }
        }
        else {
            std::cout << "handle read failed, error is " << error.what() << endl;
            Close();
            _server->ClearSession(_uuid);
        }
    }
    catch (std::exception& e) {
        std::cout << "Exception code is " << e.what() << endl;
    }
}

LogicNode::LogicNode(shared_ptr<CSession>  session, 
    shared_ptr<RecvNode> recvnode):_session(session),_recvnode(recvnode) {
    
}

CServer.h

#pragma once
#include <boost/asio.hpp>
#include "CSession.h"
#include <memory.h>
#include <map>
#include <mutex>
using namespace std;
using boost::asio::ip::tcp;
class CServer
{
public:
    CServer(boost::asio::io_context& io_context, short port);
    ~CServer();
    void ClearSession(std::string);
private:
    void HandleAccept(shared_ptr<CSession>, const boost::system::error_code & error);
    void StartAccept();
    boost::asio::io_context &_io_context;
    short _port;
    tcp::acceptor _acceptor;
    std::map<std::string, shared_ptr<CSession>> _sessions;
    std::mutex _mutex;
    boost::asio::executor_work_guard<boost::asio::io_context::executor_type> work_guard_;
};

CServer.cpp

#include "CServer.h"
#include <iostream>
#include "AsioThreadPool.h"
CServer::CServer(boost::asio::io_context& io_context, short port):_io_context(io_context), _port(port),
_acceptor(io_context, tcp::endpoint(tcp::v4(),port)), work_guard_(boost::asio::make_work_guard(io_context))
{
    cout << "Server start success, listen on port : " << _port << endl;
    StartAccept();
}

CServer::~CServer() {
    cout << "Server destruct listen on port : " << _port << endl;
}

void CServer::HandleAccept(shared_ptr<CSession> new_session, const boost::system::error_code& error){
    if (!error) {
        new_session->Start();
        lock_guard<mutex> lock(_mutex);
        _sessions.insert(make_pair(new_session->GetUuid(), new_session));
    }
    else {
        cout << "session accept failed, error is " << error.what() << endl;
    }

    StartAccept();
}

void CServer::StartAccept() {
    shared_ptr<CSession> new_session = make_shared<CSession>(_io_context, this);
    _acceptor.async_accept(new_session->GetSocket(), std::bind(&CServer::HandleAccept, this, new_session, std::placeholders::_1));
}

void CServer::ClearSession(std::string uuid) {
    lock_guard<mutex> lock(_mutex);
    _sessions.erase(uuid);
}

AsioThreadPool.h

#pragma once
#include <boost/asio.hpp>
#include "Singleton.h"
class AsioThreadPool:public Singleton<AsioThreadPool>
{
public:
    friend class Singleton<AsioThreadPool>;
    ~AsioThreadPool(){}
    AsioThreadPool& operator=(const AsioThreadPool&) = delete;
    AsioThreadPool(const AsioThreadPool&) = delete;
    boost::asio::io_context& GetIOService();
    void Stop();
private:
    AsioThreadPool(int threadNum = std::thread::hardware_concurrency());
    boost::asio::io_context _service;
    std::unique_ptr<boost::asio::io_context::work> _work;
    std::vector<std::thread> _threads;

};

AsioThreadPool.cpp

#include "AsioThreadPool.h"

AsioThreadPool::AsioThreadPool(int threadNum ):_work(new boost::asio::io_context::work(_service)){
    for (int i = 0; i < threadNum; ++i) {
        _threads.emplace_back([this]() {
            _service.run();
            });
    }
}

boost::asio::io_context& AsioThreadPool::GetIOService() {
    return _service;
}

void AsioThreadPool::Stop() {
    _service.stop();
    _work.reset();
    for (auto& t : _threads) {
        t.join();
    }
    std::cout << "AsioThreadPool::Stop" << endl;
}

main.cpp

#include <iostream>
#include "CServer.h"
#include "Singleton.h"
#include "LogicSystem.h"
#include <csignal>
#include <thread>
#include <mutex>
#include "AsioThreadPool.h"
using namespace std;


int main()
{
    try {
        auto pool = AsioThreadPool::GetInstance();
        boost::asio::io_context io_context;
        boost::asio::signal_set signals(io_context, SIGINT, SIGTERM);
        signals.async_wait([pool,&io_context](auto, auto) {
            io_context.stop();
            pool->Stop();
            });

        CServer s(pool->GetIOService(), 10086);
    
        //Æô¶¯¼àÌýÍ˳öµÄio_context
        io_context.run();
        std::cout << "server exited ...." << std::endl;
    }
    catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << endl;
    }

}

具体详情:

https://llfc.club/category?catid=225RaiVNI8pFDD5L4m807g7ZwmF#!aid/2R2LjXaJqlXvTqHLY4034ojk5pa

代码

https://gitee.com/secondtonone1/boostasio-learn/tree/master/network/day15-ThreadPoolServer

先实现IOThreadPool

#include<boost/asio.hpp>
#include"Singleton.h"
classAsioThreadPool:public Singleton<AsioThreadPool>
{
public:
friend class Singleton<AsioThreadPool>;
~AsioThreadPool(){}
AsioThreadPool&operator=(const AsioThreadPool&)=delete;
AsioThreadPool(const AsioThreadPool&)=delete;
boost::asio::io_context&GetIOService();
void Stop();
private:
AsioThreadPool(int threadNum = std::thread::hardware_concurrency());
boost::asio::io_context _service;
std::unique_ptr<boost::asio::io_context::work> _work;
std::vector<std::thread> _threads;
};

AsioThreadPool继承了Singleton<AsioThreadPool>,实现了一个函数GetIOService获取iocontext

接下来我们看看具体实现

#include"AsioThreadPool.h"
AsioThreadPool::AsioThreadPool(int threadNum ):_work(new boost::asio::io_context::work(_service)){
    for(int i =0; i < threadNum;++i){
    _threads.emplace_back([this](){
    _service.run();
 });
 }
}
boost::asio::io_context&AsioThreadPool::GetIOService(){
   return _service;
}
void AsioThreadPool::Stop(){
     _work.reset();
     for(auto& t : _threads){
        t.join();
     }
}

 

构造函数中实现了一个线程池,线程池里每个线程都会运行_service.run函数,_service.run函数内部就是从iocp或者epoll获取就绪描述符和绑定的回调函数,进而调用回调函数,因为回调函数是在不同的线程里调用的,所以会存在不同的线程调用同一个socket的回调函数的情况。
_service.run 内部在Linux环境下调用的是epoll_wait返回所有就绪的描述符列表,在windows上会循环调用GetQueuedCompletionStatus函数返回就绪的描述符,二者原理类似,进而通过描述符找到对应的注册的回调函数,然后调用回调函数。
比如iocp的流程是这样的

  1. IOCP的使用主要分为以下几步:
  2. 1创建完成端口(iocp)对象
  3. 2创建一个或多个工作线程,在完成端口上执行并处理投递到完成端口上的I/O请求
  4. 3Socket关联iocp对象,在Socket上投递网络事件
  5. 4工作线程调用GetQueuedCompletionStatus函数获取完成通知封包,取得事件信息并进行处理

epoll的流程是这样的

  1. 1调用epoll_creat在内核中创建一张epoll
  2. 2开辟一片包含nepoll_event大小的连续空间
  3. 3将要监听的socket注册到epoll表里
  4. 4调用epoll_wait,传入之前我们开辟的连续空间,epoll_wait返回就绪的epoll_event列表,epoll会将就绪的socket信息写入我们之前开辟的连续空间

隐患

IOThreadPool模式有一个隐患,同一个socket的就绪后,触发的回调函数可能在不同的线程里,比如第一次是在线程1,第二次是在线程3,如果这两次触发间隔时间不大,那么很可能出现不同线程并发访问数据的情况,比如在处理读事件时,第一次回调触发后我们从socket的接收缓冲区读数据出来,第二次回调触发,还是从socket的接收缓冲区读数据,就会造成两个线程同时从socket中读数据的情况,会造成数据混乱。

利用strand改进

对于多线程触发回调函数的情况,我们可以利用asio提供的串行类strand封装一下,这样就可以被串行调用了,其基本原理就是在线程各自调用函数时取消了直接调用的方式,而是利用一个strand类型的对象将要调用的函数投递到strand管理的队列中,再由一个统一的线程调用回调函数,调用是串行的,解决了线程并发带来的安全问题。

https://cdn.llfc.club/_20230607192843.png

图中当socket就绪后并不是由多个线程调用每个socket注册的回调函数,而是将回调函数投递给strand管理的队列,再由strand统一调度派发。

为了让回调函数被派发到strand的队列,我们只需要在注册回调函数时加一层strand的包装即可。

在CSession类中添加一个成员变量

  1. strand<io_context::executor_type> _strand;

CSession的构造函数

  1. CSession::CSession(boost::asio::io_context& io_context,CServer* server):
  2. _socket(io_context), _server(server), _b_close(false),
  3. _b_head_parse(false), _strand(io_context.get_executor()){
  4. boost::uuids::uuid a_uuid = boost::uuids::random_generator()();
  5. _uuid = boost::uuids::to_string(a_uuid);
  6. _recv_head_node = make_shared<MsgNode>(HEAD_TOTAL_LEN);
  7. }

可以看到_strand的初始化是放在初始化列表里,利用io_context.get_executor()返回的执行器构造strand。

因为在asio中无论iocontext还是strand,底层都是通过executor调度的,我们将他理解为调度器就可以了,如果多个iocontext和strand的调度器是一个,那他们的消息派发统一由这个调度器执行。

我们利用iocontext的调度器构造strand,这样他们统一由一个调度器管理。在绑定回调函数的调度器时,我们选择strand绑定即可。

比如我们在Start函数里添加绑定 ,将回调函数的调用者绑定为_strand

  1. voidCSession::Start(){
  2. ::memset(_data,0, MAX_LENGTH);
  3. _socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
  4. boost::asio::bind_executor(_strand, std::bind(&CSession::HandleRead,this,
  5. std::placeholders::_1, std::placeholders::_2,SharedSelf())));
  6. }

同样的道理,在所有收发的地方,都将调度器绑定为_strand, 比如发送部分我们需要修改为如下

  1. auto& msgnode = _send_que.front();
  2. boost::asio::async_write(_socket, boost::asio::buffer(msgnode->_data, msgnode->_total_len),
  3. boost::asio::bind_executor(_strand, std::bind(&CSession::HandleWrite,this, std::placeholders::_1,SharedSelf()))
  4. );

回调函数的处理部分也做对应的修改即可。

 

按道理:

当然可以使用 _strand.wrap 替代 boost::asio::bind_executor_strand.wrap 实际上就是将执行器(strand)和处理函数绑定在一起,以确保处理函数在执行时受到执行器的保护,从而保证线程安全性。下面是使用 _strand.wrap 的修改版本:

cppCopy Code
_socket.async_read_some(boost::asio::buffer(_data, MAX_LENGTH),
    _strand.wrap(std::bind(&CSession::HandleRead, this,
    std::placeholders::_1, std::placeholders::_2, SharedSelf())));

在这里,_strand.wrap 将执行器 _strand 和处理函数 CSession::HandleRead 绑定在一起,确保在异步操作完成时,处理函数受到执行器的保护,并在执行时在执行器上同步执行。

实际上,_strand.wrapboost::asio::bind_executor 的作用是类似的,它们都用于将执行器(strand)与处理函数绑定在一起,以确保处理函数在执行时受到执行器的保护,并在执行时在执行器上同步执行。它们的区别在于使用方式和语法上的不同:

  1. boost::asio::bind_executor 是 Boost.Asio 提供的函数,用于将执行器与处理函数绑定。它的语法比较直观,需要提供执行器和处理函数,可以灵活地指定处理函数的参数。

  2. _strand.wrap 是执行器(strand)对象提供的成员函数,用于将执行器与处理函数绑定。它的使用方式更加简洁,直接在执行器对象上调用 wrap 函数并提供处理函数即可,无需额外导入 Boost.Asio 库。

虽然两者的作用相同,但在实际使用中,可以根据个人或团队的喜好和习惯选择使用哪种方式。

性能对比

为了比较两种服务器多线程模式的性能,我们还是利用之前测试的客户端,客户端每隔10ms建立一个连接,总共建立100个连接,每个连接收发500次,总计10万个数据包,测试一下性能。

客户端测试代码如下

  1. #include<iostream>
  2. #include<boost/asio.hpp>
  3. #include<thread>
  4. #include<json/json.h>
  5. #include<json/value.h>
  6. #include<json/reader.h>
  7. #include<chrono>
  8. usingnamespace std;
  9. usingnamespace boost::asio::ip;
  10. constint MAX_LENGTH =1024*2;
  11. constint HEAD_LENGTH =2;
  12. constint HEAD_TOTAL =4;
  13. std::vector<thread> vec_threads;
  14. int main()
  15. {
  16. auto start = std::chrono::high_resolution_clock::now();// 获取开始时间
  17. for(int i =0; i <100; i++){
  18. vec_threads.emplace_back([](){
  19. try{
  20. //创建上下文服务
  21. boost::asio::io_context ioc;
  22. //构造endpoint
  23. tcp::endpoint remote_ep(address::from_string("127.0.0.1"),10086);
  24. tcp::socket sock(ioc);
  25. boost::system::error_code error = boost::asio::error::host_not_found;;
  26. sock.connect(remote_ep, error);
  27. if(error){
  28. cout <<"connect failed, code is "<< error.value()<<" error msg is "<< error.message();
  29. return0;
  30. }
  31. int i =0;
  32. while(i <500){
  33. Json::Value root;
  34. root["id"]=1001;
  35. root["data"]="hello world";
  36. std::string request = root.toStyledString();
  37. size_t request_length = request.length();
  38. char send_data[MAX_LENGTH]={0};
  39. int msgid =1001;
  40. int msgid_host = boost::asio::detail::socket_ops::host_to_network_short(msgid);
  41. memcpy(send_data,&msgid_host,2);
  42. //转为网络字节序
  43. int request_host_length = boost::asio::detail::socket_ops::host_to_network_short(request_length);
  44. memcpy(send_data +2,&request_host_length,2);
  45. memcpy(send_data +4, request.c_str(), request_length);
  46. boost::asio::write(sock, boost::asio::buffer(send_data, request_length +4));
  47. cout <<"begin to receive..."<< endl;
  48. char reply_head[HEAD_TOTAL];
  49. size_t reply_length = boost::asio::read(sock, boost::asio::buffer(reply_head, HEAD_TOTAL));
  50. msgid =0;
  51. memcpy(&msgid, reply_head, HEAD_LENGTH);
  52. short msglen =0;
  53. memcpy(&msglen, reply_head +2, HEAD_LENGTH);
  54. //转为本地字节序
  55. msglen = boost::asio::detail::socket_ops::network_to_host_short(msglen);
  56. msgid = boost::asio::detail::socket_ops::network_to_host_short(msgid);
  57. char msg[MAX_LENGTH]={0};
  58. size_t msg_length = boost::asio::read(sock, boost::asio::buffer(msg, msglen));
  59. Json::Reader reader;
  60. reader.parse(std::string(msg, msg_length), root);
  61. std::cout <<"msg id is "<< root["id"]<<" msg is "<< root["data"]<< endl;
  62. i++;
  63. }
  64. }
  65. catch(std::exception& e){
  66. std::cerr <<"Exception: "<< e.what()<< endl;
  67. }
  68. });
  69. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  70. }
  71. for(auto& t : vec_threads){
  72. t.join();
  73. }
  74. // 执行一些需要计时的操作
  75. auto end = std::chrono::high_resolution_clock::now();// 获取结束时间
  76. auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);// 计算时间差,单位为微秒
  77. std::cout <<"Time spent: "<< duration.count()<<" seconds."<< std::endl;// 输
  78. getchar();
  79. return0;
  80. }

我们先启动之前实现的AsioIOServicePool多线程模式的服务器测试,10万个数据包收发完成总计46秒

https://cdn.llfc.club/1686192334657.jpg

接下来我们启动ASIOThreadPool多线程模式的服务器测试,10万个数据包收发完成总计53秒

https://cdn.llfc.club/1686193554693.jpg

可以看出今天实现的多线程模式较之前的IOServicePool版本慢了7秒

posted @ 2023-07-31 20:08  白伟碧一些小心得  阅读(1294)  评论(0编辑  收藏  举报