Implement AUTH PLAIN authentication extensions in SMTP.
authorRobert Sesek <rsesek@bluestatic.org>
Wed, 5 Sep 2018 04:59:11 +0000 (00:59 -0400)
committerRobert Sesek <rsesek@bluestatic.org>
Wed, 5 Sep 2018 04:59:11 +0000 (00:59 -0400)
This adds support for the AUTH extension (RFC 2554) and a single
SASL mechanism of PLAIN. This must be done after STARTTLS.

README.md
smtp.go
smtp/conn.go
smtp/conn_test.go
smtp/server.go
smtp_test.go [new file with mode: 0644]

index e424860261601dd2fefcfa06ff0259898611e493..03a343f7f6ed3aa3a15db331754b30c087ce59d1 100644 (file)
--- a/README.md
+++ b/README.md
@@ -16,4 +16,6 @@ This server implements the following RFCs:
 - [Post Office Protocol - Version 3, RFC 1939](https://tools.ietf.org/html/rfc1939)
 - [Simple Mail Transfer Protocol, RFC 5321](https://tools.ietf.org/html/rfc5321)
 - [SMTP Service Extension for Secure SMTP over Transport Layer Security, RFC 3207](https://tools.ietf.org/html/rfc3207)
+- [SMTP Service Extension for Authentication, RFC 2554](https://tools.ietf.org/html/rfc2554)
+- [The PLAIN Simple Authentication and Security Layer (SASL) Mechanism, RFC 4616](https://tools.ietf.org/html/rfc4616)
 - [POP3 Extension Mechanism, RFC 2449](https://tools.ietf.org/html/rfc2449)
diff --git a/smtp.go b/smtp.go
index 47747f3b602a9210948d892d947c90d166e61163..572e1f63bbbfec63cd47c2bee363986ecb527415 100644 (file)
--- a/smtp.go
+++ b/smtp.go
@@ -95,6 +95,30 @@ func (server *smtpServer) VerifyAddress(addr mail.Address) smtp.ReplyLine {
        return smtp.ReplyOK
 }
 
+func (server *smtpServer) Authenticate(authz, authc, passwd string) bool {
+       authcAddr, err := mail.ParseAddress(authc)
+       if err != nil {
+               return false
+       }
+
+       authzAddr, err := mail.ParseAddress(authz)
+       if authz != "" && err != nil {
+               return false
+       }
+
+       domain := smtp.DomainForAddress(*authcAddr)
+       for _, s := range server.config.Servers {
+               if domain == s.Domain {
+                       authOk := authc == MailboxAccount+s.Domain && passwd == s.MailboxPassword
+                       if authzAddr != nil {
+                               authOk = authOk && smtp.DomainForAddress(*authzAddr) == domain
+                       }
+                       return authOk
+               }
+       }
+       return false
+}
+
 func (server *smtpServer) OnMessageDelivered(en smtp.Envelope) *smtp.ReplyLine {
        maildrop := server.maildropForAddress(en.RcptTo[0])
        if maildrop == "" {
index f8ef4c15c2c1cab905bea2c8fa6901e64f4c196b..fd133d79c6bc1c6178b8c21bda4fff184da9b82c 100644 (file)
@@ -3,6 +3,7 @@ package smtp
 import (
        "crypto/rand"
        "crypto/tls"
+       "encoding/base64"
        "fmt"
        "net"
        "net/mail"
@@ -34,6 +35,10 @@ type connection struct {
        esmtp bool
        tls   *tls.ConnectionState
 
+       // The authcid from a PLAIN SASL login. Non-empty iff tls is non-nil and
+       // doAUTH() succeeded.
+       authc string
+
        log zap.Logger
 
        state
@@ -88,6 +93,8 @@ func AcceptConnection(netConn net.Conn, server Server, log zap.Logger) {
                        conn.doEHLO()
                case "STARTTLS":
                        conn.doSTARTTLS()
+               case "AUTH":
+                       conn.doAUTH()
                case "MAIL":
                        conn.doMAIL()
                case "RCPT":
@@ -165,6 +172,9 @@ func (conn *connection) doEHLO() {
                if conn.server.TLSConfig() != nil && conn.tls == nil {
                        conn.tp.PrintfLine("250-STARTTLS")
                }
+               if conn.tls != nil {
+                       conn.tp.PrintfLine("250-AUTH PLAIN")
+               }
                conn.tp.PrintfLine("250 SIZE %d", 40960000)
        }
 
@@ -204,6 +214,64 @@ func (conn *connection) doSTARTTLS() {
        conn.log.Info("TLS connection done", zap.String("state", conn.getTransportString()))
 }
 
+func (conn *connection) doAUTH() {
+       if conn.state != stateInitial || conn.tls == nil {
+               conn.reply(ReplyBadSequence)
+               return
+       }
+
+       if conn.authc != "" {
+               conn.writeReply(503, "already authenticated")
+               return
+       }
+
+       var cmd, authType string
+       _, err := fmt.Sscanf(conn.line, "%s %s", &cmd, &authType)
+       if err != nil {
+               conn.reply(ReplyBadSyntax)
+               return
+       }
+
+       if authType != "PLAIN" {
+               conn.writeReply(504, "unrecognized auth type")
+               return
+       }
+
+       conn.log.Info("doAUTH()")
+
+       conn.writeReply(334, " ")
+
+       authLine, err := conn.tp.ReadLine()
+       if err != nil {
+               conn.log.Error("failed to read auth line", zap.Error(err))
+               conn.reply(ReplyBadSyntax)
+               return
+       }
+
+       authBytes, err := base64.StdEncoding.DecodeString(authLine)
+       if err != nil {
+               conn.reply(ReplyBadSyntax)
+               return
+       }
+
+       authParts := strings.Split(string(authBytes), "\x00")
+       if len(authParts) != 3 {
+               conn.log.Error("bad auth line syntax")
+               conn.reply(ReplyBadSyntax)
+               return
+       }
+
+       if !conn.server.Authenticate(authParts[0], authParts[1], authParts[2]) {
+               conn.log.Error("failed to authenticate", zap.String("authc", authParts[1]))
+               conn.writeReply(535, "invalid credentials")
+               return
+       }
+
+       conn.log.Info("authenticated", zap.String("authz", authParts[0]), zap.String("authc", authParts[1]))
+       conn.authc = authParts[1]
+       conn.reply(ReplyOK)
+}
+
 func (conn *connection) doMAIL() {
        if conn.state != stateInitial {
                conn.reply(ReplyBadSequence)
index 85dcb45a2dbd784f76208ccb9b02ed20d133c1dc..d6275c9c15ed1f556779fb249d076bc063bb0d61 100644 (file)
@@ -2,6 +2,7 @@ package smtp
 
 import (
        "crypto/tls"
+       "encoding/base64"
        "fmt"
        "net"
        "net/mail"
@@ -56,10 +57,15 @@ func runServer(t *testing.T, server Server) net.Listener {
        return l
 }
 
+type userAuth struct {
+       authz, authc, passwd string
+}
+
 type testServer struct {
        EmptyServerCallbacks
        blockList []string
        tlsConfig *tls.Config
+       *userAuth
 }
 
 func (s *testServer) Name() string {
@@ -79,6 +85,12 @@ func (s *testServer) VerifyAddress(addr mail.Address) ReplyLine {
        return ReplyOK
 }
 
+func (s *testServer) Authenticate(authz, authc, passwd string) bool {
+       return s.userAuth.authz == authz &&
+               s.userAuth.authc == authc &&
+               s.userAuth.passwd == passwd
+}
+
 func createClient(t *testing.T, addr net.Addr) *textproto.Conn {
        conn, err := textproto.Dial(addr.Network(), addr.String())
        if err != nil {
@@ -295,11 +307,8 @@ func getTLSConfig(t *testing.T) *tls.Config {
        }
 }
 
-func TestTLS(t *testing.T) {
-       l := runServer(t, &testServer{tlsConfig: getTLSConfig(t)})
-       defer l.Close()
-
-       nc, err := net.Dial(l.Addr().Network(), l.Addr().String())
+func setupTLSClient(t *testing.T, addr net.Addr) *textproto.Conn {
+       nc, err := net.Dial(addr.Network(), addr.String())
        ok(t, err)
 
        conn := textproto.NewConn(nc)
@@ -327,4 +336,62 @@ func TestTLS(t *testing.T) {
        if strings.Contains(resp, "STARTTLS\n") {
                t.Errorf("STARTTLS advertised when already started")
        }
+
+       return conn
+}
+
+func TestTLS(t *testing.T) {
+       l := runServer(t, &testServer{tlsConfig: getTLSConfig(t)})
+       defer l.Close()
+
+       setupTLSClient(t, l.Addr())
+}
+
+func TestAuthWithoutTLS(t *testing.T) {
+       l := runServer(t, &testServer{})
+       defer l.Close()
+
+       conn := createClient(t, l.Addr())
+       readCodeLine(t, conn, 220)
+
+       ok(t, conn.PrintfLine("EHLO test"))
+       _, resp, err := conn.ReadResponse(250)
+       ok(t, err)
+
+       if strings.Contains(resp, "AUTH") {
+               t.Errorf("AUTH should not be advertised over plaintext")
+       }
+}
+
+func TestAuth(t *testing.T) {
+       l := runServer(t, &testServer{
+               tlsConfig: getTLSConfig(t),
+               userAuth: &userAuth{
+                       authz:  "-authz-",
+                       authc:  "-authc-",
+                       passwd: "goats",
+               },
+       })
+       defer l.Close()
+
+       conn := setupTLSClient(t, l.Addr())
+
+       b64enc := func(s string) string {
+               return string(base64.StdEncoding.EncodeToString([]byte(s)))
+       }
+
+       runTableTest(t, conn, []requestResponse{
+               {"AUTH", 501, nil},
+               {"AUTH OAUTHBEARER", 504, nil},
+               {"AUTH PLAIN", 334, nil},
+               {b64enc("abc\x00def\x00ghf"), 535, nil},
+               {"AUTH PLAIN", 334, nil},
+               {b64enc("\x00"), 501, nil},
+               {"AUTH PLAIN", 334, nil},
+               {"this isn't base 64", 501, nil},
+               {"AUTH PLAIN", 334, nil},
+               {b64enc("-authz-\x00-authc-\x00goats"), 250, nil},
+               {"AUTH PLAIN", 503, nil}, // already authenticated
+               {"NOOP", 250, nil},
+       })
 }
index 30a65bd0fc7a4107df04198dadeb9a9d41161927..f9ecb1d2110a9a86b8152040b3229dde560b3408 100644 (file)
@@ -54,6 +54,8 @@ type Server interface {
        Name() string
        TLSConfig() *tls.Config
        VerifyAddress(mail.Address) ReplyLine
+       // Verify that the authc+passwd identity can send mail as authz.
+       Authenticate(authz, authc, passwd string) bool
        OnMessageDelivered(Envelope) *ReplyLine
 }
 
@@ -67,6 +69,10 @@ func (*EmptyServerCallbacks) VerifyAddress(mail.Address) ReplyLine {
        return ReplyOK
 }
 
+func (*EmptyServerCallbacks) Authenticate(authz, authc, passwd string) bool {
+       return false
+}
+
 func (*EmptyServerCallbacks) OnMessageDelivered(Envelope) *ReplyLine {
        return nil
 }
diff --git a/smtp_test.go b/smtp_test.go
new file mode 100644 (file)
index 0000000..2eef868
--- /dev/null
@@ -0,0 +1,43 @@
+package main
+
+import (
+       "testing"
+)
+
+var testConfig = Config{
+       Servers: []Server{
+               Server{
+                       Domain:          "domain1.net",
+                       MailboxPassword: "d1",
+               },
+               Server{
+                       Domain:          "domain2.xyz",
+                       MailboxPassword: "d2",
+               },
+       },
+}
+
+func TestAuthenticate(t *testing.T) {
+       server := smtpServer{config: testConfig}
+
+       authTests := []struct {
+               authz, authc, passwd string
+               ok                   bool
+       }{
+               {"foo@domain1.net", "mailbox@domain1.net", "d1", true},
+               {"", "mailbox@domain1.net", "d1", true},
+               {"foo@domain2.xyz", "mailbox@domain1.xyz", "d1", false},
+               {"foo@domain2.xyz", "mailbox@domain1.xyz", "d2", false},
+               {"foo@domain2.xyz", "mailbox@domain2.xyz", "d2", true},
+               {"invalid", "mailbox@domain2.xyz", "d2", false},
+               {"", "mailbox@domain2.xyz", "d2", true},
+               {"", "", "", false},
+       }
+
+       for i, test := range authTests {
+               actual := server.Authenticate(test.authz, test.authc, test.passwd)
+               if actual != test.ok {
+                       t.Errorf("Test %d, got %v, expected %v", i, actual, test.ok)
+               }
+       }
+}