Implement STARTTLS in the SMTP server.
authorRobert Sesek <rsesek@bluestatic.org>
Sun, 18 Dec 2016 01:32:12 +0000 (20:32 -0500)
committerRobert Sesek <rsesek@bluestatic.org>
Sun, 18 Dec 2016 01:32:12 +0000 (20:32 -0500)
smtp.go
smtp/conn.go

diff --git a/smtp.go b/smtp.go
index 24f635ff565c17fd45976668b6343d0a80c84702..d229765d23c4f022ab4a8ed92a6a45fef429a2f5 100644 (file)
--- a/smtp.go
+++ b/smtp.go
@@ -22,11 +22,19 @@ func runSMTPServer(config Config) <-chan error {
 }
 
 type smtpServer struct {
-       config Config
-       rc     chan error
+       config    Config
+       tlsConfig *tls.Config
+
+       rc chan error
 }
 
 func (server *smtpServer) run() {
+       var err error
+       server.tlsConfig, err = server.config.GetTLSConfig()
+       if err != nil {
+               server.rc <- err
+       }
+
        l, err := net.Listen("tcp", fmt.Sprintf(":%d", server.config.SMTPPort))
        if err != nil {
                server.rc <- err
@@ -49,7 +57,7 @@ func (server *smtpServer) Name() string {
 }
 
 func (server *smtpServer) TLSConfig() *tls.Config {
-       return nil
+       return server.tlsConfig
 }
 
 func (server *smtpServer) VerifyAddress(addr mail.Address) smtp.ReplyLine {
index bd9132613e12c1279031e90b75ddd6a936f8e59a..6a1a5a6660415b7e2f74e312e4b66f0045c3c512 100644 (file)
@@ -2,6 +2,7 @@ package smtp
 
 import (
        "crypto/rand"
+       "crypto/tls"
        "fmt"
        "net"
        "net/mail"
@@ -23,11 +24,13 @@ const (
 type connection struct {
        server Server
 
-       tp         *textproto.Conn
+       tp *textproto.Conn
+
+       nc         net.Conn
+       tlsNc      *tls.Conn
        remoteAddr net.Addr
 
        esmtp bool
-       tls   bool
 
        state
        line string
@@ -41,6 +44,7 @@ func AcceptConnection(netConn net.Conn, server Server) error {
        conn := connection{
                server:     server,
                tp:         textproto.NewConn(netConn),
+               nc:         netConn,
                remoteAddr: netConn.RemoteAddr(),
                state:      stateNew,
        }
@@ -73,6 +77,8 @@ func AcceptConnection(netConn net.Conn, server Server) error {
                case "EHLO":
                        conn.esmtp = true
                        conn.doEHLO()
+               case "STARTTLS":
+                       conn.doSTARTTLS()
                case "MAIL":
                        conn.doMAIL()
                case "RCPT":
@@ -97,15 +103,15 @@ func AcceptConnection(netConn net.Conn, server Server) error {
        return err
 }
 
-func (conn *connection) reply(reply ReplyLine) {
-       conn.writeReply(reply.Code, reply.Message)
+func (conn *connection) reply(reply ReplyLine) error {
+       return conn.writeReply(reply.Code, reply.Message)
 }
 
-func (conn *connection) writeReply(code int, msg string) {
+func (conn *connection) writeReply(code int, msg string) error {
        if len(msg) > 0 {
-               conn.tp.PrintfLine("%d %s", code, msg)
+               return conn.tp.PrintfLine("%d %s", code, msg)
        } else {
-               conn.tp.PrintfLine("%d", code)
+               return conn.tp.PrintfLine("%d", code)
        }
 }
 
@@ -136,7 +142,7 @@ func (conn *connection) doEHLO() {
                conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr))
        } else {
                conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr)
-               if conn.server.TLSConfig() != nil {
+               if conn.server.TLSConfig() != nil && conn.tlsNc == nil {
                        conn.tp.PrintfLine("250-STARTTLS")
                }
                conn.tp.PrintfLine("250 SIZE %d", 40960000)
@@ -145,6 +151,33 @@ func (conn *connection) doEHLO() {
        conn.state = stateInitial
 }
 
+func (conn *connection) doSTARTTLS() {
+       if conn.state != stateInitial {
+               conn.reply(ReplyBadSequence)
+               return
+       }
+
+       tlsConfig := conn.server.TLSConfig()
+       if !conn.esmtp || tlsConfig == nil {
+               conn.writeReply(500, "unrecognized command")
+               return
+       }
+
+       conn.writeReply(220, "initiate TLS connection")
+
+       newConn := tls.Server(conn.nc, tlsConfig)
+       if err := newConn.Handshake(); err != nil {
+               return
+       }
+
+       conn.tlsNc = newConn
+       conn.tp = textproto.NewConn(conn.tlsNc)
+       conn.state = stateInitial
+
+       conn.writeReply(220, fmt.Sprintf("%s ESMTPS [%s] (mailpopbox)",
+               conn.server.Name(), newConn.LocalAddr()))
+}
+
 func (conn *connection) doMAIL() {
        if conn.state != stateInitial {
                conn.reply(ReplyBadSequence)
@@ -257,7 +290,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte {
        if conn.esmtp {
                with = "E" + with
        }
-       if conn.tls {
+       if conn.tlsNc != nil {
                with += "S"
        }
        base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n        ", conn.server.Name(), with, envelope.ID)
@@ -265,7 +298,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte {
        base += fmt.Sprintf("for <%s>\r\n        ", envelope.RcptTo[0].Address)
 
        transport := "PLAINTEXT"
-       if conn.tls {
+       if conn.tlsNc != nil {
                // TODO: TLS version, cipher, bits
        }
        date := envelope.Received.Format(time.RFC1123Z) // Same as RFC 5322 ยง 3.3