@ -22,21 +22,19 @@ import (
"errors"
"fmt"
"net"
"sort"
"net/netip"
"slices"
"strings"
"golang.org/x/exp/maps"
)
var lan4 , lan6 , special4 , special6 Netlist
var special4 , special6 Netlist
func init ( ) {
// Lists from RFC 5735, RFC 5156,
// https://www.iana.org/assignments/iana-ipv4-special-registry/
lan4 . Add ( "0.0.0.0/8" ) // "This" network
lan4 . Add ( "10.0.0.0/8" ) // Private Use
lan4 . Add ( "172.16.0.0/12" ) // Private Use
lan4 . Add ( "192.168.0.0/16" ) // Private Use
lan6 . Add ( "fe80::/10" ) // Link-Local
lan6 . Add ( "fc00::/7" ) // Unique-Local
special4 . Add ( "0.0.0.0/8" ) // "This" network.
special4 . Add ( "192.0.0.0/29" ) // IPv4 Service Continuity
special4 . Add ( "192.0.0.9/32" ) // PCP Anycast
special4 . Add ( "192.0.0.170/32" ) // NAT64/DNS64 Discovery
@ -66,7 +64,7 @@ func init() {
}
// Netlist is a list of IP networks.
type Netlist [ ] net . IPNet
type Netlist [ ] netip . Prefix
// ParseNetlist parses a comma-separated list of CIDR masks.
// Whitespace and extra commas are ignored.
@ -78,11 +76,11 @@ func ParseNetlist(s string) (*Netlist, error) {
if mask == "" {
continue
}
_ , n , err := net . ParseCIDR ( mask )
prefix , err := netip . ParsePrefix ( mask )
if err != nil {
return nil , err
}
l = append ( l , * n )
l = append ( l , prefix )
}
return & l , nil
}
@ -103,11 +101,11 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
return err
}
for _ , mask := range masks {
_ , n , err := net . ParseCIDR ( mask )
prefix , err := netip . ParsePrefix ( mask )
if err != nil {
return err
}
* l = append ( * l , * n )
* l = append ( * l , prefix )
}
return nil
}
@ -115,15 +113,20 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
// intended to be used for setting up static lists.
func ( l * Netlist ) Add ( cidr string ) {
_ , n , err := net . ParseCIDR ( cidr )
prefix , err := netip . ParsePrefix ( cidr )
if err != nil {
panic ( err )
}
* l = append ( * l , * n )
* l = append ( * l , prefix )
}
// Contains reports whether the given IP is contained in the list.
func ( l * Netlist ) Contains ( ip net . IP ) bool {
return l . ContainsAddr ( IPToAddr ( ip ) )
}
// ContainsAddr reports whether the given IP is contained in the list.
func ( l * Netlist ) ContainsAddr ( ip netip . Addr ) bool {
if l == nil {
return false
}
@ -137,25 +140,39 @@ func (l *Netlist) Contains(ip net.IP) bool {
// IsLAN reports whether an IP is a local network address.
func IsLAN ( ip net . IP ) bool {
return AddrIsLAN ( IPToAddr ( ip ) )
}
// AddrIsLAN reports whether an IP is a local network address.
func AddrIsLAN ( ip netip . Addr ) bool {
if ip . Is4In6 ( ) {
ip = netip . AddrFrom4 ( ip . As4 ( ) )
}
if ip . IsLoopback ( ) {
return true
}
if v4 := ip . To4 ( ) ; v4 != nil {
return lan4 . Contains ( v4 )
}
return lan6 . Contains ( ip )
return ip . IsPrivate ( ) || ip . IsLinkLocalUnicast ( )
}
// IsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func IsSpecialNetwork ( ip net . IP ) bool {
return AddrIsSpecialNetwork ( IPToAddr ( ip ) )
}
// AddrIsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func AddrIsSpecialNetwork ( ip netip . Addr ) bool {
if ip . Is4In6 ( ) {
ip = netip . AddrFrom4 ( ip . As4 ( ) )
}
if ip . IsMulticast ( ) {
return true
}
if v4 := ip . To4 ( ) ; v4 != nil {
return special4 . Contains ( v4 )
if ip . Is4 ( ) {
return special4 . ContainsAddr ( ip )
}
return special6 . Contains ( ip )
return special6 . ContainsAddr ( ip )
}
var (
@ -175,19 +192,31 @@ var (
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayIP ( sender , addr net . IP ) error {
if len ( addr ) != net . IPv4len && len ( addr ) != net . IPv6len {
return CheckRelayAddr ( IPToAddr ( sender ) , IPToAddr ( addr ) )
}
// CheckRelayAddr reports whether an IP relayed from the given sender IP
// is a valid connection target.
//
// There are four rules:
// - Special network addresses are never valid.
// - Loopback addresses are OK if relayed by a loopback host.
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayAddr ( sender , addr netip . Addr ) error {
if ! addr . IsValid ( ) {
return errInvalid
}
if addr . IsUnspecified ( ) {
return errUnspecified
}
if IsSpecialNetwork ( addr ) {
if Addr IsSpecialNetwork( addr ) {
return errSpecial
}
if addr . IsLoopback ( ) && ! sender . IsLoopback ( ) {
return errLoopback
}
if IsLAN ( addr ) && ! IsLAN ( sender ) {
if Addr IsLAN( addr ) && ! Addr IsLAN( sender ) {
return errLAN
}
return nil
@ -221,17 +250,22 @@ type DistinctNetSet struct {
Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet
members map [ string ] uint
buf net . IP
members map [ netip . Prefix ] uint
}
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func ( s * DistinctNetSet ) Add ( ip net . IP ) bool {
return s . AddAddr ( IPToAddr ( ip ) )
}
// AddAddr adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func ( s * DistinctNetSet ) AddAddr ( ip netip . Addr ) bool {
key := s . key ( ip )
n := s . members [ string ( key ) ]
n := s . members [ key ]
if n < s . Limit {
s . members [ string ( key ) ] = n + 1
s . members [ key ] = n + 1
return true
}
return false
@ -239,20 +273,30 @@ func (s *DistinctNetSet) Add(ip net.IP) bool {
// Remove removes an IP from the set.
func ( s * DistinctNetSet ) Remove ( ip net . IP ) {
s . RemoveAddr ( IPToAddr ( ip ) )
}
// RemoveAddr removes an IP from the set.
func ( s * DistinctNetSet ) RemoveAddr ( ip netip . Addr ) {
key := s . key ( ip )
if n , ok := s . members [ string ( key ) ] ; ok {
if n , ok := s . members [ key ] ; ok {
if n == 1 {
delete ( s . members , string ( key ) )
delete ( s . members , key )
} else {
s . members [ string ( key ) ] = n - 1
s . members [ key ] = n - 1
}
}
}
// Contains whether the given IP is contained in the set.
func ( s DistinctNetSet ) Contains ( ip net . IP ) bool {
return s . ContainsAddr ( IPToAddr ( ip ) )
}
// ContainsAddr whether the given IP is contained in the set.
func ( s DistinctNetSet ) ContainsAddr ( ip netip . Addr ) bool {
key := s . key ( ip )
_ , ok := s . members [ string ( key ) ]
_ , ok := s . members [ key ]
return ok
}
@ -265,54 +309,30 @@ func (s DistinctNetSet) Len() int {
return int ( n )
}
// key encodes the map key for an address into a temporary buffer.
//
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func ( s * DistinctNetSet ) key ( ip net . IP ) net . IP {
// key returns the map key for ip.
func ( s * DistinctNetSet ) key ( ip netip . Addr ) netip . Prefix {
// Lazily initialize storage.
if s . members == nil {
s . members = make ( map [ string ] uint )
s . buf = make ( net . IP , 17 )
}
// Canonicalize ip and bits.
typ := byte ( '6' )
if ip4 := ip . To4 ( ) ; ip4 != nil {
typ , ip = '4' , ip4
s . members = make ( map [ netip . Prefix ] uint )
}
bits := s . Subnet
if bits > uint ( len ( ip ) * 8 ) {
bits = uint ( len ( ip ) * 8 )
}
// Encode the prefix into s.buf.
nb := int ( bits / 8 )
mask := ^ byte ( 0xFF >> ( bits % 8 ) )
s . buf [ 0 ] = typ
buf := append ( s . buf [ : 1 ] , ip [ : nb ] ... )
if nb < len ( ip ) && mask != 0 {
buf = append ( buf , ip [ nb ] & mask )
p , err := ip . Prefix ( int ( s . Subnet ) )
if err != nil {
panic ( err )
}
return buf
return p
}
// String implements fmt.Stringer
func ( s DistinctNetSet ) String ( ) string {
keys := maps . Keys ( s . members )
slices . SortFunc ( keys , func ( a , b netip . Prefix ) int {
return strings . Compare ( a . String ( ) , b . String ( ) )
} )
var buf bytes . Buffer
buf . WriteString ( "{" )
keys := make ( [ ] string , 0 , len ( s . members ) )
for k := range s . members {
keys = append ( keys , k )
}
sort . Strings ( keys )
for i , k := range keys {
var ip net . IP
if k [ 0 ] == '4' {
ip = make ( net . IP , 4 )
} else {
ip = make ( net . IP , 16 )
}
copy ( ip , k [ 1 : ] )
fmt . Fprintf ( & buf , "%v×%d" , ip , s . members [ k ] )
fmt . Fprintf ( & buf , "%v×%d" , k , s . members [ k ] )
if i != len ( keys ) - 1 {
buf . WriteString ( " " )
}