Compare commits

...

2 Commits

Author SHA1 Message Date
pascal-fischer
890e09b787 Keep confiured nameservers as fallback (#1036)
* keep existing nameserver as fallback when adding netbird resolver

* fix resolvconf

* fix imports
2023-08-01 17:45:44 +02:00
Bethuel Mmbaga
48098c994d Handle authentication errors in PKCE flow (#1039)
* handle authentication errors in PKCE flow

* remove shadowing and replace TokenEndpoint for PKCE config

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2023-07-31 14:22:38 +02:00
6 changed files with 64 additions and 48 deletions

View File

@@ -25,6 +25,8 @@ var _ OAuthFlow = &PKCEAuthorizationFlow{}
const ( const (
queryState = "state" queryState = "state"
queryCode = "code" queryCode = "code"
queryError = "error"
queryErrorDesc = "error_description"
defaultPKCETimeoutSeconds = 300 defaultPKCETimeoutSeconds = 300
) )
@@ -141,9 +143,13 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
tokenValidatorFunc := func() (*oauth2.Token, error) { tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query() query := req.URL.Query()
state := query.Get(queryState) if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state // Prevent timing attacks on state
if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state") return nil, fmt.Errorf("invalid state")
} }
@@ -161,12 +167,13 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
token, err := tokenValidatorFunc() token, err := tokenValidatorFunc()
if err != nil { if err != nil {
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
return
} }
tokenChan <- token
renderPKCEFlowTmpl(w, nil) renderPKCEFlowTmpl(w, nil)
tokenChan <- token
}) })
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {

View File

@@ -15,7 +15,8 @@ const (
fileGeneratedResolvConfSearchBeginContent = "search " fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader + fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n" fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
"%s\n"
) )
const ( const (
@@ -91,7 +92,12 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
searchDomains += " " + dConf.domain searchDomains += " " + dConf.domain
appendedDomains++ appendedDomains++
} }
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms) err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
if err != nil { if err != nil {
err = f.restore() err = f.restore()

View File

@@ -182,12 +182,11 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey := s.getPrimaryService() primaryServiceKey, existingNameserver := s.getPrimaryService()
if primaryServiceKey == "" { if primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key") return fmt.Errorf("couldn't find the primary service key")
} }
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
if err != nil { if err != nil {
return err return err
} }
@@ -196,27 +195,32 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
return nil return nil
} }
func (s *systemConfigurator) getPrimaryService() string { func (s *systemConfigurator) getPrimaryService() (string, string) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands) b, err := runSystemConfigCommand(stdinCommands)
if err != nil { if err != nil {
log.Error("got error while sending the command: ", err) log.Error("got error while sending the command: ", err)
return "" return "", ""
} }
scanner := bufio.NewScanner(bytes.NewReader(b)) scanner := bufio.NewScanner(bytes.NewReader(b))
primaryService := ""
router := ""
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
if strings.Contains(text, "PrimaryService") { if strings.Contains(text, "PrimaryService") {
return strings.TrimSpace(strings.Split(text, ":")[1]) primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
}
if strings.Contains(text, "Router") {
router = strings.TrimSpace(strings.Split(text, ":")[1])
} }
} }
return "" return primaryService, router
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error { func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines) addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand) stdinCommands := wrapCommand(addDomainCommand)

View File

@@ -4,6 +4,7 @@ package dns
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"strings" "strings"
@@ -59,7 +60,11 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
appendedDomains++ appendedDomains++
} }
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains) originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = r.applyConfig(content) err = r.applyConfig(content)
if err != nil { if err != nil {

View File

@@ -13,8 +13,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns"
) )
const ( const (
@@ -123,10 +121,6 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %s", err) return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: nbdns.RootZone,
MatchOnly: true,
})
s.routingAll = true s.routingAll = true
} else if s.routingAll { } else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)

View File

@@ -371,24 +371,24 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
} }
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
config := &server.Config{} loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, config) _, err := util.ReadJson(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if mgmtLetsencryptDomain != "" { if mgmtLetsencryptDomain != "" {
config.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain loadedConfig.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
} }
if mgmtDataDir != "" { if mgmtDataDir != "" {
config.Datadir = mgmtDataDir loadedConfig.Datadir = mgmtDataDir
} }
if certKey != "" && certFile != "" { if certKey != "" && certFile != "" {
config.HttpConfig.CertFile = certFile loadedConfig.HttpConfig.CertFile = certFile
config.HttpConfig.CertKey = certKey loadedConfig.HttpConfig.CertKey = certKey
} }
oidcEndpoint := config.HttpConfig.OIDCConfigEndpoint oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" { if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically // if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
@@ -399,45 +399,45 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, config.HttpConfig.AuthIssuer) oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
config.HttpConfig.AuthIssuer = oidcConfig.Issuer loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, config.HttpConfig.AuthKeysLocation) oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
config.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(config.DeviceAuthorizationFlow == nil || strings.ToLower(config.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint) oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
u, err := url.Parse(oidcEndpoint) u, err := url.Parse(oidcEndpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain) u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if config.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
config.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
} }
} }
if config.PKCEAuthorizationFlow != nil { if loadedConfig.PKCEAuthorizationFlow != nil {
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint) oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
} }
} }
return config, err return loadedConfig, err
} }
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response