Do not keep around the old net.Conn when doing SMTP STARTTLS.
[mailpopbox.git] / smtp / conn.go
1 package smtp
2
3 import (
4 "crypto/rand"
5 "crypto/tls"
6 "fmt"
7 "net"
8 "net/mail"
9 "net/textproto"
10 "strings"
11 "time"
12
13 "github.com/uber-go/zap"
14 )
15
16 type state int
17
18 const (
19 stateNew state = iota // Before EHLO.
20 stateInitial
21 stateMail
22 stateRecipient
23 stateData
24 )
25
26 type connection struct {
27 server Server
28
29 tp *textproto.Conn
30
31 nc net.Conn
32 remoteAddr net.Addr
33
34 esmtp bool
35 tls *tls.ConnectionState
36
37 log zap.Logger
38
39 state
40 line string
41
42 ehlo string
43 mailFrom *mail.Address
44 rcptTo []mail.Address
45 }
46
47 func AcceptConnection(netConn net.Conn, server Server, log zap.Logger) {
48 conn := connection{
49 server: server,
50 tp: textproto.NewConn(netConn),
51 nc: netConn,
52 remoteAddr: netConn.RemoteAddr(),
53 log: log.With(zap.Stringer("client", netConn.RemoteAddr())),
54 state: stateNew,
55 }
56
57 conn.log.Info("accepted connection")
58 conn.writeReply(220, fmt.Sprintf("%s ESMTP [%s] (mailpopbox)",
59 server.Name(), netConn.LocalAddr()))
60
61 for {
62 var err error
63 conn.line, err = conn.tp.ReadLine()
64 if err != nil {
65 conn.log.Error("ReadLine()", zap.Error(err))
66 conn.tp.Close()
67 return
68 }
69
70 conn.log.Info("ReadLine()", zap.String("line", conn.line))
71
72 var cmd string
73 if _, err = fmt.Sscanf(conn.line, "%s", &cmd); err != nil {
74 conn.reply(ReplyBadSyntax)
75 continue
76 }
77
78 switch strings.ToUpper(cmd) {
79 case "QUIT":
80 conn.writeReply(221, "Goodbye")
81 conn.tp.Close()
82 return
83 case "HELO":
84 conn.esmtp = false
85 fallthrough
86 case "EHLO":
87 conn.esmtp = true
88 conn.doEHLO()
89 case "STARTTLS":
90 conn.doSTARTTLS()
91 case "MAIL":
92 conn.doMAIL()
93 case "RCPT":
94 conn.doRCPT()
95 case "DATA":
96 conn.doDATA()
97 case "RSET":
98 conn.doRSET()
99 case "VRFY":
100 conn.writeReply(252, "I'll do my best")
101 case "EXPN":
102 conn.writeReply(550, "access denied")
103 case "NOOP":
104 conn.reply(ReplyOK)
105 case "HELP":
106 conn.writeReply(250, "https://tools.ietf.org/html/rfc5321")
107 default:
108 conn.writeReply(500, "unrecognized command")
109 }
110 }
111 }
112
113 func (conn *connection) reply(reply ReplyLine) error {
114 return conn.writeReply(reply.Code, reply.Message)
115 }
116
117 func (conn *connection) writeReply(code int, msg string) error {
118 conn.log.Info("writeReply", zap.Int("code", code))
119 var err error
120 if len(msg) > 0 {
121 err = conn.tp.PrintfLine("%d %s", code, msg)
122 } else {
123 err = conn.tp.PrintfLine("%d", code)
124 }
125 if err != nil {
126 conn.log.Error("writeReply",
127 zap.Int("code", code),
128 zap.Error(err))
129 }
130 return err
131 }
132
133 // parsePath parses out either a forward-, reverse-, or return-path from the
134 // current connection line. Returns a (valid-path, ReplyOK) if it was
135 // successfully parsed.
136 func (conn *connection) parsePath(command string) (string, ReplyLine) {
137 if len(conn.line) < len(command) {
138 return "", ReplyBadSyntax
139 }
140 if strings.ToUpper(command) != strings.ToUpper(conn.line[:len(command)]) {
141 return "", ReplyLine{500, "unrecognized command"}
142 }
143 params := conn.line[len(command):]
144 idx := strings.Index(params, ">")
145 if idx == -1 {
146 return "", ReplyBadSyntax
147 }
148 return params[:idx+1], ReplyOK
149 }
150
151 func (conn *connection) doEHLO() {
152 conn.resetBuffers()
153
154 var cmd string
155 _, err := fmt.Sscanf(conn.line, "%s %s", &cmd, &conn.ehlo)
156 if err != nil {
157 conn.reply(ReplyBadSyntax)
158 return
159 }
160
161 if cmd == "HELO" {
162 conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr))
163 } else {
164 conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr)
165 if conn.server.TLSConfig() != nil && conn.tls == nil {
166 conn.tp.PrintfLine("250-STARTTLS")
167 }
168 conn.tp.PrintfLine("250 SIZE %d", 40960000)
169 }
170
171 conn.log.Info("doEHLO()", zap.String("ehlo", conn.ehlo))
172
173 conn.state = stateInitial
174 }
175
176 func (conn *connection) doSTARTTLS() {
177 if conn.state != stateInitial {
178 conn.reply(ReplyBadSequence)
179 return
180 }
181
182 tlsConfig := conn.server.TLSConfig()
183 if !conn.esmtp || tlsConfig == nil {
184 conn.writeReply(500, "unrecognized command")
185 return
186 }
187
188 conn.log.Info("doSTARTTLS()")
189 conn.writeReply(220, "initiate TLS connection")
190
191 newConn := tls.Server(conn.nc, tlsConfig)
192 tp := textproto.NewConn(newConn)
193
194 err := tp.PrintfLine("220 %s ESMTPS [%s] (mailpopbox)",
195 conn.server.Name(), newConn.LocalAddr())
196 if err != nil {
197 conn.log.Error("failed to do TLS handshake", zap.Error(err))
198 return
199 }
200
201 conn.nc = newConn
202 conn.tp = tp
203 conn.state = stateNew
204
205 connState := newConn.ConnectionState()
206 conn.tls = &connState
207
208 conn.log.Info("TLS connection done", zap.String("state", conn.getTransportString()))
209 }
210
211 func (conn *connection) doMAIL() {
212 if conn.state != stateInitial {
213 conn.reply(ReplyBadSequence)
214 return
215 }
216
217 mailFrom, reply := conn.parsePath("MAIL FROM:")
218 if reply != ReplyOK {
219 conn.reply(reply)
220 return
221 }
222
223 var err error
224 conn.mailFrom, err = mail.ParseAddress(mailFrom)
225 if err != nil || conn.mailFrom == nil {
226 conn.reply(ReplyBadSyntax)
227 return
228 }
229
230 conn.log.Info("doMAIL()", zap.String("address", conn.mailFrom.Address))
231
232 conn.state = stateMail
233 conn.reply(ReplyOK)
234 }
235
236 func (conn *connection) doRCPT() {
237 if conn.state != stateMail && conn.state != stateRecipient {
238 conn.reply(ReplyBadSequence)
239 return
240 }
241
242 rcptTo, reply := conn.parsePath("RCPT TO:")
243 if reply != ReplyOK {
244 conn.reply(reply)
245 return
246 }
247
248 address, err := mail.ParseAddress(rcptTo)
249 if err != nil {
250 conn.reply(ReplyBadSyntax)
251 return
252 }
253
254 if reply := conn.server.VerifyAddress(*address); reply != ReplyOK {
255 conn.log.Warn("invalid address",
256 zap.String("address", address.Address),
257 zap.Stringer("reply", reply))
258 conn.reply(reply)
259 return
260 }
261
262 conn.log.Info("doRCPT()", zap.String("address", address.Address))
263
264 conn.rcptTo = append(conn.rcptTo, *address)
265
266 conn.state = stateRecipient
267 conn.reply(ReplyOK)
268 }
269
270 func (conn *connection) doDATA() {
271 if conn.state != stateRecipient {
272 conn.reply(ReplyBadSequence)
273 return
274 }
275
276 conn.writeReply(354, "Start mail input; end with <CRLF>.<CRLF>")
277 conn.log.Info("doDATA()")
278
279 data, err := conn.tp.ReadDotBytes()
280 if err != nil {
281 conn.log.Error("failed to ReadDotBytes()",
282 zap.Error(err),
283 zap.String("bytes", fmt.Sprintf("%x", data)))
284 conn.writeReply(552, "transaction failed")
285 return
286 }
287
288 received := time.Now()
289 env := Envelope{
290 RemoteAddr: conn.remoteAddr,
291 EHLO: conn.ehlo,
292 MailFrom: *conn.mailFrom,
293 RcptTo: conn.rcptTo,
294 Received: received,
295 ID: conn.envelopeID(received),
296 }
297
298 conn.log.Info("received message",
299 zap.Int("bytes", len(data)),
300 zap.Time("date", received),
301 zap.String("id", env.ID))
302
303 trace := conn.getReceivedInfo(env)
304
305 env.Data = append(trace, data...)
306
307 if reply := conn.server.OnMessageDelivered(env); reply != nil {
308 conn.log.Warn("message was rejected", zap.String("id", env.ID))
309 conn.reply(*reply)
310 return
311 }
312
313 conn.state = stateInitial
314 conn.reply(ReplyOK)
315 }
316
317 func (conn *connection) envelopeID(t time.Time) string {
318 var idBytes [4]byte
319 rand.Read(idBytes[:])
320 return fmt.Sprintf("m.%d.%x", t.UnixNano(), idBytes)
321 }
322
323 func (conn *connection) getReceivedInfo(envelope Envelope) []byte {
324 rhost, _, err := net.SplitHostPort(conn.remoteAddr.String())
325 if err != nil {
326 rhost = conn.remoteAddr.String()
327 }
328
329 rhosts, err := net.LookupAddr(rhost)
330 if err == nil {
331 rhost = fmt.Sprintf("%s [%s]", rhosts[0], rhost)
332 }
333
334 base := fmt.Sprintf("Received: from %s (%s)\r\n ", conn.ehlo, rhost)
335
336 with := "SMTP"
337 if conn.esmtp {
338 with = "E" + with
339 }
340 if conn.tls != nil {
341 with += "S"
342 }
343 base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n ", conn.server.Name(), with, envelope.ID)
344
345 base += fmt.Sprintf("for <%s>\r\n ", envelope.RcptTo[0].Address)
346
347 transport := conn.getTransportString()
348 date := envelope.Received.Format(time.RFC1123Z) // Same as RFC 5322 ยง 3.3
349 base += fmt.Sprintf("(using %s);\r\n %s\r\n", transport, date)
350
351 return []byte(base)
352 }
353
354 func (conn *connection) getTransportString() string {
355 if conn.tls == nil {
356 return "PLAINTEXT"
357 }
358
359 ciphers := map[uint16]string{
360 tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
361 tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
362 tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
363 tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
364 tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
365 tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
366 tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
367 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
368 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
369 tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
370 tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
371 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
372 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
373 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
374 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
375 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
376 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
377 }
378 versions := map[uint16]string{
379 tls.VersionSSL30: "SSLv3.0",
380 tls.VersionTLS10: "TLSv1.0",
381 tls.VersionTLS11: "TLSv1.1",
382 tls.VersionTLS12: "TLSv1.2",
383 }
384
385 state := conn.tls
386
387 version := versions[state.Version]
388 cipher := ciphers[state.CipherSuite]
389
390 if version == "" {
391 version = fmt.Sprintf("%x", state.Version)
392 }
393 if cipher == "" {
394 cipher = fmt.Sprintf("%x", state.CipherSuite)
395 }
396
397 name := ""
398 if state.ServerName != "" {
399 name = fmt.Sprintf(" name=%s", state.ServerName)
400 }
401
402 return fmt.Sprintf("%s cipher=%s%s", version, cipher, name)
403 }
404
405 func (conn *connection) doRSET() {
406 conn.log.Info("doRSET()")
407 conn.state = stateInitial
408 conn.resetBuffers()
409 conn.reply(ReplyOK)
410 }
411
412 func (conn *connection) resetBuffers() {
413 conn.mailFrom = nil
414 conn.rcptTo = make([]mail.Address, 0)
415 }