Validate IP-declared lengths before synthesizing direct ICMP packet

This commit is contained in:
Viktor Liu
2026-05-04 12:12:45 +02:00
parent 1b2d7777a3
commit 5f3aef3198

View File

@@ -225,45 +225,62 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
// //
// Unlike gVisor's network layer, this does not validate ICMP checksums or // Unlike gVisor's network layer, this does not validate ICMP checksums or
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. // reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv4MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv4(payload) ip := header.IPv4(payload)
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
return 0, src, dst, false return 0, 0, src, dst, false
} }
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
return 0, src, dst, false return 0, 0, src, dst, false
} }
ipHdrLen = int(ip.HeaderLength()) ipHdrLen = int(ip.HeaderLength())
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize { totalLen := int(ip.TotalLength())
return 0, src, dst, false if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) {
return 0, 0, src, dst, false
} }
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true icmpLen = totalLen - ipHdrLen
if icmpLen < header.ICMPv4MinimumSize {
return 0, 0, src, dst, false
}
return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
} }
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv6MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv6(payload) ip := header.IPv6(payload)
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
return 0, src, dst, false return 0, 0, src, dst, false
} }
ipHdrLen = header.IPv6MinimumSize ipHdrLen = header.IPv6MinimumSize
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize { icmpLen = int(ip.PayloadLength())
return 0, src, dst, false if icmpLen < header.ICMPv6MinimumSize || ipHdrLen+icmpLen > len(payload) {
return 0, 0, src, dst, false
} }
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
} }
func (f *Forwarder) handleICMPDirect(payload []byte) bool { func (f *Forwarder) handleICMPDirect(payload []byte) bool {
if len(payload) == 0 {
return false
}
var ( var (
ipHdrLen int ipHdrLen int
icmpLen int
srcAddr tcpip.Address srcAddr tcpip.Address
dstAddr tcpip.Address dstAddr tcpip.Address
ok bool ok bool
) )
switch payload[0] >> 4 { version := payload[0] >> 4
switch version {
case 4: case 4:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload) ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
case 6: case 6:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload) ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
} }
if !ok { if !ok {
return false return false
@@ -280,22 +297,20 @@ func (f *Forwarder) handleICMPDirect(payload []byte) bool {
RemoteAddress: srcAddr, RemoteAddress: srcAddr,
} }
// Build a PacketBuffer with headers consumed the same way gVisor would. // Trim the buffer to the IP-declared length so gVisor doesn't see padding.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload), Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]),
}) })
defer pkt.DecRef() defer pkt.DecRef()
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
return false return false
} }
if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok {
icmpPayload := payload[ipHdrLen:]
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
return false return false
} }
if payload[0]>>4 == 6 { if version == 6 {
return f.handleICMPv6(id, pkt) return f.handleICMPv6(id, pkt)
} }
return f.handleICMP(id, pkt) return f.handleICMP(id, pkt)