diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index fb5aee155..8c9b2548e 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -225,45 +225,62 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { // // Unlike gVisor's network layer, this does not validate ICMP checksums or // reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. -func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { +func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv4MinimumSize { + return 0, 0, src, dst, false + } ip := header.IPv4(payload) if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { - return 0, src, dst, false + return 0, 0, src, dst, false } if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { - return 0, src, dst, false + return 0, 0, src, dst, false } ipHdrLen = int(ip.HeaderLength()) - if len(payload)-ipHdrLen < header.ICMPv4MinimumSize { - return 0, src, dst, false + totalLen := int(ip.TotalLength()) + if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) { + return 0, 0, src, dst, false } - return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true + icmpLen = totalLen - ipHdrLen + if icmpLen < header.ICMPv4MinimumSize { + return 0, 0, src, dst, false + } + return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true } -func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { +func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv6MinimumSize { + return 0, 0, src, dst, false + } ip := header.IPv6(payload) if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { - return 0, src, dst, false + return 0, 0, src, dst, false } ipHdrLen = header.IPv6MinimumSize - if len(payload)-ipHdrLen < header.ICMPv6MinimumSize { - return 0, src, dst, false + icmpLen = int(ip.PayloadLength()) + if icmpLen < header.ICMPv6MinimumSize || ipHdrLen+icmpLen > len(payload) { + return 0, 0, src, dst, false } - return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true + return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true } func (f *Forwarder) handleICMPDirect(payload []byte) bool { + if len(payload) == 0 { + return false + } var ( ipHdrLen int + icmpLen int srcAddr tcpip.Address dstAddr tcpip.Address ok bool ) - switch payload[0] >> 4 { + version := payload[0] >> 4 + switch version { case 4: - ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload) + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload) case 6: - ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload) } if !ok { return false @@ -280,22 +297,20 @@ func (f *Forwarder) handleICMPDirect(payload []byte) bool { RemoteAddress: srcAddr, } - // Build a PacketBuffer with headers consumed the same way gVisor would. + // Trim the buffer to the IP-declared length so gVisor doesn't see padding. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(payload), + Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]), }) defer pkt.DecRef() if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { return false } - - icmpPayload := payload[ipHdrLen:] - if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok { + if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok { return false } - if payload[0]>>4 == 6 { + if version == 6 { return f.handleICMPv6(id, pkt) } return f.handleICMP(id, pkt)