diff --git a/generate/go_client.tpl b/generate/go_client.tpl index 3f076aeb..cbdd14fb 100644 --- a/generate/go_client.tpl +++ b/generate/go_client.tpl @@ -2,8 +2,6 @@ package client import ( - "sync" - "connectrpc.com/connect" compress "github.com/klauspost/connect-compress/v2" @@ -22,8 +20,6 @@ type ( config *DialConfig interceptors []connect.Interceptor - - sync.Mutex } {{ range $name, $api := . -}} {{ $name | title }} interface { @@ -55,6 +51,11 @@ func New(config *DialConfig) (Client, error) { if config.Token != "" { authInterceptor := &authInterceptor{config: config} c.interceptors = append(c.interceptors, authInterceptor) + + if config.TokenRenewal != nil { + tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c} + c.interceptors = append(c.interceptors, tokenRenewingInterceptor) + } } if config.Log != nil { loggingInterceptor := &loggingInterceptor{config: config} @@ -62,9 +63,6 @@ func New(config *DialConfig) (Client, error) { } c.interceptors = append(c.interceptors, config.Interceptors...) - // TODO convert to interceptor - go c.startTokenRenewal() - return c, nil } diff --git a/go/client/client-interceptors.go b/go/client/client-interceptors.go index 0d9c678e..71782fe1 100644 --- a/go/client/client-interceptors.go +++ b/go/client/client-interceptors.go @@ -2,8 +2,14 @@ package client import ( "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" "connectrpc.com/connect" + apiv2models "github.com/metal-stack/api/go/metalstack/api/v2" ) // authinterceptor adds the required auth headers @@ -65,3 +71,79 @@ func (i *loggingInterceptor) WrapStreamingClient(next connect.StreamingClientFun func (i *loggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return next } + +type tokenRenewingInterceptor struct { + config *DialConfig + client *client + + renewing atomic.Bool + + sync.Mutex +} + +func (i *tokenRenewingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { + err := i.renewTokenIfNeeded() + if err != nil { + return nil, err + } + return next(ctx, request) + }) +} + +func (i *tokenRenewingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return next +} + +func (i *tokenRenewingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return next +} + +func (i *tokenRenewingInterceptor) renewTokenIfNeeded() error { + if i.config.expiresAt.IsZero() { + return nil + } + if i.renewing.Load() { + return nil + } + if i.config.Log == nil { + i.config.Log = slog.Default() + } + + replaceBefore := i.config.expiresAt.Sub(i.config.issuedAt) / tokenRenewChecksDuringLifetime + + if time.Until(i.config.expiresAt) > replaceBefore { + return nil + } + + i.renewing.Store(true) + defer i.renewing.Store(false) + + i.config.Log.Info("call token refresh, current token expires soon", "expires", i.config.expiresAt.String()) + + i.Lock() + defer i.Unlock() + + resp, err := i.client.Apiv2().Token().Refresh(context.Background(), &apiv2models.TokenServiceRefreshRequest{}) + if err != nil { + return fmt.Errorf("unable to refresh token %w", err) + } + + i.config.Token = resp.Secret + err = i.config.parse() + if err != nil { + return fmt.Errorf("unable to parse token %w", err) + } + + if i.config.TokenRenewal.PersistTokenFn == nil { + return nil + } + + err = i.config.TokenRenewal.PersistTokenFn(i.config.Token) + if err != nil { + return fmt.Errorf("unable to persist token %w", err) + } + + i.config.Log.Info("token refreshed, new token expires in", "expires", i.config.expiresAt.String()) + return nil +} diff --git a/go/client/client.go b/go/client/client.go index 374b5cc2..e4fcd1ce 100755 --- a/go/client/client.go +++ b/go/client/client.go @@ -2,8 +2,6 @@ package client import ( - "sync" - "connectrpc.com/connect" compress "github.com/klauspost/connect-compress/v2" @@ -22,8 +20,6 @@ type ( config *DialConfig interceptors []connect.Interceptor - - sync.Mutex } Adminv2 interface { Filesystem() adminv2connect.FilesystemServiceClient @@ -116,6 +112,11 @@ func New(config *DialConfig) (Client, error) { if config.Token != "" { authInterceptor := &authInterceptor{config: config} c.interceptors = append(c.interceptors, authInterceptor) + + if config.TokenRenewal != nil { + tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c} + c.interceptors = append(c.interceptors, tokenRenewingInterceptor) + } } if config.Log != nil { loggingInterceptor := &loggingInterceptor{config: config} @@ -123,9 +124,6 @@ func New(config *DialConfig) (Client, error) { } c.interceptors = append(c.interceptors, config.Interceptors...) - // TODO convert to interceptor - go c.startTokenRenewal() - return c, nil } diff --git a/go/client/client_test.go b/go/client/client_test.go index c5b1e3e4..e08e73ca 100644 --- a/go/client/client_test.go +++ b/go/client/client_test.go @@ -41,7 +41,7 @@ func Test_Client(t *testing.T) { server.Close() }() - tokenString, err := generateToken(1 * time.Second) + tokenString, err := generateToken(2 * time.Second) require.NoError(t, err) c, err := client.New(&client.DialConfig{ @@ -50,6 +50,7 @@ func Test_Client(t *testing.T) { Transport: server.Client().Transport, TokenRenewal: &client.TokenRenewal{ PersistTokenFn: func(token string) error { + ts.token = token t.Log("token persisted:", token) return nil }, @@ -64,7 +65,7 @@ func Test_Client(t *testing.T) { require.False(t, ts.wasCalled) require.Equal(t, tokenString, vs.token) - time.Sleep(300 * time.Millisecond) + time.Sleep(1 * time.Second) v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) require.NoError(t, err) require.NotNil(t, v) @@ -79,7 +80,7 @@ func Test_Client(t *testing.T) { require.Equal(t, "1.0", v.Version.Version) require.True(t, ts.wasCalled) - require.NotEqual(t, tokenString, vs.token, "token must have changed") + require.NotEqual(t, tokenString, ts.token, "token must have changed") } func generateToken(duration time.Duration) (string, error) { @@ -121,6 +122,7 @@ func (m *mockVersionService) Get(ctx context.Context, req *apiv2.VersionServiceG type mockTokenService struct { wasCalled bool + token string } // Create implements apiv2connect.TokenServiceHandler. diff --git a/go/client/conn.go b/go/client/conn.go index 38a4f2db..3c5aecb4 100644 --- a/go/client/conn.go +++ b/go/client/conn.go @@ -1,7 +1,6 @@ package client import ( - "context" "errors" "fmt" "log/slog" @@ -10,7 +9,6 @@ import ( "connectrpc.com/connect" "github.com/golang-jwt/jwt/v5" - api "github.com/metal-stack/api/go/metalstack/api/v2" ) const tokenRenewChecksDuringLifetime = 4 @@ -84,70 +82,3 @@ func (dc *DialConfig) parse() error { } return nil } - -func (c *client) startTokenRenewal() { - if c.config.TokenRenewal == nil { - return - } - if c.config.expiresAt.IsZero() { - return - } - if c.config.Log == nil { - c.config.Log = slog.Default() - } - - replaceBefore := c.config.expiresAt.Sub(c.config.issuedAt) / tokenRenewChecksDuringLifetime - - err := c.renewTokenIfNeeded(replaceBefore) - if err != nil { - c.config.Log.Error("unable to renew token", "error", err) - } - - ticker := time.NewTicker(replaceBefore) - defer ticker.Stop() - done := make(chan bool) - for { - select { - case <-done: - return - case <-ticker.C: - err := c.renewTokenIfNeeded(replaceBefore) - if err != nil { - c.config.Log.Error("unable to renew token", "error", err) - } - } - } -} - -func (c *client) renewTokenIfNeeded(replaceBefore time.Duration) error { - if time.Until(c.config.expiresAt) > replaceBefore { - return nil - } - c.config.Log.Info("call token refresh, current token expires soon", "expires", c.config.expiresAt.String()) - - c.Lock() - defer c.Unlock() - - resp, err := c.Apiv2().Token().Refresh(context.Background(), &api.TokenServiceRefreshRequest{}) - if err != nil { - return fmt.Errorf("unable to refresh token %w", err) - } - - c.config.Token = resp.Secret - err = c.config.parse() - if err != nil { - return fmt.Errorf("unable to parse token %w", err) - } - - if c.config.TokenRenewal.PersistTokenFn == nil { - return nil - } - - err = c.config.TokenRenewal.PersistTokenFn(c.config.Token) - if err != nil { - return fmt.Errorf("unable to persist token %w", err) - } - - c.config.Log.Info("token refreshed, new token expires in", "expires", c.config.expiresAt.String()) - return nil -}