From 51e57f75c4c64a28d5ab534b0ab5664c90cc8142 Mon Sep 17 00:00:00 2001 From: Robert Sesek Date: Sun, 9 Nov 2025 22:57:25 -0500 Subject: [PATCH] Refactor the OAuth server into a type --- cmd/mailbox-shuffler/mailbox-shuffler.go | 46 +++--------- cmd/mailbox-shuffler/oauth.go | 89 ++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 38 deletions(-) create mode 100644 cmd/mailbox-shuffler/oauth.go diff --git a/cmd/mailbox-shuffler/mailbox-shuffler.go b/cmd/mailbox-shuffler/mailbox-shuffler.go index 91544df..a6c4b96 100644 --- a/cmd/mailbox-shuffler/mailbox-shuffler.go +++ b/cmd/mailbox-shuffler/mailbox-shuffler.go @@ -12,10 +12,10 @@ import ( "encoding/json" "fmt" "log" - "math/rand/v2" "net/http" "os" + "go.uber.org/zap" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/gmail/v1" @@ -78,48 +78,18 @@ func getToken(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) } } - nonce := fmt.Sprintf("rd%d", rand.Int64()) - ch := make(chan string) + srv := &http.Server{Addr: "localhost:8025"} + config.RedirectURL = fmt.Sprintf("http://%s", srv.Addr) - mux := http.NewServeMux() - mux.HandleFunc("GET /", func(rw http.ResponseWriter, req *http.Request) { - if req.FormValue("state") != nonce { - log.Printf("Nonce mismatch, got %#v", req) - http.Error(rw, "", http.StatusBadRequest) - return - } - if code := req.FormValue("code"); code != "" { - fmt.Fprintln(rw, "

Authorized!

") - ch <- code - return - } - log.Printf("Invalid request - missing code: %#v", req) - http.Error(rw, "", http.StatusBadRequest) - }) - - const listen = "localhost:8025" - srv := http.Server{ - Addr: listen, - Handler: mux, - } - - config.RedirectURL = fmt.Sprintf("http://%s", listen) - authURL := config.AuthCodeURL(nonce) + srvCtx, cancel := context.WithCancel(ctx) + s := RunOAuthServer(srvCtx, srv, config, zap.L()) + authURL, ch := s.AuthorizeToken() log.Printf("Authorize the application at this URL:\n\t%s", authURL) - go func() { - log.Print("Starting OAuth token receiver") - err := srv.ListenAndServe() - if err == http.ErrServerClosed { - log.Print("Server stopped") - } else { - log.Printf("Error with server: %v", err) - } - }() - code := <-ch - srv.Shutdown(ctx) + cancel() + log.Printf("Got code: %q", code) token, err = config.Exchange(ctx, code) diff --git a/cmd/mailbox-shuffler/oauth.go b/cmd/mailbox-shuffler/oauth.go new file mode 100644 index 0000000..10449cf --- /dev/null +++ b/cmd/mailbox-shuffler/oauth.go @@ -0,0 +1,89 @@ +// mailpopbox +// Copyright 2025 Blue Static +// This program is free software licensed under the GNU General Public License, +// version 3.0. The full text of the license can be found in LICENSE.txt. +// SPDX-License-Identifier: GPL-3.0-only + +package main + +import ( + "context" + "fmt" + "math/rand/v2" + "net/http" + "sync" + + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +type OAuthServer struct { + log *zap.Logger + c *oauth2.Config + mu sync.Mutex + tokenReqs map[string]chan<- string +} + +func RunOAuthServer(ctx context.Context, srv *http.Server, config *oauth2.Config, log *zap.Logger) *OAuthServer { + s := &OAuthServer{c: config, + log: log, + tokenReqs: make(map[string]chan<- string), + } + mux := http.NewServeMux() + mux.HandleFunc("GET /", s.handleRequest) + srv.Handler = mux + go func() { + log.Info("Starting OAuth server", zap.String("addr", srv.Addr)) + err := srv.ListenAndServe() + if err == http.ErrServerClosed { + log.Info("Stopping OAuth server") + } else { + log.Error("ListenAndServe", zap.Error(err)) + } + }() + go func() { + <-ctx.Done() + srv.Close() + }() + return s +} + +func (s *OAuthServer) AuthorizeToken() (string, <-chan string) { + id := fmt.Sprintf("rd%d", rand.Int64()) + ch := make(chan string) + + s.mu.Lock() + s.tokenReqs[id] = ch + s.mu.Unlock() + + url := s.c.AuthCodeURL(id) + s.log.Info("Requesting authorization", zap.String("id", id), zap.String("url", url)) + return url, ch +} + +func (s *OAuthServer) handleRequest(rw http.ResponseWriter, req *http.Request) { + id := req.FormValue("state") + s.mu.Lock() + ch, ok := s.tokenReqs[id] + if ok { + delete(s.tokenReqs, id) + } + s.mu.Unlock() + defer close(ch) + + log := s.log.With(zap.String("id", id)) + + if !ok { + log.Error("No channel for token", zap.String("id", id)) + http.Error(rw, "Invalid State", http.StatusBadRequest) + return + } + if code := req.FormValue("code"); code != "" { + fmt.Fprintln(rw, "

Authorized!

") + log.Info("Received authorization code", zap.String("id", id)) + ch <- code + return + } + log.Error("Invalid request - missing code", zap.String("id", id)) + http.Error(rw, "Invalid Code", http.StatusBadRequest) +} -- 2.43.5