From bdfbbec54a6a0dbaa9f865017fbcba24c696b370 Mon Sep 17 00:00:00 2001 From: Robert Sesek Date: Sat, 17 Dec 2016 20:32:12 -0500 Subject: [PATCH] Implement STARTTLS in the SMTP server. --- smtp.go | 14 +++++++++++--- smtp/conn.go | 53 ++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/smtp.go b/smtp.go index 24f635f..d229765 100644 --- a/smtp.go +++ b/smtp.go @@ -22,11 +22,19 @@ func runSMTPServer(config Config) <-chan error { } type smtpServer struct { - config Config - rc chan error + config Config + tlsConfig *tls.Config + + rc chan error } func (server *smtpServer) run() { + var err error + server.tlsConfig, err = server.config.GetTLSConfig() + if err != nil { + server.rc <- err + } + l, err := net.Listen("tcp", fmt.Sprintf(":%d", server.config.SMTPPort)) if err != nil { server.rc <- err @@ -49,7 +57,7 @@ func (server *smtpServer) Name() string { } func (server *smtpServer) TLSConfig() *tls.Config { - return nil + return server.tlsConfig } func (server *smtpServer) VerifyAddress(addr mail.Address) smtp.ReplyLine { diff --git a/smtp/conn.go b/smtp/conn.go index bd91326..6a1a5a6 100644 --- a/smtp/conn.go +++ b/smtp/conn.go @@ -2,6 +2,7 @@ package smtp import ( "crypto/rand" + "crypto/tls" "fmt" "net" "net/mail" @@ -23,11 +24,13 @@ const ( type connection struct { server Server - tp *textproto.Conn + tp *textproto.Conn + + nc net.Conn + tlsNc *tls.Conn remoteAddr net.Addr esmtp bool - tls bool state line string @@ -41,6 +44,7 @@ func AcceptConnection(netConn net.Conn, server Server) error { conn := connection{ server: server, tp: textproto.NewConn(netConn), + nc: netConn, remoteAddr: netConn.RemoteAddr(), state: stateNew, } @@ -73,6 +77,8 @@ func AcceptConnection(netConn net.Conn, server Server) error { case "EHLO": conn.esmtp = true conn.doEHLO() + case "STARTTLS": + conn.doSTARTTLS() case "MAIL": conn.doMAIL() case "RCPT": @@ -97,15 +103,15 @@ func AcceptConnection(netConn net.Conn, server Server) error { return err } -func (conn *connection) reply(reply ReplyLine) { - conn.writeReply(reply.Code, reply.Message) +func (conn *connection) reply(reply ReplyLine) error { + return conn.writeReply(reply.Code, reply.Message) } -func (conn *connection) writeReply(code int, msg string) { +func (conn *connection) writeReply(code int, msg string) error { if len(msg) > 0 { - conn.tp.PrintfLine("%d %s", code, msg) + return conn.tp.PrintfLine("%d %s", code, msg) } else { - conn.tp.PrintfLine("%d", code) + return conn.tp.PrintfLine("%d", code) } } @@ -136,7 +142,7 @@ func (conn *connection) doEHLO() { conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr)) } else { conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr) - if conn.server.TLSConfig() != nil { + if conn.server.TLSConfig() != nil && conn.tlsNc == nil { conn.tp.PrintfLine("250-STARTTLS") } conn.tp.PrintfLine("250 SIZE %d", 40960000) @@ -145,6 +151,33 @@ func (conn *connection) doEHLO() { conn.state = stateInitial } +func (conn *connection) doSTARTTLS() { + if conn.state != stateInitial { + conn.reply(ReplyBadSequence) + return + } + + tlsConfig := conn.server.TLSConfig() + if !conn.esmtp || tlsConfig == nil { + conn.writeReply(500, "unrecognized command") + return + } + + conn.writeReply(220, "initiate TLS connection") + + newConn := tls.Server(conn.nc, tlsConfig) + if err := newConn.Handshake(); err != nil { + return + } + + conn.tlsNc = newConn + conn.tp = textproto.NewConn(conn.tlsNc) + conn.state = stateInitial + + conn.writeReply(220, fmt.Sprintf("%s ESMTPS [%s] (mailpopbox)", + conn.server.Name(), newConn.LocalAddr())) +} + func (conn *connection) doMAIL() { if conn.state != stateInitial { conn.reply(ReplyBadSequence) @@ -257,7 +290,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte { if conn.esmtp { with = "E" + with } - if conn.tls { + if conn.tlsNc != nil { with += "S" } base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n ", conn.server.Name(), with, envelope.ID) @@ -265,7 +298,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte { base += fmt.Sprintf("for <%s>\r\n ", envelope.RcptTo[0].Address) transport := "PLAINTEXT" - if conn.tls { + if conn.tlsNc != nil { // TODO: TLS version, cipher, bits } date := envelope.Received.Format(time.RFC1123Z) // Same as RFC 5322 § 3.3 -- 2.22.5