Handle initial responses for AUTH PLAIN.
[mailpopbox.git] / smtp / conn.go
1 package smtp
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "crypto/tls"
7 "encoding/base64"
8 "fmt"
9 "net"
10 "net/mail"
11 "net/textproto"
12 "strings"
13 "time"
14
15 "github.com/uber-go/zap"
16 )
17
18 type state int
19
20 const (
21 stateNew state = iota // Before EHLO.
22 stateInitial
23 stateMail
24 stateRecipient
25 stateData
26 )
27
28 type delivery int
29
30 func (d delivery) String() string {
31 switch d {
32 case deliverUnknown:
33 return "unknown"
34 case deliverInbound:
35 return "inbound"
36 case deliverOutbound:
37 return "outbound"
38 }
39 panic("Unknown delivery")
40 }
41
42 const (
43 deliverUnknown delivery = iota
44 deliverInbound // Mail is not from one of this server's domains.
45 deliverOutbound // Mail IS from one of this server's domains.
46 )
47
48 type connection struct {
49 server Server
50
51 tp *textproto.Conn
52
53 nc net.Conn
54 remoteAddr net.Addr
55
56 esmtp bool
57 tls *tls.ConnectionState
58
59 // The authcid from a PLAIN SASL login. Non-empty iff tls is non-nil and
60 // doAUTH() succeeded.
61 authc string
62
63 log zap.Logger
64
65 state
66 line string
67
68 delivery
69 // For deliverOutbound, replaces the From and Reply-To values.
70 sendAs *mail.Address
71
72 ehlo string
73 mailFrom *mail.Address
74 rcptTo []mail.Address
75 }
76
77 func AcceptConnection(netConn net.Conn, server Server, log zap.Logger) {
78 conn := connection{
79 server: server,
80 tp: textproto.NewConn(netConn),
81 nc: netConn,
82 remoteAddr: netConn.RemoteAddr(),
83 log: log.With(zap.Stringer("client", netConn.RemoteAddr())),
84 state: stateNew,
85 }
86
87 conn.log.Info("accepted connection")
88 conn.writeReply(220, fmt.Sprintf("%s ESMTP [%s] (mailpopbox)",
89 server.Name(), netConn.LocalAddr()))
90
91 for {
92 var err error
93 conn.line, err = conn.tp.ReadLine()
94 if err != nil {
95 conn.log.Error("ReadLine()", zap.Error(err))
96 conn.tp.Close()
97 return
98 }
99
100 lineForLog := conn.line
101 const authPlain = "AUTH PLAIN "
102 if strings.HasPrefix(conn.line, authPlain) {
103 lineForLog = authPlain + "[redacted]"
104 }
105 conn.log.Info("ReadLine()", zap.String("line", lineForLog))
106
107 var cmd string
108 if _, err = fmt.Sscanf(conn.line, "%s", &cmd); err != nil {
109 conn.reply(ReplyBadSyntax)
110 continue
111 }
112
113 switch strings.ToUpper(cmd) {
114 case "QUIT":
115 conn.writeReply(221, "Goodbye")
116 conn.tp.Close()
117 return
118 case "HELO":
119 conn.esmtp = false
120 fallthrough
121 case "EHLO":
122 conn.esmtp = true
123 conn.doEHLO()
124 case "STARTTLS":
125 conn.doSTARTTLS()
126 case "AUTH":
127 conn.doAUTH()
128 case "MAIL":
129 conn.doMAIL()
130 case "RCPT":
131 conn.doRCPT()
132 case "DATA":
133 conn.doDATA()
134 case "RSET":
135 conn.doRSET()
136 case "VRFY":
137 conn.writeReply(252, "I'll do my best")
138 case "EXPN":
139 conn.writeReply(550, "access denied")
140 case "NOOP":
141 conn.reply(ReplyOK)
142 case "HELP":
143 conn.writeReply(250, "https://tools.ietf.org/html/rfc5321")
144 default:
145 conn.writeReply(500, "unrecognized command")
146 }
147 }
148 }
149
150 func (conn *connection) reply(reply ReplyLine) error {
151 return conn.writeReply(reply.Code, reply.Message)
152 }
153
154 func (conn *connection) writeReply(code int, msg string) error {
155 conn.log.Info("writeReply", zap.Int("code", code))
156 var err error
157 if len(msg) > 0 {
158 err = conn.tp.PrintfLine("%d %s", code, msg)
159 } else {
160 err = conn.tp.PrintfLine("%d", code)
161 }
162 if err != nil {
163 conn.log.Error("writeReply",
164 zap.Int("code", code),
165 zap.Error(err))
166 }
167 return err
168 }
169
170 // parsePath parses out either a forward-, reverse-, or return-path from the
171 // current connection line. Returns a (valid-path, ReplyOK) if it was
172 // successfully parsed.
173 func (conn *connection) parsePath(command string) (string, ReplyLine) {
174 if len(conn.line) < len(command) {
175 return "", ReplyBadSyntax
176 }
177 if strings.ToUpper(command) != strings.ToUpper(conn.line[:len(command)]) {
178 return "", ReplyLine{500, "unrecognized command"}
179 }
180 params := conn.line[len(command):]
181 idx := strings.Index(params, ">")
182 if idx == -1 {
183 return "", ReplyBadSyntax
184 }
185 return strings.ToLower(params[:idx+1]), ReplyOK
186 }
187
188 func (conn *connection) doEHLO() {
189 conn.resetBuffers()
190
191 var cmd string
192 _, err := fmt.Sscanf(conn.line, "%s %s", &cmd, &conn.ehlo)
193 if err != nil {
194 conn.reply(ReplyBadSyntax)
195 return
196 }
197
198 if cmd == "HELO" {
199 conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr))
200 } else {
201 conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr)
202 if conn.server.TLSConfig() != nil && conn.tls == nil {
203 conn.tp.PrintfLine("250-STARTTLS")
204 }
205 if conn.tls != nil {
206 conn.tp.PrintfLine("250-AUTH PLAIN")
207 }
208 conn.tp.PrintfLine("250 SIZE %d", 40960000)
209 }
210
211 conn.log.Info("doEHLO()", zap.String("ehlo", conn.ehlo))
212
213 conn.state = stateInitial
214 }
215
216 func (conn *connection) doSTARTTLS() {
217 if conn.state != stateInitial {
218 conn.reply(ReplyBadSequence)
219 return
220 }
221
222 tlsConfig := conn.server.TLSConfig()
223 if !conn.esmtp || tlsConfig == nil {
224 conn.writeReply(500, "unrecognized command")
225 return
226 }
227
228 conn.log.Info("doSTARTTLS()")
229 conn.writeReply(220, "initiate TLS connection")
230
231 tlsConn := tls.Server(conn.nc, tlsConfig)
232 if err := tlsConn.Handshake(); err != nil {
233 conn.log.Error("failed to do TLS handshake", zap.Error(err))
234 return
235 }
236
237 conn.nc = tlsConn
238 conn.tp = textproto.NewConn(tlsConn)
239 conn.state = stateNew
240
241 connState := tlsConn.ConnectionState()
242 conn.tls = &connState
243
244 conn.log.Info("TLS connection done", zap.String("state", conn.getTransportString()))
245 }
246
247 func (conn *connection) doAUTH() {
248 if conn.state != stateInitial || conn.tls == nil {
249 conn.reply(ReplyBadSequence)
250 return
251 }
252
253 if conn.authc != "" {
254 conn.writeReply(503, "already authenticated")
255 return
256 }
257
258 var cmd, authType, authString string
259 n, err := fmt.Sscanf(conn.line, "%s %s %s", &cmd, &authType, &authString)
260 if n < 2 {
261 conn.reply(ReplyBadSyntax)
262 return
263 }
264
265 if authType != "PLAIN" {
266 conn.writeReply(504, "unrecognized auth type")
267 return
268 }
269
270 // If only 2 tokens were scanned, then an initial response was not provided.
271 if n == 2 && conn.line[len(conn.line)-1] != ' ' {
272 conn.reply(ReplyBadSyntax)
273 return
274 }
275
276 conn.log.Info("doAUTH()")
277
278 if authString == "" {
279 conn.writeReply(334, " ")
280
281 authString, err = conn.tp.ReadLine()
282 if err != nil {
283 conn.log.Error("failed to read auth line", zap.Error(err))
284 conn.reply(ReplyBadSyntax)
285 return
286 }
287 }
288
289 authBytes, err := base64.StdEncoding.DecodeString(authString)
290 if err != nil {
291 conn.reply(ReplyBadSyntax)
292 return
293 }
294
295 authParts := strings.Split(string(authBytes), "\x00")
296 if len(authParts) != 3 {
297 conn.log.Error("bad auth line syntax")
298 conn.reply(ReplyBadSyntax)
299 return
300 }
301
302 if !conn.server.Authenticate(authParts[0], authParts[1], authParts[2]) {
303 conn.log.Error("failed to authenticate", zap.String("authc", authParts[1]))
304 conn.writeReply(535, "invalid credentials")
305 return
306 }
307
308 conn.log.Info("authenticated", zap.String("authz", authParts[0]), zap.String("authc", authParts[1]))
309 conn.authc = authParts[1]
310 conn.reply(ReplyOK)
311 }
312
313 func (conn *connection) doMAIL() {
314 if conn.state != stateInitial {
315 conn.reply(ReplyBadSequence)
316 return
317 }
318
319 mailFrom, reply := conn.parsePath("MAIL FROM:")
320 if reply != ReplyOK {
321 conn.reply(reply)
322 return
323 }
324
325 var err error
326 conn.mailFrom, err = mail.ParseAddress(mailFrom)
327 if err != nil || conn.mailFrom == nil {
328 conn.reply(ReplyBadSyntax)
329 return
330 }
331
332 if conn.server.VerifyAddress(*conn.mailFrom) == ReplyOK {
333 // Message is being sent from a domain that this is an MTA for. Ultimate
334 // handling of the outbound message requires knowing the recipient.
335 domain := DomainForAddress(*conn.mailFrom)
336 // TODO: better way to authenticate this?
337 if !strings.HasSuffix(conn.authc, "@"+domain) {
338 conn.writeReply(550, "not authenticated")
339 return
340 }
341 conn.delivery = deliverOutbound
342 } else {
343 conn.delivery = deliverInbound
344 }
345
346 conn.log.Info("doMAIL()", zap.String("address", conn.mailFrom.Address))
347
348 conn.state = stateMail
349 conn.reply(ReplyOK)
350 }
351
352 func (conn *connection) doRCPT() {
353 if conn.state != stateMail && conn.state != stateRecipient {
354 conn.reply(ReplyBadSequence)
355 return
356 }
357
358 rcptTo, reply := conn.parsePath("RCPT TO:")
359 if reply != ReplyOK {
360 conn.reply(reply)
361 return
362 }
363
364 address, err := mail.ParseAddress(rcptTo)
365 if err != nil {
366 conn.reply(ReplyBadSyntax)
367 return
368 }
369
370 if reply := conn.server.VerifyAddress(*address); reply == ReplyOK {
371 // Message is addressed to this server. If it's outbound, only support
372 // the special send-as addressing.
373 if conn.delivery == deliverOutbound {
374 if !strings.HasPrefix(address.Address, SendAsAddress) {
375 conn.log.Error("internal relay addressing not supported",
376 zap.String("address", address.Address))
377 conn.reply(ReplyBadMailbox)
378 return
379 }
380 address.Address = strings.TrimPrefix(address.Address, SendAsAddress)
381 if DomainForAddress(*address) != DomainForAddressString(conn.authc) {
382 conn.log.Error("not authenticated for send-as",
383 zap.String("address", address.Address),
384 zap.String("authc", conn.authc))
385 conn.reply(ReplyBadMailbox)
386 return
387 }
388 if conn.sendAs != nil {
389 conn.log.Error("sendAs already specified",
390 zap.String("address", address.Address),
391 zap.String("sendAs", conn.sendAs.Address))
392 conn.reply(ReplyMailboxUnallowed)
393 return
394 }
395 conn.log.Info("doRCPT()",
396 zap.String("sendAs", address.Address))
397 conn.sendAs = address
398 conn.state = stateRecipient
399 conn.reply(ReplyOK)
400 return
401 }
402 } else {
403 // Message is not addressed to this server, so the delivery must be outbound.
404 if conn.delivery == deliverInbound {
405 conn.log.Warn("invalid address",
406 zap.String("address", address.Address),
407 zap.Stringer("reply", reply))
408 conn.reply(reply)
409 return
410 }
411 }
412
413 conn.log.Info("doRCPT()",
414 zap.String("address", address.Address),
415 zap.String("delivery", conn.delivery.String()))
416
417 conn.rcptTo = append(conn.rcptTo, *address)
418
419 conn.state = stateRecipient
420 conn.reply(ReplyOK)
421 }
422
423 func (conn *connection) doDATA() {
424 if conn.state != stateRecipient {
425 conn.reply(ReplyBadSequence)
426 return
427 }
428
429 conn.writeReply(354, "Start mail input; end with <CRLF>.<CRLF>")
430 conn.log.Info("doDATA()")
431
432 data, err := conn.tp.ReadDotBytes()
433 if err != nil {
434 conn.log.Error("failed to ReadDotBytes()",
435 zap.Error(err),
436 zap.String("bytes", fmt.Sprintf("%x", data)))
437 conn.writeReply(552, "transaction failed")
438 return
439 }
440
441 conn.handleSendAs(&data)
442
443 received := time.Now()
444 env := Envelope{
445 RemoteAddr: conn.remoteAddr,
446 EHLO: conn.ehlo,
447 MailFrom: *conn.mailFrom,
448 RcptTo: conn.rcptTo,
449 Received: received,
450 ID: conn.envelopeID(received),
451 }
452
453 conn.log.Info("received message",
454 zap.Int("bytes", len(data)),
455 zap.Time("date", received),
456 zap.String("id", env.ID),
457 zap.String("delivery", conn.delivery.String()))
458
459 trace := conn.getReceivedInfo(env)
460
461 env.Data = append(trace, data...)
462
463 if conn.delivery == deliverInbound {
464 if reply := conn.server.OnMessageDelivered(env); reply != nil {
465 conn.log.Warn("message was rejected", zap.String("id", env.ID))
466 conn.reply(*reply)
467 return
468 }
469 } else if conn.delivery == deliverOutbound {
470 conn.server.RelayMessage(env)
471 }
472
473 conn.state = stateInitial
474 conn.resetBuffers()
475 conn.reply(ReplyOK)
476 }
477
478 func (conn *connection) handleSendAs(data *[]byte) {
479 if conn.delivery != deliverOutbound || conn.sendAs == nil {
480 return
481 }
482
483 conn.mailFrom = conn.sendAs
484
485 // Find the separator between the message header and body.
486 headerIdx := bytes.Index(*data, []byte("\n\n"))
487 if headerIdx == -1 {
488 conn.log.Error("send-as: could not find headers index")
489 return
490 }
491
492 fromPrefix := []byte("From: ")
493 fromIdx := bytes.Index(*data, fromPrefix)
494 if fromIdx == -1 || fromIdx >= headerIdx {
495 conn.log.Error("send-as: could not find From header")
496 return
497 }
498 if fromIdx != 0 {
499 if (*data)[fromIdx-1] != '\n' {
500 conn.log.Error("send-as: could not find From header")
501 return
502 }
503 }
504
505 fromEndIdx := bytes.IndexByte((*data)[fromIdx:], '\n')
506 if fromIdx == -1 {
507 conn.log.Error("send-as: could not find end of From header")
508 return
509 }
510 fromEndIdx += fromIdx
511
512 newData := (*data)[:fromIdx]
513 newData = append(newData, fromPrefix...)
514 newData = append(newData, []byte(conn.sendAs.String())...)
515 newData = append(newData, (*data)[fromEndIdx:]...)
516
517 *data = newData
518 }
519
520 func (conn *connection) envelopeID(t time.Time) string {
521 var idBytes [4]byte
522 rand.Read(idBytes[:])
523 return fmt.Sprintf("m.%d.%x", t.UnixNano(), idBytes)
524 }
525
526 func (conn *connection) getReceivedInfo(envelope Envelope) []byte {
527 rhost, _, err := net.SplitHostPort(conn.remoteAddr.String())
528 if err != nil {
529 rhost = conn.remoteAddr.String()
530 }
531
532 rhosts, err := net.LookupAddr(rhost)
533 if err == nil {
534 rhost = fmt.Sprintf("%s [%s]", rhosts[0], rhost)
535 }
536
537 base := fmt.Sprintf("Received: from %s (%s)\r\n ", conn.ehlo, rhost)
538
539 with := "SMTP"
540 if conn.esmtp {
541 with = "E" + with
542 }
543 if conn.tls != nil {
544 with += "S"
545 }
546 base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n ", conn.server.Name(), with, envelope.ID)
547
548 base += fmt.Sprintf("for <%s>\r\n ", envelope.RcptTo[0].Address)
549
550 transport := conn.getTransportString()
551 date := envelope.Received.Format(time.RFC1123Z) // Same as RFC 5322 ยง 3.3
552 base += fmt.Sprintf("(using %s);\r\n %s\r\n", transport, date)
553
554 return []byte(base)
555 }
556
557 func (conn *connection) getTransportString() string {
558 if conn.tls == nil {
559 return "PLAINTEXT"
560 }
561
562 ciphers := map[uint16]string{
563 tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
564 tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
565 tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
566 tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
567 tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
568 tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
569 tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
570 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
571 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
572 tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
573 tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
574 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
575 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
576 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
577 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
578 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
579 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
580 }
581 versions := map[uint16]string{
582 tls.VersionSSL30: "SSLv3.0",
583 tls.VersionTLS10: "TLSv1.0",
584 tls.VersionTLS11: "TLSv1.1",
585 tls.VersionTLS12: "TLSv1.2",
586 }
587
588 state := conn.tls
589
590 version := versions[state.Version]
591 cipher := ciphers[state.CipherSuite]
592
593 if version == "" {
594 version = fmt.Sprintf("%x", state.Version)
595 }
596 if cipher == "" {
597 cipher = fmt.Sprintf("%x", state.CipherSuite)
598 }
599
600 name := ""
601 if state.ServerName != "" {
602 name = fmt.Sprintf(" name=%s", state.ServerName)
603 }
604
605 return fmt.Sprintf("%s cipher=%s%s", version, cipher, name)
606 }
607
608 func (conn *connection) doRSET() {
609 conn.log.Info("doRSET()")
610 conn.state = stateInitial
611 conn.resetBuffers()
612 conn.reply(ReplyOK)
613 }
614
615 func (conn *connection) resetBuffers() {
616 conn.delivery = deliverUnknown
617 conn.sendAs = nil
618 conn.mailFrom = nil
619 conn.rcptTo = make([]mail.Address, 0)
620 }