技术宅,fat-man

增加语言的了解程度可以避免写出愚蠢的代码

导航

重构oceanbase的一个函数

我去,今天读了一下ob的源码,感觉有点乱啊!!!好吧,当作练手,我重构了一个函数

void* ObMySQLCallback::decode(easy_message_t* m)
    {
      uint32_t pkt_len = 0;
      uint8_t pkt_seq = 0;
      uint8_t pkt_type = 0;
      ObMySQLCommandPacket* packet = NULL;
      char* buffer = NULL;
      int32_t len = 0;

      if (NULL == m)
      {
        TBSYS_LOG(ERROR, "invalid argument m is %p", m);
      }
      else if (NULL == m->input)
      {
        TBSYS_LOG(ERROR, "invalide argument m->input is %p", m->input);
      }
      else
      {
        if ((len = static_cast<int32_t>(m->input->last - m->input->pos)) >= OB_MYSQL_PACKET_HEADER_SIZE)
        {
          //1. decode length from net buffer
          //2. decode seq from net buffer
          ObMySQLUtil::get_uint3(m->input->pos, pkt_len);
          ObMySQLUtil::get_uint1(m->input->pos, pkt_seq);

          //message has enough buffer
          if (pkt_len <= m->input->last - m->input->pos)
          {
            ObMySQLUtil::get_uint1(m->input->pos, pkt_type);
            //利用message带的pool进行应用层内存的分配
            buffer = reinterpret_cast<char*>(easy_pool_alloc(m->pool,
                                                              static_cast<uint32_t>(sizeof(ObMySQLCommandPacket) + pkt_len)));
            if (NULL == buffer)
            {
              TBSYS_LOG(ERROR, "alloc packet buffer(length=%lu) from m->pool failed", sizeof(ObMySQLCommandPacket) + pkt_len);
            }
            else
            {
              TBSYS_LOG(DEBUG, "alloc packet buffer length = %lu", sizeof(ObMySQLCommandPacket) + pkt_len);
              packet = new(buffer)ObMySQLCommandPacket();
              packet->set_header(pkt_len, pkt_seq);
              packet->set_type(pkt_type);
              packet->set_receive_ts(tbsys::CTimeUtil::getTime());
              memcpy(buffer + sizeof(ObMySQLCommandPacket), m->input->pos, pkt_len - 1);
              packet->get_command().assign(buffer + sizeof(ObMySQLCommandPacket), pkt_len - 1);
              TBSYS_LOG(DEBUG, "decode comand packet command is \"%.*s\"", packet->get_command().length(),
                        packet->get_command().ptr());
              if (PACKET_RECORDER_FLAG)
              {
                // record the packet to FIFO stream if required
                ObMySQLServer* server = reinterpret_cast<ObMySQLServer*>(m->c->handler->user_data);
                ObMySQLCommandPacketRecord record;
                record.socket_fd_ = m->c->fd;
                record.cseq_ = m->c->seq;
                record.addr_ = m->c->addr;
                record.pkt_length_ = pkt_len;
                record.pkt_seq_ = pkt_seq;
                record.cmd_type_ = pkt_type;
                struct iovec buffers[2];
                buffers[0].iov_base = &record;
                buffers[0].iov_len = sizeof(record);
                buffers[1].iov_base = m->input->pos;
                buffers[1].iov_len = pkt_len - 1;
                int err = OB_SUCCESS;
                if (OB_SUCCESS != (err = server->get_packet_recorder().push(buffers, 2)))
                {
                
                  TBSYS_LOG(WARN, "failed to record MySQL packet, err=%d", err);
                }
              }
              m->input->pos += pkt_len - 1;
            }
          }
          else
          {
            m->next_read_len = static_cast<int>(pkt_len - (m->input->last - m->input->pos));
            TBSYS_LOG(DEBUG, "not enough data in message, packet length = %u, data in message is %ld",
                      pkt_len, m->input->last - m->input->pos);
            m->input->pos -= OB_MYSQL_PACKET_HEADER_SIZE;
          }
        }
      }
      return packet;
    }

 

问题:代码好长。。。嵌套太深。。。

ObMySQLCommandPacket* ObMySQLCallback::make_packet(easy_message_t* m, uint32_t *pkt_len, uint8_t *pkt_seq, uint8_t *pkt_type)
{
    ObMySQLUtil::get_uint1(m->input->pos, *pkt_type);
    //利用message带的pool进行应用层内存的分配
    char* buffer = reinterpret_cast<char*>(easy_pool_alloc(m->pool,
                static_cast<uint32_t>(sizeof(ObMySQLCommandPacket) + *pkt_len)));
                
    if (NULL == buffer)
    {
        TBSYS_LOG(ERROR, "alloc packet buffer(length=%lu) from m->pool failed", sizeof(ObMySQLCommandPacket) + *pkt_len);
        return NULL;
    }

    TBSYS_LOG(DEBUG, "alloc packet buffer length = %lu", sizeof(ObMySQLCommandPacket) + *pkt_len);
    ObMySQLCommandPacket* packet = new(buffer)ObMySQLCommandPacket();
    packet->set_header(*pkt_len, *pkt_seq);
    packet->set_type(*pkt_type);
    packet->set_receive_ts(tbsys::CTimeUtil::getTime());
    memcpy(buffer + sizeof(ObMySQLCommandPacket), m->input->pos, *pkt_len - 1);
    packet->get_command().assign(buffer + sizeof(ObMySQLCommandPacket), *pkt_len - 1);
    TBSYS_LOG(DEBUG, "decode comand packet command is \"%.*s\"", packet->get_command().length(),
            packet->get_command().ptr());
    return packet;
}

void ObMySQLCallback::record_packet(easy_message_t* m, uint32_t *pkt_len, uint8_t *pkt_seq, uint8_t *pkt_type)
{
    // record the packet to FIFO stream if required
    ObMySQLServer* server = reinterpret_cast<ObMySQLServer*>(m->c->handler->user_data);
    ObMySQLCommandPacketRecord record;
    record.socket_fd_ = m->c->fd;
    record.cseq_ = m->c->seq;
    record.addr_ = m->c->addr;
    record.pkt_length_ = *pkt_len;
    record.pkt_seq_ = *pkt_seq;
    record.cmd_type_ = *pkt_type;
    struct iovec buffers[2];
    buffers[0].iov_base = &record;
    buffers[0].iov_len = sizeof(record);
    buffers[1].iov_base = m->input->pos;
    buffers[1].iov_len = pkt_len - 1;
    int err = OB_SUCCESS;
    if (OB_SUCCESS != (err = server->get_packet_recorder().push(buffers, 2)))
    {
        TBSYS_LOG(WARN, "failed to record MySQL packet, err=%d", err);
    }
}

void ObMySQLCallback::init_pkt_variables(uint32_t *pkt_len, uint8_t *pkt_seq)
{
    //1. decode length from net buffer
    //2. decode seq from net buffer 
    
    ObMySQLUtil::get_uint3(m->input->pos, *pkt_len);
    ObMySQLUtil::get_uint1(m->input->pos, *pkt_seq);
}

void* ObMySQLCallback::decode(easy_message_t* m)
{
    uint32_t pkt_len = 0,  pkt_seq = 0 ,  pkt_type = 0;
    
    if (NULL == m || NULL == m->input)
    {
        TBSYS_LOG(ERROR, "invalid argument m %p", m);
        return NULL;
    }
    
    int32_t msg_buffer_size = static_cast<int32_t>(m->input->last - m->input->pos);
    if ( msg_buffer_size < OB_MYSQL_PACKET_HEADER_SIZE)
    {
        return NULL;
    }
    
    init_pkt_variables(&pkt_len, &pkt_seq);
    if (pkt_len > msg_buffer_size) //message has not enough buffer
    {
        m->next_read_len = static_cast<int>(pkt_len - msg_buffer_size);
        TBSYS_LOG(DEBUG, "not enough data in message, packet length = %u, data in message is %ld",pkt_len, msg_buffer_size);
        m->input->pos -= OB_MYSQL_PACKET_HEADER_SIZE;
        return NULL;
    }
    
    ObMySQLCommandPacket* packet = make_packet(m, &pkt_len,  &pkt_seq ,  &pkt_type);
    if (PACKET_RECORDER_FLAG)
    {
        record_packet(m, &pkt_len,  &pkt_seq ,  &pkt_type);
    }
    m->input->pos += pkt_len - 1;
    return packet;
}

 

 

posted on 2014-01-14 17:01  codestyle  阅读(569)  评论(0编辑  收藏  举报