From 99195aaae40e046143b72b47b9f730abcc26dece Mon Sep 17 00:00:00 2001 From: Robert Sesek Date: Sun, 28 Dec 2025 16:43:17 -0500 Subject: [PATCH] Move token storage into the OAuthServer --- cmd/mailbox-shuffler/config.go | 26 +++-- cmd/mailbox-shuffler/mailbox-shuffler.go | 57 +--------- cmd/mailbox-shuffler/oauth.go | 133 ++++++++++++++++++++--- 3 files changed, 139 insertions(+), 77 deletions(-) diff --git a/cmd/mailbox-shuffler/config.go b/cmd/mailbox-shuffler/config.go index 4620d37..3446aeb 100644 --- a/cmd/mailbox-shuffler/config.go +++ b/cmd/mailbox-shuffler/config.go @@ -18,6 +18,8 @@ const ( ServerTypeGmail ServerType = "gmail" ) +// ServerConfig stores the connection information for an email +// server. type ServerConfig struct { Type ServerType ServerAddr string @@ -28,22 +30,28 @@ type ServerConfig struct { Password string } +// MonitorConfig controls how to move messages between a source and +// destination email server. type MonitorConfig struct { Source ServerConfig Destination ServerConfig PollInterval time.Duration } -type Config struct { - Monitor []MonitorConfig +// OAuthServerConfig stores the configuration for an OAuth 2.0 +// application for authenticating to GMail. +type OAuthServerConfig struct { + RedirectURL string + ListenAddr string + CredentialsPath string + TokenStore string + TLSCertPath, TLSKeyPath string +} - OAuthServer struct { - RedirectURL string - ListenAddr string - CredentialsPath string - TokenStore string - TLSCertPath, TLSKeyPath string - } +// Config is the top-level config of mailbox-shuffler. +type Config struct { + Monitor []MonitorConfig + OAuthServer OAuthServerConfig } func (c *Config) Validate() error { diff --git a/cmd/mailbox-shuffler/mailbox-shuffler.go b/cmd/mailbox-shuffler/mailbox-shuffler.go index 194faea..ddaa324 100644 --- a/cmd/mailbox-shuffler/mailbox-shuffler.go +++ b/cmd/mailbox-shuffler/mailbox-shuffler.go @@ -11,13 +11,11 @@ import ( "encoding/base64" "encoding/json" "fmt" - "net/http" "os" "src.bluestatic.org/mailpopbox/pkg/version" "go.uber.org/zap" - "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/gmail/v1" "google.golang.org/api/option" @@ -74,12 +72,14 @@ func main() { } ctx := context.Background() - token, err := getToken(ctx, log, &config, oauthConfig) - if err != nil { - log.Fatal("Failed to get OAuth token", zap.Error(err)) + oauthServer := RunOAuthServer(ctx, config.OAuthServer, oauthConfig, log) + + resp := <-oauthServer.GetTokenForUser(ctx, config.Monitor[0].Destination.Email) + if resp.Error != nil { + log.Fatal("Failed to get OAuth token", zap.Error(resp.Error)) } - auth := option.WithHTTPClient(oauthConfig.Client(ctx, token)) + auth := option.WithHTTPClient(oauthConfig.Client(ctx, resp.Token)) client, err := gmail.NewService(ctx, auth, option.WithUserAgent("mailbox-shuffler")) if err != nil { log.Fatal("Failed to create GMail client", zap.Error(err)) @@ -94,48 +94,3 @@ func main() { result, err := call.Do() log.Info("Result", zap.Any("result", result), zap.Error(err)) } - -func getToken(ctx context.Context, log *zap.Logger, config *Config, oauthConfig *oauth2.Config) (*oauth2.Token, error) { - var token *oauth2.Token - f, err := os.Open(config.OAuthServer.TokenStore) - if err != nil && !os.IsNotExist(err) { - return nil, err - } else if f != nil { - defer f.Close() - if err = json.NewDecoder(f).Decode(&token); err != nil { - return nil, err - } else { - return token, nil - } - } - - srv := &http.Server{Addr: "localhost:8025"} - oauthConfig.RedirectURL = fmt.Sprintf("http://%s", srv.Addr) - - srvCtx, cancel := context.WithCancel(ctx) - s := RunOAuthServer(srvCtx, srv, oauthConfig, zap.L()) - - authURL, ch := s.AuthorizeToken() - fmt.Printf("Authorize the application at this URL:\n\t%s\n", authURL) - - code := <-ch - cancel() - - log.Info("Got code", zap.String("code", code)) - - token, err = oauthConfig.Exchange(ctx, code) - if err != nil { - return nil, err - } - - f, err = os.Create(config.OAuthServer.TokenStore) - if err != nil { - return token, err - } - defer f.Close() - if err := json.NewEncoder(f).Encode(token); err != nil { - return token, err - } - - return token, nil -} diff --git a/cmd/mailbox-shuffler/oauth.go b/cmd/mailbox-shuffler/oauth.go index 10449cf..14de28a 100644 --- a/cmd/mailbox-shuffler/oauth.go +++ b/cmd/mailbox-shuffler/oauth.go @@ -8,30 +8,87 @@ package main import ( "context" + "encoding/json" "fmt" "math/rand/v2" "net/http" + "os" "sync" "go.uber.org/zap" "golang.org/x/oauth2" ) -type OAuthServer struct { +type GetTokenForUserResult struct { + Token *oauth2.Token + Error error +} + +type OAuthServer interface { + GetTokenForUser(ctx context.Context, id string) <-chan GetTokenForUserResult +} + +type oauthServer struct { log *zap.Logger - c *oauth2.Config + sc OAuthServerConfig + o2c *oauth2.Config mu sync.Mutex tokenReqs map[string]chan<- string } -func RunOAuthServer(ctx context.Context, srv *http.Server, config *oauth2.Config, log *zap.Logger) *OAuthServer { - s := &OAuthServer{c: config, +const tokenStoreVersion = 1 + +type ( + tokenMap map[string]*oauth2.Token + + tokenStore struct { + Version int + Tokens tokenMap + } +) + +func readTokenStore(path string) (*tokenStore, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return &tokenStore{Version: tokenStoreVersion, Tokens: make(tokenMap)}, nil + } + return nil, err + } + defer f.Close() + var ts *tokenStore + if err := json.NewDecoder(f).Decode(&ts); err != nil { + return nil, err + } + if ts.Version != tokenStoreVersion { + return nil, fmt.Errorf("Invalid tokenStore version, got %d, expected %d", ts.Version, tokenStoreVersion) + } + return ts, nil +} + +func (ts *tokenStore) Save(path string) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return json.NewEncoder(f).Encode(ts) +} + +func RunOAuthServer(ctx context.Context, sc OAuthServerConfig, o2c *oauth2.Config, log *zap.Logger) OAuthServer { + o2c.RedirectURL = sc.RedirectURL + s := &oauthServer{ + sc: sc, + o2c: o2c, log: log, tokenReqs: make(map[string]chan<- string), } mux := http.NewServeMux() - mux.HandleFunc("GET /", s.handleRequest) - srv.Handler = mux + mux.HandleFunc("GET /{$}", s.handleRequest) + srv := &http.Server{ + Handler: mux, + Addr: sc.ListenAddr, + } go func() { log.Info("Starting OAuth server", zap.String("addr", srv.Addr)) err := srv.ListenAndServe() @@ -48,28 +105,70 @@ func RunOAuthServer(ctx context.Context, srv *http.Server, config *oauth2.Config return s } -func (s *OAuthServer) AuthorizeToken() (string, <-chan string) { - id := fmt.Sprintf("rd%d", rand.Int64()) - ch := make(chan string) +func (s *oauthServer) GetTokenForUser(ctx context.Context, userID string) <-chan GetTokenForUserResult { + ch := make(chan GetTokenForUserResult) - s.mu.Lock() - s.tokenReqs[id] = ch - s.mu.Unlock() + go func() { + s.mu.Lock() + defer s.mu.Unlock() + + ts, err := readTokenStore(s.sc.TokenStore) + if err != nil { + ch <- GetTokenForUserResult{Error: err} + return + } + token, ok := ts.Tokens[userID] + if ok { + ch <- GetTokenForUserResult{Token: token} + return + } + + // No token is stored, so put in a request. + nonce := fmt.Sprintf("rd%d", rand.Int64()) + codeCh := make(chan string) + s.tokenReqs[nonce] = codeCh + + url := s.o2c.AuthCodeURL(nonce, oauth2.AccessTypeOffline) + s.log.Info("Requesting authorization", zap.String("nonce", nonce), zap.String("url", url)) + + // Drop the lock until the code is received. + s.mu.Unlock() + code := <-codeCh + s.log.Info("Got code", zap.String("code", code)) + token, err = s.o2c.Exchange(ctx, code) + s.mu.Lock() + + if err != nil { + ch <- GetTokenForUserResult{Error: err} + return + } + + ts, err = readTokenStore(s.sc.TokenStore) + if err != nil { + ch <- GetTokenForUserResult{Error: err} + return + } + ts.Tokens[userID] = token + if err := ts.Save(s.sc.TokenStore); err != nil { + ch <- GetTokenForUserResult{Error: err} + return + } + + ch <- GetTokenForUserResult{Token: token} + }() - url := s.c.AuthCodeURL(id) - s.log.Info("Requesting authorization", zap.String("id", id), zap.String("url", url)) - return url, ch + return ch } -func (s *OAuthServer) handleRequest(rw http.ResponseWriter, req *http.Request) { +func (s *oauthServer) handleRequest(rw http.ResponseWriter, req *http.Request) { id := req.FormValue("state") s.mu.Lock() ch, ok := s.tokenReqs[id] if ok { delete(s.tokenReqs, id) + defer close(ch) } s.mu.Unlock() - defer close(ch) log := s.log.With(zap.String("id", id)) -- 2.43.5