Tcp粘包处理

Tcp传输的是数据流,不能保证每次收到数据包与发送的数据包完全一致。比如发送了两个消息abc和deg收到的可能是ab和cdef。

为了解决这个问题需要在消息中加上能标识,以拆分出发送的原始消息。

在此使用了简单的方式,在消息前加上4字节的包长度,收到消息后查看是否完整,若不完整则等到下一次收到数据再处理,直到接收到完整的数据包。

负载包 = 4bytes(消息包长度) + nbytes(消息包内容)。

下面是完整的实现:

PacketBufferManager负责拆包及数据包缓存的处理,为泛型类TPacket为消息包的类型。

收到到的消息包一般会有一定的结构比如Json、Protobuf等,因此抽象了IPacketFactory接口,可以自定义消息的反序列化,若消息包就以字节数据的方式处理可用示例中的BytePacketFactory。

    public interface IPacketFactory<out TPacket>
    {
        TPacket ReadPacket(Stream stream, int count);
    }
    public class BytePacketFactory : IPacketFactory<byte[]>
    {
        public IKeyInfo KeyInfo { get; set; }

        public byte[] ReadPacket(Stream stream, int count)
        {
            var bytes = new byte[count];
            var unused = stream.Read(bytes, 0, count);
            return bytes;
        }
    }
    public class PacketBufferManager<TPacket>
    {
        private const int MaxPacketLength = 1024 * 1024 * 1024;
        private readonly ArraySegmentStream _bufferStream = new ArraySegmentStream();
        private readonly IPacketFactory<TPacket> _packetFactory;
        private List<ArraySegment<byte>> _datas = new List<ArraySegment<byte>>();

        public PacketBufferManager(IPacketFactory<TPacket> packetFactory)
        {
            _packetFactory = packetFactory ?? throw new ArgumentNullException(nameof(packetFactory));
        }

        public TPacket[] ReadPackets(byte[] data, int offset, int count)
        {
            var temp = _datas.ToList();
            temp.Add(new ArraySegment<byte>(data, offset, count));

            var totalCount = ArraySegmentStream.GetLeftCount(temp.ToArray(), 0, 0);
            if (totalCount < 4)
            {
                var currentBytes = data.Skip(offset).Take(count).ToArray();
                temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes);
                _datas = temp;
                return null;
            }

            _bufferStream.Reset(temp.ToArray());
            var packets = new List<TPacket>();
            while (true)
            {
                var lengthBytes = new byte[4];
                var savePosition = _bufferStream.Position;
                var readLength = _bufferStream.Read(lengthBytes, 0, 4);
                _bufferStream.Position = savePosition;
                if (readLength < 4)
                {
                    var currentBytes = data
                        .Skip(offset + _bufferStream.SegmentPosition)
                        .Take(count - _bufferStream.SegmentPosition)
                        .ToArray();
                    temp = temp.Skip(_bufferStream.SegmentIndex).ToList();
                    temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes);
                    _datas = temp;
                    return packets.ToArray();
                }

                var packetLength = BitConverter.ToInt32(lengthBytes, 0);
                if (packetLength > MaxPacketLength)
                {
                    throw new InvalidDataException("packet excced max length");
                }

                var leftCount = _bufferStream.Length - _bufferStream.Position;
                if (leftCount < packetLength + 4) //no enough bytes
                {
                    var currentBytes = data
                        .Skip(offset + _bufferStream.SegmentPosition)
                        .Take(count - _bufferStream.SegmentPosition)
                        .ToArray();
                    temp = temp.Skip(_bufferStream.SegmentIndex).ToList();
                    temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes);
                    _datas = temp;
                    return packets.ToArray();
                }

                _bufferStream.Read(lengthBytes, 0, 4);
                var pb = _packetFactory.ReadPacket(_bufferStream, packetLength);
                packets.Add(pb);

                if (_bufferStream.Length == _bufferStream.Position) //all byte read
                {
                    _datas.Clear();
                    return packets.ToArray();
                }
                //var usedDataLength = _bufferStream.Position;
            }
        }
    }
    public class ArraySegmentStream : Stream
    {
        public int SegmentIndex { get; private set; }
        public int SegmentPosition { get; private set; }
        public ArraySegment<byte>[] Datas { get; private set; }

        public ArraySegmentStream() { }

        public ArraySegmentStream(ArraySegment<byte>[] datas)
        {
            Reset(datas);
        }

        public void Reset(ArraySegment<byte>[] datas)
        {
            Datas = datas;
            SegmentIndex = 0;
            SegmentPosition = 0;
        }

        public override void Flush()
        {
            //throw new NotImplementedException();
        }

        public override int Read(byte[] buffer, int offset, int count)
        {
            var leftTotalCount = GetLeftCount();
            if (leftTotalCount < count)
                throw new IOException("no enough buffer data");

            var bufferIndex = 0;
            for (var i = SegmentIndex; i < Datas.Length; i++)
            {
                var leftCount = count - bufferIndex;
                var segment = Datas[SegmentIndex];
                var currentSegmentLeftCount = segment.Count - SegmentPosition;
                var readCount = Math.Min(leftCount, currentSegmentLeftCount);
                Array.Copy(segment.Array, segment.Offset + SegmentPosition,
                    buffer, offset + bufferIndex, readCount);

                SegmentPosition += readCount;
                bufferIndex += readCount;

                if (SegmentPosition == segment.Count)
                {
                    //move to next segment
                    SegmentPosition = 0;
                    SegmentIndex++;
                }
                if (bufferIndex != count) continue;

                break;
            }
            return bufferIndex;
        }

        private int GetLeftCount()
        {
            return GetLeftCount(Datas, SegmentIndex, SegmentPosition);
        }

        public static int GetLeftCount(ArraySegment<byte>[] segments, int currentSegmentIndex, int currentSegmentPosition)
        {
            var count = 0;
            var isCurrent = true;
            for (var i = currentSegmentIndex; i < segments.Length; i++)
            {
                count += isCurrent
                    ? segments[i].Count - currentSegmentPosition
                    : segments[i].Count;
                isCurrent = false;
            }
            return count;
        }

        public override long Seek(long offset, SeekOrigin origin)
        {
            throw new NotImplementedException();
        }

        public override void SetLength(long value)
        {
            throw new NotImplementedException();
        }

        public override void Write(byte[] buffer, int offset, int count)
        {
            throw new NotImplementedException();
        }

        public override bool CanRead { get; } = true;
        public override bool CanSeek { get; } = false;
        public override bool CanWrite { get; } = false;
        public override long Length => GetLeftCount(Datas, 0, 0);
        public override long Position
        {
            get => Length - GetLeftCount();
            set
            {
                if (value + 1 > Length) throw new InvalidOperationException("position out of range");

                var position = 0;
                for (int segmentIndex = 0; segmentIndex < Datas.Length; segmentIndex++)
                {
                    if (position + Datas[segmentIndex].Count - 1 < value)
                    {
                        position += Datas[segmentIndex].Count; //next segment first element position
                    }
                    else
                    {
                        SegmentIndex = segmentIndex;
                        SegmentPosition = (int)value - position;
                        return;
                    }
                }
                throw new IndexOutOfRangeException();
            }
        }
    }
            //用法示例
            var bufferManager = new PacketBufferManager<byte[]>(new BytePacketFactory());
            var readBufferSize = 8192;
            var readBuffer = new byte[readBufferSize];
            while (true)
            {
                try
                {
                    var receivedLength = networkStream.Read(readBuffer, 0, readBufferSize);
                    //接收到字节流后由bufferManager处理,如果无完整的消息包则返回null, 可能会返回多个包
                    var packets = bufferManager.ReadPackets(readBuffer, 0, receivedLength);
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);
                    break;
                }
            }

 

posted @ 2017-12-09 21:33  ChrisHuang  阅读(305)  评论(0编辑  收藏  举报