Tcp粘包处理
Tcp传输的是数据流,不能保证每次收到数据包与发送的数据包完全一致。比如发送了两个消息abc和deg收到的可能是ab和cdef。
为了解决这个问题需要在消息中加上能标识,以拆分出发送的原始消息。
在此使用了简单的方式,在消息前加上4字节的包长度,收到消息后查看是否完整,若不完整则等到下一次收到数据再处理,直到接收到完整的数据包。
负载包 = 4bytes(消息包长度) + nbytes(消息包内容)。
下面是完整的实现:
PacketBufferManager负责拆包及数据包缓存的处理,为泛型类TPacket为消息包的类型。
收到到的消息包一般会有一定的结构比如Json、Protobuf等,因此抽象了IPacketFactory接口,可以自定义消息的反序列化,若消息包就以字节数据的方式处理可用示例中的BytePacketFactory。
public interface IPacketFactory<out TPacket> { TPacket ReadPacket(Stream stream, int count); }
public class BytePacketFactory : IPacketFactory<byte[]> { public IKeyInfo KeyInfo { get; set; } public byte[] ReadPacket(Stream stream, int count) { var bytes = new byte[count]; var unused = stream.Read(bytes, 0, count); return bytes; } }
public class PacketBufferManager<TPacket> { private const int MaxPacketLength = 1024 * 1024 * 1024; private readonly ArraySegmentStream _bufferStream = new ArraySegmentStream(); private readonly IPacketFactory<TPacket> _packetFactory; private List<ArraySegment<byte>> _datas = new List<ArraySegment<byte>>(); public PacketBufferManager(IPacketFactory<TPacket> packetFactory) { _packetFactory = packetFactory ?? throw new ArgumentNullException(nameof(packetFactory)); } public TPacket[] ReadPackets(byte[] data, int offset, int count) { var temp = _datas.ToList(); temp.Add(new ArraySegment<byte>(data, offset, count)); var totalCount = ArraySegmentStream.GetLeftCount(temp.ToArray(), 0, 0); if (totalCount < 4) { var currentBytes = data.Skip(offset).Take(count).ToArray(); temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes); _datas = temp; return null; } _bufferStream.Reset(temp.ToArray()); var packets = new List<TPacket>(); while (true) { var lengthBytes = new byte[4]; var savePosition = _bufferStream.Position; var readLength = _bufferStream.Read(lengthBytes, 0, 4); _bufferStream.Position = savePosition; if (readLength < 4) { var currentBytes = data .Skip(offset + _bufferStream.SegmentPosition) .Take(count - _bufferStream.SegmentPosition) .ToArray(); temp = temp.Skip(_bufferStream.SegmentIndex).ToList(); temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes); _datas = temp; return packets.ToArray(); } var packetLength = BitConverter.ToInt32(lengthBytes, 0); if (packetLength > MaxPacketLength) { throw new InvalidDataException("packet excced max length"); } var leftCount = _bufferStream.Length - _bufferStream.Position; if (leftCount < packetLength + 4) //no enough bytes { var currentBytes = data .Skip(offset + _bufferStream.SegmentPosition) .Take(count - _bufferStream.SegmentPosition) .ToArray(); temp = temp.Skip(_bufferStream.SegmentIndex).ToList(); temp[temp.Count - 1] = new ArraySegment<byte>(currentBytes); _datas = temp; return packets.ToArray(); } _bufferStream.Read(lengthBytes, 0, 4); var pb = _packetFactory.ReadPacket(_bufferStream, packetLength); packets.Add(pb); if (_bufferStream.Length == _bufferStream.Position) //all byte read { _datas.Clear(); return packets.ToArray(); } //var usedDataLength = _bufferStream.Position; } } }
public class ArraySegmentStream : Stream { public int SegmentIndex { get; private set; } public int SegmentPosition { get; private set; } public ArraySegment<byte>[] Datas { get; private set; } public ArraySegmentStream() { } public ArraySegmentStream(ArraySegment<byte>[] datas) { Reset(datas); } public void Reset(ArraySegment<byte>[] datas) { Datas = datas; SegmentIndex = 0; SegmentPosition = 0; } public override void Flush() { //throw new NotImplementedException(); } public override int Read(byte[] buffer, int offset, int count) { var leftTotalCount = GetLeftCount(); if (leftTotalCount < count) throw new IOException("no enough buffer data"); var bufferIndex = 0; for (var i = SegmentIndex; i < Datas.Length; i++) { var leftCount = count - bufferIndex; var segment = Datas[SegmentIndex]; var currentSegmentLeftCount = segment.Count - SegmentPosition; var readCount = Math.Min(leftCount, currentSegmentLeftCount); Array.Copy(segment.Array, segment.Offset + SegmentPosition, buffer, offset + bufferIndex, readCount); SegmentPosition += readCount; bufferIndex += readCount; if (SegmentPosition == segment.Count) { //move to next segment SegmentPosition = 0; SegmentIndex++; } if (bufferIndex != count) continue; break; } return bufferIndex; } private int GetLeftCount() { return GetLeftCount(Datas, SegmentIndex, SegmentPosition); } public static int GetLeftCount(ArraySegment<byte>[] segments, int currentSegmentIndex, int currentSegmentPosition) { var count = 0; var isCurrent = true; for (var i = currentSegmentIndex; i < segments.Length; i++) { count += isCurrent ? segments[i].Count - currentSegmentPosition : segments[i].Count; isCurrent = false; } return count; } public override long Seek(long offset, SeekOrigin origin) { throw new NotImplementedException(); } public override void SetLength(long value) { throw new NotImplementedException(); } public override void Write(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } public override bool CanRead { get; } = true; public override bool CanSeek { get; } = false; public override bool CanWrite { get; } = false; public override long Length => GetLeftCount(Datas, 0, 0); public override long Position { get => Length - GetLeftCount(); set { if (value + 1 > Length) throw new InvalidOperationException("position out of range"); var position = 0; for (int segmentIndex = 0; segmentIndex < Datas.Length; segmentIndex++) { if (position + Datas[segmentIndex].Count - 1 < value) { position += Datas[segmentIndex].Count; //next segment first element position } else { SegmentIndex = segmentIndex; SegmentPosition = (int)value - position; return; } } throw new IndexOutOfRangeException(); } } }
//用法示例 var bufferManager = new PacketBufferManager<byte[]>(new BytePacketFactory()); var readBufferSize = 8192; var readBuffer = new byte[readBufferSize]; while (true) { try { var receivedLength = networkStream.Read(readBuffer, 0, readBufferSize); //接收到字节流后由bufferManager处理,如果无完整的消息包则返回null, 可能会返回多个包 var packets = bufferManager.ReadPackets(readBuffer, 0, receivedLength); } catch (Exception ex) { Console.WriteLine(ex); break; } }