Boost::Asioベースの高性能サーバーパッケージ---windows、linux、macプラットフォーム間サポート!(C++11、14、17)

20455 ワード

直接コード:
/**
* @file z_asio.h
* @the header file of C++ network.
* @author dqsjqian Mr.Zhang
* @mail [email protected]
* @date May 19 2019
*/

#pragma once

#include 
#include 

#include 
#include 
#include 
#include 

#include 

namespace z_asio {

    using namespace asio;
    using namespace asio::ip;

    class SharedSocket
    {
    public:
        using Ptr = std::shared_ptr;

        static SharedSocket::Ptr Make(tcp::socket socket, asio::io_service& ioService)
        {
            return std::make_shared(std::move(socket), ioService);
        }

        SharedSocket(tcp::socket socket, asio::io_service& ioService)
            :
            mSocket(std::move(socket)),
            mIoService(ioService)
        {
        }

        virtual ~SharedSocket() = default;

        tcp::socket& socket()
        {
            return mSocket;
        }

        asio::io_service& service()
        {
            return mIoService;
        }

    private:
        tcp::socket         mSocket;
        asio::io_service&   mIoService;
    };

    class WrapperIoService : public asio::noncopyable
    {
    public:
        using Ptr = std::shared_ptr;

        WrapperIoService(int concurrencyHint)
            :
            mTrickyIoService(std::make_shared<:io_service>(concurrencyHint)),
            mIoService(*mTrickyIoService)
        {
        }

        WrapperIoService(asio::io_service& ioService)
            :
            mIoService(ioService)
        {}

        virtual ~WrapperIoService()
        {
            stop();
        }

        void    run()
        {
            asio::io_service::work worker(mIoService);
            for (; !mIoService.stopped();)
            {
                mIoService.run();
            }
        }

        void    stop()
        {
            mIoService.stop();
        }

        asio::io_service& service()
        {
            return mIoService;
        }

        auto    runAfter(std::chrono::nanoseconds timeout, std::function callback)
        {
            auto timer = std::make_shared<:steady_timer>(mIoService);
            timer->expires_from_now(timeout);
            timer->async_wait([timer, callback](const asio::error_code & ec) {
                    if (!ec)
                    {
                        callback();
                    }
                });
            return timer;
        }

    public:
        std::shared_ptr<:io_service>   mTrickyIoService;
        asio::io_service&                   mIoService;
    };

    class IoServiceThread : public asio::noncopyable
    {
    public:
        using Ptr = std::shared_ptr;

        IoServiceThread(int concurrencyHint)
            :
            mWrapperIoService(concurrencyHint)
        {
        }

        virtual ~IoServiceThread()
        {
            stop();
        }

        void    start(size_t threadNum)
        {
            std::lock_guard<:mutex> lck(mIoThreadGuard);
            if (threadNum == 0)
            {
                throw std::runtime_error("thread num is zero");
            }
            if (!mIoThreads.empty())
            {
                return;
            }
            for (size_t i = 0; i < threadNum; i++)
            {
                mIoThreads.push_back(std::thread([this]() {
                        mWrapperIoService.run();
                    }));
            }
        }

        void    stop()
        {
            std::lock_guard<:mutex> lck(mIoThreadGuard);

            mWrapperIoService.stop();
            for (auto& thread : mIoThreads)
            {
                try
                {
                    thread.join();
                }
                catch (...)
                {
                }
            }
            mIoThreads.clear();
        }

        asio::io_service& service()
        {
            return mWrapperIoService.service();
        }

        WrapperIoService& wrapperIoService()
        {
            return mWrapperIoService;
        }

    private:
        WrapperIoService            mWrapperIoService;
        std::vector<:thread>    mIoThreads;
        std::mutex                  mIoThreadGuard;
    };

    class IoServicePool : public asio::noncopyable
    {
    public:
        using Ptr = std::shared_ptr;

        IoServicePool(size_t poolSize,
            int concurrencyHint)
            :
            mPickIoServiceIndex(0)
        {
            if (poolSize == 0)
            {
                throw std::runtime_error("pool size is zero");
            }

            for (size_t i = 0; i < poolSize; i++)
            {
                mIoServicePool.emplace_back(std::make_shared(concurrencyHint));
            }
        }

        virtual ~IoServicePool()
        {
            stop();
        }

        void  start(size_t threadNumEveryService)
        {
            std::lock_guard<:mutex> lck(mPoolGuard);

            for (const auto& service : mIoServicePool)
            {
                service->start(threadNumEveryService);
            }
        }

        void  stop()
        {
            std::lock_guard<:mutex> lck(mPoolGuard);

            for (const auto& service : mIoServicePool)
            {
                service->stop();
            }
            mIoServicePool.clear();
        }

        asio::io_service& pickIoService()
        {
            auto index = mPickIoServiceIndex.fetch_add(1, std::memory_order::memory_order_relaxed);
            return mIoServicePool[index % mIoServicePool.size()]->service();
        }

        std::shared_ptr pickIoServiceThread()
        {
            auto index = mPickIoServiceIndex.fetch_add(1, std::memory_order::memory_order_relaxed);
            return mIoServicePool[index % mIoServicePool.size()];
        }

    private:
        std::vector<:shared_ptr>>   mIoServicePool;
        std::mutex                                      mPoolGuard;
        std::atomic_int32_t                             mPickIoServiceIndex;
    };

    class AsioTcpConnector : public asio::noncopyable
    {
    public:
        using Ptr = std::shared_ptr;

        AsioTcpConnector(IoServicePool::Ptr ioServicePool)
            :
            mIoServicePool(ioServicePool)
        {
        }

        void    asyncConnect(
            asio::ip::tcp::endpoint endpoint,
            std::chrono::nanoseconds timeout,
            std::function callback,
            std::function failedCallback)
        {
            wrapperAsyncConnect(mIoServicePool->pickIoServiceThread(), { endpoint }, timeout, callback, failedCallback);
        }

        void    asyncConnect(
            std::shared_ptr ioServiceThread,
            asio::ip::tcp::endpoint endpoint,
            std::chrono::nanoseconds timeout,
            std::function callback,
            std::function failedCallback)
        {
            wrapperAsyncConnect(ioServiceThread, { endpoint }, timeout, callback, failedCallback);
        }

    private:
        void    wrapperAsyncConnect(
            IoServiceThread::Ptr ioServiceThread,
            std::vector<:ip::tcp::endpoint> endpoints,
            std::chrono::nanoseconds timeout,
            std::function callback,
            std::function failedCallback)
        {
            auto sharedSocket = SharedSocket::Make(tcp::socket(ioServiceThread->service()), ioServiceThread->service());
            auto timeoutTimer = ioServiceThread->wrapperIoService().runAfter(timeout, [=]() {
                    sharedSocket->socket().close();
                    failedCallback();
                });

            asio::async_connect(sharedSocket->socket(),
                endpoints,
                [=](std::error_code ec, tcp::endpoint) {
                    timeoutTimer->cancel();
                    if (!ec)
                    {
                        callback(sharedSocket);
                    }
                    else
                    {
                        failedCallback();
                    }
                });
        }

    private:
        IoServicePool::Ptr mIoServicePool;
    };

    class AsioTcpAcceptor : public asio::noncopyable, public std::enable_shared_from_this
    {
    public:
        using Ptr = std::shared_ptr< AsioTcpAcceptor>;

        AsioTcpAcceptor(
            asio::io_service& listenService,
            IoServicePool::Ptr ioServicePool,
            ip::tcp::endpoint endpoint)
            :
            mIoServicePool(ioServicePool),
            mAcceptor(std::make_shared< ip::tcp::acceptor>(listenService, endpoint))
        {
        }

        virtual ~AsioTcpAcceptor()
        {
            mAcceptor->close();
        }

        void    startAccept(std::function callback)
        {
            doAccept(callback);
        }

    private:
        void    doAccept(std::function callback)
        {
            auto& ioService = mIoServicePool->pickIoService();
            auto sharedSocket = SharedSocket::Make(tcp::socket(ioService), ioService);

            auto self = shared_from_this();
            mAcceptor->async_accept(
                sharedSocket->socket(),
                [self, callback, sharedSocket, this](std::error_code ec) {
                    if (!ec)
                    {
                        sharedSocket->service().post([=]() {
                                callback(sharedSocket);
                            });
                    }
                    doAccept(callback);
                });
        }

    private:
        IoServicePool::Ptr                  mIoServicePool;
        std::shared_ptr< ip::tcp::acceptor> mAcceptor;
    };

    class AsioTcpSession : public asio::noncopyable, public std::enable_shared_from_this< AsioTcpSession>
    {
    public:
        using DataCB = std::function;
        using Ptr = std::shared_ptr;

        static Ptr Make(
            SharedSocket::Ptr socket,
            size_t maxRecvBufferSize,
            DataCB cb)
        {
            struct make_shared_enabler : public AsioTcpSession
            {
            public:
                make_shared_enabler(
                    SharedSocket::Ptr socket,
                    size_t maxRecvBufferSize,
                    DataCB cb)
                    :
                    AsioTcpSession(std::move(socket), maxRecvBufferSize, std::move(cb))
                {}
            };
            auto session = std::make_shared(
                std::move(socket), 
                maxRecvBufferSize, 
                std::move(cb));
            session->startRecv();
            return session;
        }

        virtual ~AsioTcpSession() = default;

        void    send(std::shared_ptr<:string> msg)
        {
            {
                std::lock_guard<:mutex> lck(mSendGuard);
                mPendingSendMsg.push_back({ 0, std::move(msg) });
            }
            trySend();
        }

        void    send(std::string msg)
        {
            send(std::make_shared<:string>(std::move(msg)));
        }

        const SharedSocket::Ptr& socket() const
        {
            return mSocket;
        }

    private:
        AsioTcpSession(
            SharedSocket::Ptr socket,
            size_t maxRecvBufferSize,
            DataCB cb)
            :
            mMaxRecvBufferSize(maxRecvBufferSize),
            mSocket(std::move(socket)),
            mSending(false),
            mDataCB(std::move(cb))
        {
            mSocket->socket().non_blocking();
            asio::ip::tcp::no_delay option(true);
            mSocket->socket().set_option(option);
            growRecvBuffer();
        }

        void    startRecv()
        {
            std::call_once(mRecvInitOnceFlag, [=]() {
                    doRecv();
                });
        }

        void    doRecv()
        {
            auto self = shared_from_this();
            mSocket->socket().async_read_some(
                asio::buffer(ox_buffer_getwriteptr(mRecvBuffer.get()),
                    ox_buffer_getwritevalidcount(mRecvBuffer.get())),
                [this, self](std::error_code ec, size_t bytesTransferred) {
                    onRecvCompleted(ec, bytesTransferred);
                });
        }

        void    onRecvCompleted(std::error_code ec, size_t bytesTransferred)
        {
            if (ec)
            {
                return;
            }

            ox_buffer_addwritepos(mRecvBuffer.get(), bytesTransferred);
            if (ox_buffer_getreadvalidcount(mRecvBuffer.get()) == ox_buffer_getsize(mRecvBuffer.get()))
            {
                growRecvBuffer();
            }

            if (mDataCB)
            {
                const auto proclen = mDataCB(ox_buffer_getreadptr(mRecvBuffer.get()),
                    ox_buffer_getreadvalidcount(mRecvBuffer.get()));
                assert(proclen <= ox_buffer_getreadvalidcount(mRecvBuffer.get()));
                if (proclen <= ox_buffer_getreadvalidcount(mRecvBuffer.get()))
                {
                    ox_buffer_addreadpos(mRecvBuffer.get(), proclen);
                }
                else
                {
                    ;//throw
                }
            }

            if (ox_buffer_getwritevalidcount(mRecvBuffer.get()) == 0 
                || ox_buffer_getreadvalidcount(mRecvBuffer.get()) == 0)
            {
                ox_buffer_adjustto_head(mRecvBuffer.get());
            }

            doRecv();
        }

        void    trySend()
        {
            std::lock_guard<:mutex> lck(mSendGuard);
            if (mSending || mPendingSendMsg.empty())
            {
                return;
            }

            mBuffers.resize(mPendingSendMsg.size());
            for (std::size_t i = 0; i < mPendingSendMsg.size(); ++i)
            {
                auto& msg = mPendingSendMsg[i];
                mBuffers[i] = asio::const_buffer(msg.msg->c_str() + msg.sendPos,
                    msg.msg->size() - msg.sendPos);
            }

            auto self = shared_from_this();
            mSocket->socket().async_send(
                mBuffers,
                [this, self](std::error_code ec, size_t bytesTransferred) {
                    onSendCompleted(ec, bytesTransferred);
                });
            mSending = true;
        }

        void    onSendCompleted(std::error_code ec, size_t bytesTransferred)
        {
            {
                std::lock_guard<:mutex> lck(mSendGuard);
                mSending = false;
                if (ec) //TODO::     
                {
                    return;
                }
                adjustSendBuffer(bytesTransferred);
            }
            trySend();
        }

        void    adjustSendBuffer(size_t bytesTransferred)
        {
            while (bytesTransferred > 0)
            {
                auto& frontMsg = mPendingSendMsg.front();
                const auto len = std::min(bytesTransferred, frontMsg.msg->size() - frontMsg.sendPos);
                frontMsg.sendPos += len;
                bytesTransferred -= len;
                if (frontMsg.sendPos == frontMsg.msg->size())
                {
                    mPendingSendMsg.pop_front();
                }
            }
        }

        void    growRecvBuffer()
        {
            if (mRecvBuffer == nullptr)
            {
                mRecvBuffer.reset(ox_buffer_new(std::min(16 * 1024, mMaxRecvBufferSize)));
            }
            else
            {
                const auto NewSize = ox_buffer_getsize(mRecvBuffer.get()) + 1024;
                if (NewSize > mMaxRecvBufferSize)
                {
                    return;
                }
                std::unique_ptr newBuffer(ox_buffer_new(NewSize));
                ox_buffer_write(newBuffer.get(),
                    ox_buffer_getreadptr(mRecvBuffer.get()),
                    ox_buffer_getreadvalidcount(mRecvBuffer.get()));
                mRecvBuffer = std::move(newBuffer);
            }
        }

    private:
        const size_t                        mMaxRecvBufferSize;
        const SharedSocket::Ptr             mSocket;

        bool                                mSending;
        std::mutex                          mSendGuard;
        struct PendingMsg
        {
            size_t  sendPos;
            std::shared_ptr<:string>    msg;
        };
        // TODO::          ,      asio::async_write   ,             .
        std::deque              mPendingSendMsg;
        std::vector<:const_buffer>     mBuffers;

        std::once_flag                      mRecvInitOnceFlag;
        DataCB                              mDataCB;
        struct BufferDeleter
        {
            void operator()(struct buffer_s* ptr) const
            {
                ox_buffer_delete(ptr);
            }
        };
        std::unique_ptr mRecvBuffer;
    };

}

 
使用方法を見てみましょう.
int main(int argc, char **argv)
{
    if (argc != 4)
    {
        fprintf(stderr, "Usage:   
"); exit(-1); } IoServicePool::Ptr ioServicePool = std::make_shared(2, 1); ioServicePool->start(1); const auto endpoint = asio::ip::tcp::endpoint(asio::ip::address_v4::from_string(argv[1]), std::atoi(argv[2])); AsioTcpConnector connector(ioServicePool); for (size_t i = 0; i < std::atoi(argv[3]); i++) { connector.asyncConnect(endpoint, [](SharedSocket::Ptr sharedSocket) { // }); } while (true) { std::this_thread::sleep_for(std::chrono::seconds(1)); } return 0; } int main(int argc, char **argv) { if (argc != 2) { fprintf(stderr, "Usage:
"); exit(-1); } bool stoped = false; IoServicePool::Ptr ioServicePool = std::make_shared(4, 1); ioServicePool->start(1); IoServiceThread serviceWrapper(1); serviceWrapper.start(1); AsioTcpAcceptor::Ptr acceptor = std::make_shared(serviceWrapper.service(), ioServicePool, ip::tcp::endpoint(ip::tcp::v4(), std::atoi(argv[1]))); acceptor->startAccept([](SharedSocket::Ptr socket) { handleAccept(socket); }); asio::signal_set sig(serviceWrapper.service(), SIGINT, SIGTERM); sig.async_wait([&](const asio::error_code & err, int signal) { stoped = true; } ); while (!stoped) { auto nowTime = std::chrono::system_clock::now(); std::this_thread::sleep_for(std::chrono::seconds(1)); auto diff = std::chrono::system_clock::now() - nowTime; auto mill = std::chrono::duration_cast<:chrono::milliseconds>(diff); std::cout << "count is:" << (count/ mill.count())*1000 << ", cost " << mill.count() << std::endl; count = 0; } serviceWrapper.stop(); ioServicePool->stop(); return 0; }