mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Feature/upload bundle (#3734)
Add an upload bundle option with the flag --upload-bundle; by default, the upload will use a NetBird address, which can be replaced using the flag --upload-bundle-url. The upload server is available under the /upload-server path. The release change will push a docker image to netbirdio/upload image repository. The server supports using s3 with pre-signed URL for direct upload and local file for storing bundles.
This commit is contained in:
124
upload-server/server/local.go
Normal file
124
upload-server/server/local.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDir = "/var/lib/netbird"
|
||||
putHandler = "/{dir}/{file}"
|
||||
)
|
||||
|
||||
type local struct {
|
||||
url string
|
||||
dir string
|
||||
}
|
||||
|
||||
func configureLocalHandlers(mux *http.ServeMux) error {
|
||||
envURL, ok := os.LookupEnv("SERVER_URL")
|
||||
if !ok {
|
||||
return fmt.Errorf("SERVER_URL environment variable is required")
|
||||
}
|
||||
_, err := url.Parse(envURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SERVER_URL environment variable is invalid: %w", err)
|
||||
}
|
||||
|
||||
dir := defaultDir
|
||||
envDir, ok := os.LookupEnv("STORE_DIR")
|
||||
if ok {
|
||||
if !filepath.IsAbs(envDir) {
|
||||
return fmt.Errorf("STORE_DIR environment variable should point to an absolute path, e.g. /tmp")
|
||||
}
|
||||
log.Infof("Using local directory: %s", envDir)
|
||||
dir = envDir
|
||||
}
|
||||
|
||||
l := &local{
|
||||
url: envURL,
|
||||
dir: dir,
|
||||
}
|
||||
mux.HandleFunc(types.GetURLPath, l.handlerGetUploadURL)
|
||||
mux.HandleFunc(putURLPath+putHandler, l.handlePutRequest)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *local) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) {
|
||||
if !isValidRequest(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
objectKey := getObjectKey(w, r)
|
||||
if objectKey == "" {
|
||||
return
|
||||
}
|
||||
|
||||
uploadURL, err := l.getUploadURL(objectKey)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get upload URL", http.StatusInternalServerError)
|
||||
log.Errorf("Failed to get upload URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
respondGetRequest(w, uploadURL, objectKey)
|
||||
}
|
||||
|
||||
func (l *local) getUploadURL(objectKey string) (string, error) {
|
||||
parsedUploadURL, err := url.Parse(l.url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse upload URL: %w", err)
|
||||
}
|
||||
newURL := parsedUploadURL.JoinPath(parsedUploadURL.Path, putURLPath, objectKey)
|
||||
return newURL.String(), nil
|
||||
}
|
||||
|
||||
func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to read body: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
uploadDir := r.PathValue("dir")
|
||||
if uploadDir == "" {
|
||||
http.Error(w, "missing dir path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
uploadFile := r.PathValue("file")
|
||||
if uploadFile == "" {
|
||||
http.Error(w, "missing file name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
dirPath := filepath.Join(l.dir, uploadDir)
|
||||
err = os.MkdirAll(dirPath, 0750)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create upload dir", http.StatusInternalServerError)
|
||||
log.Errorf("Failed to create upload dir: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
file := filepath.Join(dirPath, uploadFile)
|
||||
if err := os.WriteFile(file, body, 0600); err != nil {
|
||||
http.Error(w, "failed to write file", http.StatusInternalServerError)
|
||||
log.Errorf("Failed to write file %s: %v", file, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Uploading file %s", file)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
65
upload-server/server/local_test.go
Normal file
65
upload-server/server/local_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
func Test_LocalHandlerGetUploadURL(t *testing.T) {
|
||||
mockURL := "http://localhost:8080"
|
||||
t.Setenv("SERVER_URL", mockURL)
|
||||
t.Setenv("STORE_DIR", t.TempDir())
|
||||
|
||||
mux := http.NewServeMux()
|
||||
err := configureLocalHandlers(mux)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil)
|
||||
req.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response types.GetURLResponse
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, response.URL, "test-file/")
|
||||
require.NotEmpty(t, response.Key)
|
||||
require.Contains(t, response.Key, "test-file/")
|
||||
|
||||
}
|
||||
|
||||
func Test_LocalHandlePutRequest(t *testing.T) {
|
||||
mockDir := t.TempDir()
|
||||
mockURL := "http://localhost:8080"
|
||||
t.Setenv("SERVER_URL", mockURL)
|
||||
t.Setenv("STORE_DIR", mockDir)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
err := configureLocalHandlers(mux)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent := []byte("test file content")
|
||||
req := httptest.NewRequest(http.MethodPut, putURLPath+"/uploads/test.txt", bytes.NewReader(fileContent))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
expectedFilePath := filepath.Join(mockDir, "uploads", "test.txt")
|
||||
createdFileContent, err := os.ReadFile(expectedFilePath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
69
upload-server/server/s3.go
Normal file
69
upload-server/server/s3.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
type sThree struct {
|
||||
ctx context.Context
|
||||
bucket string
|
||||
presignClient *s3.PresignClient
|
||||
}
|
||||
|
||||
func configureS3Handlers(mux *http.ServeMux) error {
|
||||
bucket := os.Getenv(bucketVar)
|
||||
region, ok := os.LookupEnv("AWS_REGION")
|
||||
if !ok {
|
||||
return fmt.Errorf("AWS_REGION environment variable is required")
|
||||
}
|
||||
ctx := context.Background()
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to load SDK config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg)
|
||||
|
||||
handler := &sThree{
|
||||
ctx: ctx,
|
||||
bucket: bucket,
|
||||
presignClient: s3.NewPresignClient(client),
|
||||
}
|
||||
mux.HandleFunc(types.GetURLPath, handler.handlerGetUploadURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sThree) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) {
|
||||
if !isValidRequest(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
objectKey := getObjectKey(w, r)
|
||||
if objectKey == "" {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := s.presignClient.PresignPutObject(s.ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(objectKey),
|
||||
}, s3.WithPresignExpires(15*time.Minute))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "failed to presign URL", http.StatusInternalServerError)
|
||||
log.Errorf("Presign error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
respondGetRequest(w, req.URL, objectKey)
|
||||
}
|
||||
103
upload-server/server/s3_test.go
Normal file
103
upload-server/server/s3_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
func Test_S3HandlerGetUploadURL(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping test on non-Linux and CI environment due to docker dependency")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on Windows due to potential docker dependency")
|
||||
}
|
||||
|
||||
awsEndpoint := "http://127.0.0.1:4566"
|
||||
awsRegion := "us-east-1"
|
||||
|
||||
ctx := context.Background()
|
||||
containerRequest := testcontainers.ContainerRequest{
|
||||
Image: "localstack/localstack:s3-latest",
|
||||
ExposedPorts: []string{"4566:4566/tcp"},
|
||||
WaitingFor: wait.ForLog("Ready"),
|
||||
}
|
||||
|
||||
c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: containerRequest,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func(c testcontainers.Container, ctx context.Context) {
|
||||
if err := c.Terminate(ctx); err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}(c, ctx)
|
||||
|
||||
t.Setenv("AWS_REGION", awsRegion)
|
||||
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
o.UsePathStyle = true
|
||||
o.BaseEndpoint = cfg.BaseEndpoint
|
||||
})
|
||||
|
||||
bucketName := "test"
|
||||
if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
Bucket: &bucketName,
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(list.Buckets), 1)
|
||||
assert.Equal(t, *list.Buckets[0].Name, bucketName)
|
||||
|
||||
t.Setenv(bucketVar, bucketName)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
err = configureS3Handlers(mux)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil)
|
||||
req.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response types.GetURLResponse
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, response.URL, "test-file/")
|
||||
require.NotEmpty(t, response.Key)
|
||||
require.Contains(t, response.Key, "test-file/")
|
||||
}
|
||||
109
upload-server/server/server.go
Normal file
109
upload-server/server/server.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
putURLPath = "/upload"
|
||||
bucketVar = "BUCKET"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
srv *http.Server
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
address := os.Getenv("SERVER_ADDRESS")
|
||||
if address == "" {
|
||||
log.Infof("SERVER_ADDRESS environment variable was not set, using 0.0.0.0:8080")
|
||||
address = "0.0.0.0:8080"
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
err := configureMux(mux)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to configure server: %v", err)
|
||||
}
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
})
|
||||
|
||||
return &Server{
|
||||
srv: &http.Server{Addr: address, Handler: mux},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
log.Infof("Starting upload server on %s", s.srv.Addr)
|
||||
return s.srv.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
if s.srv != nil {
|
||||
log.Infof("Stopping upload server on %s", s.srv.Addr)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return s.srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureMux(mux *http.ServeMux) error {
|
||||
_, ok := os.LookupEnv(bucketVar)
|
||||
if ok {
|
||||
return configureS3Handlers(mux)
|
||||
} else {
|
||||
return configureLocalHandlers(mux)
|
||||
}
|
||||
}
|
||||
|
||||
func getObjectKey(w http.ResponseWriter, r *http.Request) string {
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "" {
|
||||
http.Error(w, "id query param required", http.StatusBadRequest)
|
||||
return ""
|
||||
}
|
||||
|
||||
return id + "/" + uuid.New().String()
|
||||
}
|
||||
|
||||
func isValidRequest(w http.ResponseWriter, r *http.Request) bool {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return false
|
||||
}
|
||||
|
||||
if r.Header.Get(types.ClientHeader) != types.ClientHeaderValue {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
func respondGetRequest(w http.ResponseWriter, uploadURL string, objectKey string) {
|
||||
response := types.GetURLResponse{
|
||||
URL: uploadURL,
|
||||
Key: objectKey,
|
||||
}
|
||||
|
||||
rdata, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to marshal response", http.StatusInternalServerError)
|
||||
log.Errorf("Marshal error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write(rdata)
|
||||
if err != nil {
|
||||
log.Errorf("Write error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user