Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions generate/go_client.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
package client

import (
"sync"

"connectrpc.com/connect"
compress "github.com/klauspost/connect-compress/v2"

Expand All @@ -22,8 +20,6 @@ type (
config *DialConfig

interceptors []connect.Interceptor

sync.Mutex
}
{{ range $name, $api := . -}}
{{ $name | title }} interface {
Expand Down Expand Up @@ -55,16 +51,18 @@ func New(config *DialConfig) (Client, error) {
if config.Token != "" {
authInterceptor := &authInterceptor{config: config}
c.interceptors = append(c.interceptors, authInterceptor)

if config.TokenRenewal != nil {
tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c}
c.interceptors = append(c.interceptors, tokenRenewingInterceptor)
}
}
if config.Log != nil {
loggingInterceptor := &loggingInterceptor{config: config}
c.interceptors = append(c.interceptors, loggingInterceptor)
}
c.interceptors = append(c.interceptors, config.Interceptors...)

// TODO convert to interceptor
go c.startTokenRenewal()

return c, nil
}

Expand Down
82 changes: 82 additions & 0 deletions go/client/client-interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ package client

import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"

"connectrpc.com/connect"
apiv2models "github.com/metal-stack/api/go/metalstack/api/v2"
)

// authinterceptor adds the required auth headers
Expand Down Expand Up @@ -65,3 +71,79 @@ func (i *loggingInterceptor) WrapStreamingClient(next connect.StreamingClientFun
func (i *loggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}

type tokenRenewingInterceptor struct {
config *DialConfig
client *client

renewing atomic.Bool

sync.Mutex
}

func (i *tokenRenewingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
err := i.renewTokenIfNeeded()
if err != nil {
return nil, err
}
return next(ctx, request)
})
}

func (i *tokenRenewingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}

func (i *tokenRenewingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}

func (i *tokenRenewingInterceptor) renewTokenIfNeeded() error {
if i.config.expiresAt.IsZero() {
return nil
}
if i.renewing.Load() {
return nil
}
if i.config.Log == nil {
i.config.Log = slog.Default()
}

replaceBefore := i.config.expiresAt.Sub(i.config.issuedAt) / tokenRenewChecksDuringLifetime

if time.Until(i.config.expiresAt) > replaceBefore {
return nil
}

i.renewing.Store(true)
defer i.renewing.Store(false)

i.config.Log.Info("call token refresh, current token expires soon", "expires", i.config.expiresAt.String())

i.Lock()
defer i.Unlock()

resp, err := i.client.Apiv2().Token().Refresh(context.Background(), &apiv2models.TokenServiceRefreshRequest{})
if err != nil {
return fmt.Errorf("unable to refresh token %w", err)
}

i.config.Token = resp.Secret
err = i.config.parse()
if err != nil {
return fmt.Errorf("unable to parse token %w", err)
}

if i.config.TokenRenewal.PersistTokenFn == nil {
return nil
}

err = i.config.TokenRenewal.PersistTokenFn(i.config.Token)
if err != nil {
return fmt.Errorf("unable to persist token %w", err)
}

i.config.Log.Info("token refreshed, new token expires in", "expires", i.config.expiresAt.String())
return nil
}
12 changes: 5 additions & 7 deletions go/client/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions go/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Test_Client(t *testing.T) {
server.Close()
}()

tokenString, err := generateToken(1 * time.Second)
tokenString, err := generateToken(2 * time.Second)
require.NoError(t, err)

c, err := client.New(&client.DialConfig{
Expand All @@ -50,6 +50,7 @@ func Test_Client(t *testing.T) {
Transport: server.Client().Transport,
TokenRenewal: &client.TokenRenewal{
PersistTokenFn: func(token string) error {
ts.token = token
t.Log("token persisted:", token)
return nil
},
Expand All @@ -64,7 +65,7 @@ func Test_Client(t *testing.T) {
require.False(t, ts.wasCalled)
require.Equal(t, tokenString, vs.token)

time.Sleep(300 * time.Millisecond)
time.Sleep(1 * time.Second)
v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
Expand All @@ -79,7 +80,7 @@ func Test_Client(t *testing.T) {
require.Equal(t, "1.0", v.Version.Version)

require.True(t, ts.wasCalled)
require.NotEqual(t, tokenString, vs.token, "token must have changed")
require.NotEqual(t, tokenString, ts.token, "token must have changed")
}

func generateToken(duration time.Duration) (string, error) {
Expand Down Expand Up @@ -121,6 +122,7 @@ func (m *mockVersionService) Get(ctx context.Context, req *apiv2.VersionServiceG

type mockTokenService struct {
wasCalled bool
token string
}

// Create implements apiv2connect.TokenServiceHandler.
Expand Down
69 changes: 0 additions & 69 deletions go/client/conn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package client

import (
"context"
"errors"
"fmt"
"log/slog"
Expand All @@ -10,7 +9,6 @@ import (

"connectrpc.com/connect"
"github.com/golang-jwt/jwt/v5"
api "github.com/metal-stack/api/go/metalstack/api/v2"
)

const tokenRenewChecksDuringLifetime = 4
Expand Down Expand Up @@ -84,70 +82,3 @@ func (dc *DialConfig) parse() error {
}
return nil
}

func (c *client) startTokenRenewal() {
if c.config.TokenRenewal == nil {
return
}
if c.config.expiresAt.IsZero() {
return
}
if c.config.Log == nil {
c.config.Log = slog.Default()
}

replaceBefore := c.config.expiresAt.Sub(c.config.issuedAt) / tokenRenewChecksDuringLifetime

err := c.renewTokenIfNeeded(replaceBefore)
if err != nil {
c.config.Log.Error("unable to renew token", "error", err)
}

ticker := time.NewTicker(replaceBefore)
defer ticker.Stop()
done := make(chan bool)
for {
select {
case <-done:
return
case <-ticker.C:
err := c.renewTokenIfNeeded(replaceBefore)
if err != nil {
c.config.Log.Error("unable to renew token", "error", err)
}
}
}
}

func (c *client) renewTokenIfNeeded(replaceBefore time.Duration) error {
if time.Until(c.config.expiresAt) > replaceBefore {
return nil
}
c.config.Log.Info("call token refresh, current token expires soon", "expires", c.config.expiresAt.String())

c.Lock()
defer c.Unlock()

resp, err := c.Apiv2().Token().Refresh(context.Background(), &api.TokenServiceRefreshRequest{})
if err != nil {
return fmt.Errorf("unable to refresh token %w", err)
}

c.config.Token = resp.Secret
err = c.config.parse()
if err != nil {
return fmt.Errorf("unable to parse token %w", err)
}

if c.config.TokenRenewal.PersistTokenFn == nil {
return nil
}

err = c.config.TokenRenewal.PersistTokenFn(c.config.Token)
if err != nil {
return fmt.Errorf("unable to persist token %w", err)
}

c.config.Log.Info("token refreshed, new token expires in", "expires", c.config.expiresAt.String())
return nil
}
Loading