diff --git a/util/file.go b/util/file.go index ff9a989b1..022841947 100644 --- a/util/file.go +++ b/util/file.go @@ -11,8 +11,7 @@ import ( // The output JSON is pretty-formatted func WriteJson(file string, obj interface{}) error { - configDir, configFileName := filepath.Split(file) - err := os.MkdirAll(configDir, 0750) + configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } @@ -100,3 +99,13 @@ func CopyFileContents(src, dst string) (err error) { err = out.Sync() return } + +func prepareConfigFileDir(file string) (string, string, error) { + configDir, configFileName := filepath.Split(file) + if configDir == "" { + return filepath.Dir(file), configFileName, nil + } + + err := os.MkdirAll(configDir, 0750) + return configDir, configFileName, err +} diff --git a/util/file_test.go b/util/file_test.go index 99a2708fb..3de7db49b 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -3,11 +3,13 @@ package util_test import ( "crypto/md5" "encoding/hex" - "github.com/netbirdio/netbird/util" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" "io" "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" ) var _ = Describe("Client", func() { @@ -102,4 +104,23 @@ var _ = Describe("Client", func() { }) }) }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) })