基于epoll的TP传输层实现

1. 抽象TP传输层设计

  在使用epoll实现实际的传输层之前,先设计一个抽象的传输层,这个抽象的传输层是传输层实现的接口层。

  接口层中一共有以下几个通用的类或者接口:

(1)Socket:通用的套接字层,用于封装本地套接字,同时会在析构时自动关闭套接字,避免资源泄漏

(2)DataSink:通用的数据接收层,当传输层接收到数据时,会通过用户定义的DataSink对象传输到外部

(3)IStream:通用的数据流程,代表可读/写的字节流类的接口

(4)IConnectable:一个接口,表示可以链接到其它服务器

(5)BasicServer:基本的服务器类,继承了Socket类

(6)BasicStream:基本的数据流泪,继承IStream和Socket类

1.1 抽象层类图

1.2 Socket类实现

  #ifndef SOCKET
    #define SOCKET int32_t
    #endif
    
    typedef SOCKET NativeSocket;

    class Socket
    {
    public:
        Socket() : _nativeSocket(0)
        {

        }

        Socket(NativeSocket nativeSocket)
            :_nativeSocket(nativeSocket)
        {

        }

        virtual ~Socket()
        {
            close(_nativeSocket);
        }

        NativeSocket GetNativeSocket() const {return _nativeSocket;}
        void SetNativeSocket(NativeSocket nativeSocket) {_nativeSocket = nativeSocket;}

        private:
            NativeSocket _nativeSocket;
    };

1.3 DataSink类实现

  class DataSink
    {
    public:
        virtual int32_t OnDataIndication(IStream *stream, const char *buf, int64_t bytes) = 0;
    };

1.4 IStream类实现

  class IStream
    {
    public:
        typedef std::function<void(const char* buf, int64_t size)> DataIndicationHandler;
        virtual int32_t Receive(char *buffer, int32_t buffSize, int32_t& readSize) = 0;
        virtual int32_t Send(const ByteArray& byteArray) = 0;

        virtual void OnDataIndication(DataIndicationHandler handler) = 0;
        virtual DataIndicationHandler GetDataIndication() = 0;
    };

1.5 BasicStream类实现

  class BasicStream : public Socket, public IStream
    {
    public:
        BasicStream(NativeSocket nativeSocket) : Socket(nativeSocket) {}

        BasicStream(const BasicStream& stream) = delete;

        virtual void SetDataSink(DataSink* dataSink) {
            _dataSink = dataSink;
        }

        virtual DataSink* GetDataSink() {
            return _dataSink;
        }

        virtual const DataSink* GetDataSink() const {
            return _dataSink;
        }

    private:
        DataSink* _dataSink;
    };

1.6 IConnectable类实现

  class IConnectable
    {
    public:
        virtual void Connect(const std::string& host, int32_t port) = 0;
    };

1.7 BasicServer类实现

  template <class ConnectionType>
    class BasicServer : public Socket
    {
        public:
        typedef std::function<void(IStream* stream)> ConnectIndicationHandler;
        typedef std::function<void(IStream* stream)> DisconnectIndicationHandler;

        BasicServer() { }

        virtual int32_t Listen(const std::string& host, int32_t port, int backlog) = 0;
        virtual void OnConnectIndication(ConnectIndicationHandler handler) = 0;
        virtual void OnDisconnectIndication(DisconnectIndicationHandler handler) = 0;

        virtual ConnectionType Accept(int32_t listenfd) = 0;           
    };

2. 基于epoll实现服务器和客户端

   在前面的内容中已经完成了抽象TP传输层和基础工具(消息队列、线程池、缓冲区抽象、事件循环和日志工具)的实现,接下来在抽象TP传输层和基础工具的基础上完成基于epoll机制服务器和客户端的实现。

2.1 类图设计

  

2.2 EpollSream代码实现

#pragma once

#include <string>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>

#include "../include/Net.h"

namespace meshy
{
    class EpollStream : public BasicStream
    {
    public:
        EpollStream(NativeSocket nativeSocket) : BasicStream(nativeSocket){}

        int32_t Send(const ByteArray& byteArray) override
        {
            const char *buf = byteArray.data();
            int32_t size = byteArray.size();
            int32_t n = size;
            while(n > 0)
            {
                int32_t writeSize = write(GetNativeSocket(), buf + size - n, n);
                if(writeSize < n)
                {
                    if (writeSize == -1 && errno != EAGAIN) {
                        TRACE_ERROR("FATAL write data to peer failed!");
                    }
                    break;
                }

                n -= writeSize;
            }

            return size;
            
        }
        
        int32_t Receive(char *buffer, int32_t buffSize, int32_t& readSize) override
        {
            readSize = 0;
            int32_t nread = 0;
            //NativeSocketEvent ev;

            while ((nread = read(GetNativeSocket(), buffer + readSize, buffSize - 1)) > 0) {
                readSize += nread;
            }

            return nread;
        }

        void OnDataIndication(DataIndicationHandler handler) override
        {

        }

        DataIndicationHandler GetDataIndication() override
        {
            return _handler;
        }

    private:
        DataIndicationHandler _handler;
    };

    typedef shared_ptr<EpollStream> EpollStreamPtr;
}
EpollStream.h

2.3 EpollServer代码实现

#pragma once 

#include <string>

#include "../include/Net.h"
#include "EpollStream.h"

namespace meshy
{
    class EpollServer : public BasicServer<EpollStreamPtr>
    {
    public:
        EpollServer(){}
        virtual int32_t Listen(const std::string& host, int32_t port, int backlog = 1) override;
        virtual EpollStreamPtr Accept(int32_t eventfd); 

        virtual void OnConnectIndication(ConnectIndicationHandler handler) override
        {
            _connectHandler = handler;
        }
        virtual void OnDisconnectIndication(DisconnectIndicationHandler handler) override
        {
            _disconnectIndication = handler;
        }

    private:
        int32_t _SetNonBlocking(int32_t fd);
        int32_t _bind(const std::string& host, int32_t port);

        ConnectIndicationHandler _connectHandler;
        DisconnectIndicationHandler _disconnectIndication;
    };
}
EpollServer.h
#include "EpollServer.h"
#include "EpollLoop.h"

#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>  
#include <sys/socket.h>

namespace meshy
{
    int32_t EpollServer::_bind(const std::string& host, int32_t port)
    {
        int32_t listenfd;
        if ((listenfd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
            TRACE_ERROR("Create socket failed!");
            return -1;
        }

        int32_t option = 1;
        setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &option, sizeof(option));
        _SetNonBlocking(listenfd);
        SetNativeSocket(listenfd);

        sockaddr_in addr;
        bzero(&addr, sizeof(addr));
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        addr.sin_addr.s_addr = inet_addr(host.c_str());

        int32_t errorCode = bind(listenfd, (struct sockaddr *) &addr, sizeof(addr));
        if (errorCode < 0) {
            TRACE_ERROR("Bind socket failed!");
            return errorCode;
        }
        
        return 0;
    }

    int32_t EpollServer::Listen(const std::string& host, int32_t port, int backlog)
    {
        _bind(host, port);

        int32_t listenfd = GetNativeSocket();

        int32_t errorCode = listen(listenfd, backlog);
        if (-1 == errorCode) {
            TRACE_ERROR("Listen socket failed!");
            return errorCode;
        }
        TRACE_DEBUG("Listen Success");

        errorCode = EpollLoop::Get()->AddEpollEvents(EPOLLIN | EPOLLET, listenfd);
        if(errorCode < 0)
        {
            TRACE_ERROR("epoll_ctl faild : listen socket");
            return errorCode;
        }
        
        EpollLoop::Get()->AddServer(listenfd, this);
        return 0;
    }

    EpollStreamPtr EpollServer::Accept(int32_t eventfd)
    {
        sockaddr_in addr;
        int32_t addrlen;
        memset(&addr, 0, sizeof(sockaddr_in));
        NativeSocket connSocket;
        NativeSocket listenfd = GetNativeSocket();
        
        while((connSocket = accept(listenfd, (sockaddr*)&addr, (socklen_t*)&addrlen)) > 0)
        {
            // 得到客户端IP
            getpeername(connSocket, (struct sockaddr *)&addr, (socklen_t*)&addrlen);

            std::ostringstream os;
            os.str("");
            os << "client:" << inet_ntoa(addr.sin_addr) << "connect success";
            std::string strInfo = os.str();
            TRACE_DEBUG(strInfo);

            _SetNonBlocking(connSocket);
            int32_t errorCode = EpollLoop::Get()->AddEpollEvents(EPOLLIN | EPOLLET, connSocket);
            if(errorCode < 0)
            {
                TRACE_ERROR("epoll_ctl faild : conn socket");
            }

            EpollStreamPtr connection = std::make_shared<EpollStream>(connSocket);

            //if ( _connectHandler ) {
               // _connectHandler(connection.get());
            //}

            return connection;
        } 

        if (connSocket == -1) 
        {
           TRACE_ERROR("Accept error");
        }

        return EpollStreamPtr(nullptr);
    }

    // 设置阻塞模式
    int32_t EpollServer::_SetNonBlocking(int32_t fd)
    {
        int32_t opts = fcntl(fd, F_GETFL);
        if(opts < 0) {
            TRACE_ERROR("fcntl(F_GETFL)");
            return -1;
        }
        opts = (opts | O_NONBLOCK);
        if(fcntl(fd, F_SETFL, opts) < 0)
        {
            TRACE_ERROR("fcntl(F_SETFL)");
            return -1;
        }

        return 0;
    }
}
EpollServer.cpp

2.4 EpollClient代码实现

// 基于Epoll实现的客户端
#include "../incldue/Net.h"

#include <memory>

class EpollClient : public IConnectable, public BasicStream
{
    typdef std::shared_ptr<EpollClient*> EpollClientPtr;

public:
    EpollClient(const EPollClient&) = delete;
    EpollClient& operator=(const EpollClient&) = delete;

    virtual int32_t Receive(char *buffer, int32_t buffSize, int32_t& readSize) override;
    virtual int32_t Send(const ByteArray& byteArray) override;

    virtual void OnConnectIndication(ConnectIndicationHandler handler) override
    {
        _connectHandler = handler;
    }
    virtual void OnDisconnectIndication(DisconnectIndicationHandler handler) override
    {
        _disconnectIndication = handler;
    }

    virtual void Connect(const std::string& host, int32_t port) override;
    static EpollClientPtr Connect(const std::string& host, int32_t port);

protected:
    EpollClient(NativeSocket nativeSocket)
        :BasicStream(nativeSocket)
    {

    }

    int32_t _SetNonBlocking(int32_t fd);
    
private:
    ConnectIndicationHandler _connectHandler;
    DisconnectIndicationHandler _disconnectIndication;
}
EpollClient.h
#include "EpollClient.h"
#include "EpollLoop.h"

#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>  
#include <sys/socket.h>

int32_t EpollClient::Receive(char *buffer, int32_t buffSize, int32_t& readSize)
{
    readSize = 0;
    int32_t nread = 0;
    //NativeSocketEvent ev;

    while ((nread = read(GetNativeSocket(), buffer + readSize, buffSize - 1)) > 0) {
        readSize += nread;
    }

    return nread;
}

int32_t EpollClient::Send(const ByteArray& byteArray)
{
    const char *buf = byteArray.data();
    int32_t size = byteArray.size();
    int32_t n = size;
    while(n > 0)
    {
        int32_t writeSize = write(GetNativeSocket(), buf + size - n, n);
        if(writeSize < n)
        {
            if (writeSize == -1 && errno != EAGAIN) {
                TRACE_ERROR("FATAL write data to peer failed!");
            }
            break;
        }

        n -= writeSize;
    }

    return size;
}

void EpollClient::Connect(const std::string& host, int32_t port)
{
    sockaddr_in addr;
    bzero(&addr, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);
    addr.sin_addr.s_addr = inet_addr(host.c_str());
    _SetNonBlocking(GetNativeSocket());

    if (connect(sock, (sockaddr *)&addr, sizeof(addr)) < 0)
    {
        TREACE_ERROR("connect faild");
        return;
    }
}

EpollClientPtr EpollClient::Connect(const std::string& host, int32_t port)
{
    NativeSocket sock = socket(AF_INET, SOCK_STREAM, 0);
    if(sock < 0)
    {
        TRACE_ERROR("EpollClient socket faild");
        retrun nullptr;
    }
    EpollClientPtr client = EpollStreamPtr(new EpollClient(sock));
    client->Connect(host, port);

    EpollLoop *pEpollLoop = EpollLoop::Get();
    if(pEpollLoop)
    {
        if(pEpollLoop->AddEpollEvents(EPOLLIN | EPOLLET, sock) < 0)
        {
            TRACE_ERROR("cleint AddEpollEvents faild");
            return nullptr;
         }
    }

    pEpollLoop->AddStream(client);

    return client;
}

// 设置阻塞模式
int32_t EpollServer::_SetNonBlocking(int32_t fd)
{
    int32_t opts = fcntl(fd, F_GETFL);
    if(opts < 0) {
        TRACE_ERROR("fcntl(F_GETFL)");
        return -1;
    }
    opts = (opts | O_NONBLOCK);
    if(fcntl(fd, F_SETFL, opts) < 0)
    {
        TRACE_ERROR("fcntl(F_SETFL)");
        return -1;
    }

    return 0;
}
EpollClient.cpp

2.5 EpollLoop事件循环代码实现

#pragma once

#include "../include/Loop.h"
#include "../include/Logger.h"

#include "EpollServer.h"

#include <stdint.h>
#include <map>
#include <stdint.h>
#include <sys/epoll.h>

namespace meshy
{
    #define FD_SIZE 1024
    #define BUFF_SIZE 1024
    
    class EpollLoop : public Loop
    {
    public:
        static EpollLoop* Get();
        virtual ~EpollLoop();

        int32_t AddEpollEvents(int32_t events, int32_t fd); // 第二个参数为要监听的描述符
        int32_t ModifyEpollEvents(int32_t events, int32_t fd);
        int32_t DelEpollEvents(int32_t events, int32_t fd);

        void AddServer(NativeSocket socket, EpollServer* server);
        void AddStream(EpollStreamPtr stream);

    protected:
        EpollLoop();
        void _Run() override;
        void _EpollThread();
        void _HandleEvent(int32_t eventfd, struct epoll_event* events, int32_t nfds);
        void _Initialize();
        int32_t _Accept(int32_t eventfd, int32_t listenfd);
        void _Read(int32_t eventfd, int32_t fd, uint32_t events);

    private:
        int32_t _eventfd;
        bool _shutdown;

        std::map<NativeSocket, EpollServer*> _servers;
        std::map <NativeSocket, EpollStreamPtr> _streams;
    };
}
EpollLoop.h
#include "EpollLoop.h"
#include "../include/util.h"

#include <thread>

namespace meshy
{
    EpollLoop* EpollLoop::Get()
    {
        static EpollLoop epoolLoop;
        return &epoolLoop;
    }

    EpollLoop::EpollLoop()
    {
        _Initialize();
        _shutdown = false;
    }

    EpollLoop::~EpollLoop()
    {
        _shutdown = true;
    }

    int32_t EpollLoop::AddEpollEvents(int32_t events, int32_t fd)
    {
        epoll_event ev;
        ev.events = events;
        ev.data.fd = fd;

        return epoll_ctl(_eventfd, EPOLL_CTL_ADD, fd, &ev);
    }

    int32_t EpollLoop::ModifyEpollEvents(int32_t events, int32_t fd)
    {
        epoll_event event;
        event.events = events;
        event.data.fd = fd;

        return epoll_ctl(_eventfd, EPOLL_CTL_MOD, fd, &event);
    }

    int32_t EpollLoop::DelEpollEvents(int32_t events, int32_t fd)
    {
        epoll_event event;
        event.events = events;
        event.data.fd = fd;

        return epoll_ctl(_eventfd, EPOLL_CTL_DEL, fd, &event);
    }

    void EpollLoop::_Run()
    {
        std::thread ThreadProc(&EpollLoop::_EpollThread, this);
        ThreadProc.detach();
    }

    void EpollLoop::_EpollThread()
    {
        TRACE_DEBUG("_EPollThread");
        epoll_event events[FD_SIZE];

        while (!_shutdown) 
        {
            int32_t nfds = epoll_wait(_eventfd, events, FD_SIZE, -1);
            if (-1 == nfds) 
            {
                TRACE_ERROR("FATAL epoll_wait failed!");
                exit(EXIT_FAILURE);
            }

            _HandleEvent(_eventfd, events, nfds);
        }
    }

    void EpollLoop::_HandleEvent(int32_t eventfd, struct epoll_event* events, int32_t nfds)
    {
        for (int i = 0; i < nfds; ++i) 
        {
            int32_t fd;
            fd = events[i].data.fd;  // 得到监听的描述符

            if (_servers.find(fd) != _servers.end()) 
            {
                _Accept(eventfd, fd);
                continue;
            }

            if (events[i].events & EPOLLIN) 
            {
                _Read(eventfd, fd, events[i].events);
            }

            if (events[i].events & EPOLLOUT) 
            {
            }
        }
    }

    void EpollLoop::_Initialize()
    {
        _eventfd = epoll_create(FD_SIZE);
        if(_eventfd < 0)
        {
            TRACE_ERROR("epoll_create failed");
        }
        std::ostringstream os;
        os.str("");
        os << "epoll_create success, fdsize:" << FD_SIZE;
        std::string strInfo = os.str();
        TRACE_DEBUG(strInfo);
    }

    int32_t EpollLoop::_Accept(int32_t eventfd, int32_t listenfd)
    {
        TRACE_DEBUG("_Accept");
        EpollServer* server = _servers.find(listenfd)->second;
        EpollStreamPtr connection = server->Accept(eventfd);

        if (connection != nullptr) {
           // _streams[connection->GetNativeSocket()] = connection;
           AddStream(connection);
            return 0;
        }

        return -1;
    }

    void EpollLoop::_Read(int32_t eventfd, int32_t fd, uint32_t events)
    {
        TRACE_DEBUG("_Read");

        EpollStreamPtr stream = _streams[fd];

        char buffer[BUFF_SIZE] = {0};
        int32_t readSize;
        int32_t nread = stream->Receive(buffer, BUFF_SIZE, readSize);

        //stream->SetEvents(events);

        if ((nread == -1 && errno != EAGAIN) || readSize == 0) {
            _streams.erase(fd);

            // Print error message
            char message[50];
            sprintf(message, "errno: %d: %s, nread: %d, n: %d", errno, strerror(errno), nread, readSize);
            TRACE_WARNING(message);
            return;
        }
       
        char utf8_buff[BUFF_SIZE] = {0};
        int32_t destreadsize = GB2312ToUTF8(buffer, utf8_buff, BUFF_SIZE);
        std::ostringstream os;
        os.str("");
        os << "srcreadSize:" << readSize << "--destreadsize:" << BUFF_SIZE - destreadsize;
        std::string strInfo = os.str();
        TRACE_DEBUG(strInfo);

        TRACE_INFO(std::string(utf8_buff));
        // Write buf to the receive queue.
        //_Enqueue(stream, buffer, readSize);
    }

    void EpollLoop::AddServer(NativeSocket socket, EpollServer* server)
    {
        _servers.insert(std::pair<NativeSocket, EpollServer*>(socket, server));
    }
            
    void EpollLoop::AddStream(EpollStreamPtr stream)
    {
        _streams[stream->GetNativeSocket()] = stream;
    }
}
EpollLoop.cpp

2.6 main.cpp测试

#include "stdio.h"

#include "./epoll/EpollServer.h"
#include "./epoll/EpollLoop.h"

#include <thread>
#include <chrono>

int main()
{
    meshy::EpollLoop *pEpollLoop = meshy::EpollLoop::Get();
    pEpollLoop->Start();

    std::string ip = "192.168.1.40";
    int32_t port = 1122;
    meshy::EpollServer server;
    server.Listen(ip, port, 1); 

    while (1)
    {
        std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    }
    return 0;
}

2.7 makefile

#!/bin/bash

all: out

COMPILE=g++
FLAGS=-Wall -g -O0 -std=c++11 -lpthread

out: main.o Logger.o EpollServer.o EpollLoop.o
    $(COMPILE) $(FLAGS) -o out main.o Logger.o EpollServer.o EpollLoop.o

main.o: main.cpp
    $(COMPILE) $(FLAGS) -c main.cpp 

Logger.o: Logger.cpp
    $(COMPILE) $(FLAGS) -c Logger.cpp 

EpollServer.o: ./epoll/EpollServer.cpp
    $(COMPILE) $(FLAGS) -c ./epoll/EpollServer.cpp

EpollLoop.o: ./epoll/EpollLoop.cpp
    $(COMPILE) $(FLAGS) -c ./epoll/EpollLoop.cpp

clean:
    rm *.o
    rm out
代码下载地址:https://pan.baidu.com/s/1P4wM6pCz-S8Dbtf-4xKlqA
posted @ 2018-08-06 10:22  Fate0729  阅读(500)  评论(0编辑  收藏  举报