Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions api/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
package api

import (
"fmt"
"net/http"

"github.com/gin-gonic/gin"
"github.com/go-vela/worker/router/middleware/token"
"github.com/golang-jwt/jwt/v5"
)

// swagger:operation POST /register system Register
Expand Down Expand Up @@ -39,6 +41,13 @@ import (
// channel of the worker. This will unblock operation if the worker has not been
// registered and the provided registration token is valid.
func Register(c *gin.Context) {
// extract the worker hostname that was packed into gin context
w, ok := c.Get("worker-hostname")
if !ok {
c.JSON(http.StatusInternalServerError, "no worker hostname in the context")
return
}

// extract the register token channel that was packed into gin context
v, ok := c.Get("register-token")
if !ok {
Expand Down Expand Up @@ -68,8 +77,53 @@ func Register(c *gin.Context) {
return
}

// extract the subject from the token
sub, err := getSubjectFromToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, err)
return
}

// make sure we configured the hostname properly
hostname, ok := w.(string)
if !ok {
c.JSON(http.StatusInternalServerError, "worker hostname in the context is the wrong type")
return
}

// if the subject doesn't match the worker hostname return an error
if sub != hostname {
c.JSON(http.StatusUnauthorized, "worker hostname is invalid")
return
}

// write registration token to auth token channel
rChan <- token

c.JSON(http.StatusOK, "successfully passed token to worker")
}

// getSubjectFromToken is a helper function to extract
// the subject from the token claims.
func getSubjectFromToken(token string) (string, error) {
// create a new JWT parser
j := jwt.NewParser()

// parse the payload
t, _, err := j.ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return "", fmt.Errorf("unable to parse token")
}

sub, err := t.Claims.GetSubject()
if err != nil {
return "", fmt.Errorf("unable to get subject from token")
}

// make sure there was a subject defined
if len(sub) == 0 {
return "", fmt.Errorf("no subject defined in token")
}

return sub, nil
}
1 change: 1 addition & 0 deletions cmd/vela-worker/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (w *Worker) server() (http.Handler, *tls.Config) {
_server := router.Load(
middleware.RequestVersion,
middleware.ServerAddress(w.Config.Server.Address),
middleware.WorkerHostname(w.Config.API.Address.Hostname()),
middleware.Executors(w.Executors),
middleware.Logger(logrus.StandardLogger(), time.RFC3339, true),
middleware.RegisterToken(w.RegisterToken),
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/go-vela/sdk-go v0.19.0-rc1
github.com/go-vela/server v0.19.0-rc1
github.com/go-vela/types v0.19.0-rc1
github.com/golang-jwt/jwt/v5 v5.0.0
github.com/google/go-cmp v0.5.9
github.com/joho/godotenv v1.5.1
github.com/opencontainers/image-spec v1.0.2
Expand Down Expand Up @@ -59,7 +60,6 @@ require (
github.com/go-playground/validator/v10 v10.11.2 // indirect
github.com/goccy/go-json v0.10.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.0.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/gnostic v0.5.7-v3refs // indirect
github.com/google/go-github/v51 v51.0.0 // indirect
Expand Down
18 changes: 18 additions & 0 deletions router/middleware/worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2023 Target Brands, Inc. All rights reserved.
//
// Use of this source code is governed by the LICENSE file in this repository.

package middleware

import (
"github.com/gin-gonic/gin"
)

// WorkerHostname is a middleware function that attaches the
// worker hostname to the context of every http.Request.
func WorkerHostname(name string) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("worker-hostname", name)
c.Next()
}
}
46 changes: 46 additions & 0 deletions router/middleware/worker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2023 Target Brands, Inc. All rights reserved.
//
// Use of this source code is governed by the LICENSE file in this repository.

package middleware

import (
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/gin-gonic/gin"
)

func TestMiddleware_WorkerHostname(t *testing.T) {
// setup types
got := ""
want := "foobar"

// setup context
gin.SetMode(gin.TestMode)

resp := httptest.NewRecorder()
context, engine := gin.CreateTestContext(resp)
context.Request, _ = http.NewRequest(http.MethodGet, "/health", nil)

// setup mock server
engine.Use(WorkerHostname(want))
engine.GET("/health", func(c *gin.Context) {
got = c.Value("worker-hostname").(string)

c.Status(http.StatusOK)
})

// run test
engine.ServeHTTP(context.Writer, context.Request)

if resp.Code != http.StatusOK {
t.Errorf("WorkerHostname returned %v, want %v", resp.Code, http.StatusOK)
}

if !reflect.DeepEqual(got, want) {
t.Errorf("WorkerHostname is %v, want %v", got, want)
}
}