From 7090804f9fc681c1097db825b4c4d4a3fc508b49 Mon Sep 17 00:00:00 2001 From: Robert Sesek Date: Sun, 18 Dec 2016 16:47:34 -0500 Subject: [PATCH] Do not keep around the old net.Conn when doing SMTP STARTTLS. --- smtp/conn.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/smtp/conn.go b/smtp/conn.go index a8a9a1f..ca24c10 100644 --- a/smtp/conn.go +++ b/smtp/conn.go @@ -29,10 +29,10 @@ type connection struct { tp *textproto.Conn nc net.Conn - tlsNc *tls.Conn remoteAddr net.Addr esmtp bool + tls *tls.ConnectionState log zap.Logger @@ -54,7 +54,9 @@ func AcceptConnection(netConn net.Conn, server Server, log zap.Logger) { state: stateNew, } - conn.writeReply(220, fmt.Sprintf("%s ESMTP [%s] (mailpopbox)", server.Name(), netConn.LocalAddr())) + conn.log.Info("accepted connection") + conn.writeReply(220, fmt.Sprintf("%s ESMTP [%s] (mailpopbox)", + server.Name(), netConn.LocalAddr())) for { var err error @@ -160,7 +162,7 @@ func (conn *connection) doEHLO() { conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr)) } else { conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr) - if conn.server.TLSConfig() != nil && conn.tlsNc == nil { + if conn.server.TLSConfig() != nil && conn.tls == nil { conn.tp.PrintfLine("250-STARTTLS") } conn.tp.PrintfLine("250 SIZE %d", 40960000) @@ -187,18 +189,23 @@ func (conn *connection) doSTARTTLS() { conn.writeReply(220, "initiate TLS connection") newConn := tls.Server(conn.nc, tlsConfig) - if err := newConn.Handshake(); err != nil { + tp := textproto.NewConn(newConn) + + err := tp.PrintfLine("220 %s ESMTPS [%s] (mailpopbox)", + conn.server.Name(), newConn.LocalAddr()) + if err != nil { + conn.log.Error("failed to do TLS handshake", zap.Error(err)) return } - conn.tlsNc = newConn - conn.tp = textproto.NewConn(conn.tlsNc) - conn.state = stateInitial + conn.nc = newConn + conn.tp = tp + conn.state = stateNew - conn.log.Info("HELO again") + connState := newConn.ConnectionState() + conn.tls = &connState - conn.writeReply(220, fmt.Sprintf("%s ESMTPS [%s] (mailpopbox)", - conn.server.Name(), newConn.LocalAddr())) + conn.log.Info("TLS connection done", zap.String("state", conn.getTransportString())) } func (conn *connection) doMAIL() { @@ -330,7 +337,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte { if conn.esmtp { with = "E" + with } - if conn.tlsNc != nil { + if conn.tls != nil { with += "S" } base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n ", conn.server.Name(), with, envelope.ID) @@ -345,7 +352,7 @@ func (conn *connection) getReceivedInfo(envelope Envelope) []byte { } func (conn *connection) getTransportString() string { - if conn.tlsNc == nil { + if conn.tls == nil { return "PLAINTEXT" } @@ -375,7 +382,7 @@ func (conn *connection) getTransportString() string { tls.VersionTLS12: "TLSv1.2", } - state := conn.tlsNc.ConnectionState() + state := conn.tls version := versions[state.Version] cipher := ciphers[state.CipherSuite] -- 2.22.5