Implement AUTH PLAIN authentication extensions in SMTP.
[mailpopbox.git] / smtp / conn.go
1 package smtp
2
3 import (
4 "crypto/rand"
5 "crypto/tls"
6 "encoding/base64"
7 "fmt"
8 "net"
9 "net/mail"
10 "net/textproto"
11 "strings"
12 "time"
13
14 "github.com/uber-go/zap"
15 )
16
17 type state int
18
19 const (
20 stateNew state = iota // Before EHLO.
21 stateInitial
22 stateMail
23 stateRecipient
24 stateData
25 )
26
27 type connection struct {
28 server Server
29
30 tp *textproto.Conn
31
32 nc net.Conn
33 remoteAddr net.Addr
34
35 esmtp bool
36 tls *tls.ConnectionState
37
38 // The authcid from a PLAIN SASL login. Non-empty iff tls is non-nil and
39 // doAUTH() succeeded.
40 authc string
41
42 log zap.Logger
43
44 state
45 line string
46
47 ehlo string
48 mailFrom *mail.Address
49 rcptTo []mail.Address
50 }
51
52 func AcceptConnection(netConn net.Conn, server Server, log zap.Logger) {
53 conn := connection{
54 server: server,
55 tp: textproto.NewConn(netConn),
56 nc: netConn,
57 remoteAddr: netConn.RemoteAddr(),
58 log: log.With(zap.Stringer("client", netConn.RemoteAddr())),
59 state: stateNew,
60 }
61
62 conn.log.Info("accepted connection")
63 conn.writeReply(220, fmt.Sprintf("%s ESMTP [%s] (mailpopbox)",
64 server.Name(), netConn.LocalAddr()))
65
66 for {
67 var err error
68 conn.line, err = conn.tp.ReadLine()
69 if err != nil {
70 conn.log.Error("ReadLine()", zap.Error(err))
71 conn.tp.Close()
72 return
73 }
74
75 conn.log.Info("ReadLine()", zap.String("line", conn.line))
76
77 var cmd string
78 if _, err = fmt.Sscanf(conn.line, "%s", &cmd); err != nil {
79 conn.reply(ReplyBadSyntax)
80 continue
81 }
82
83 switch strings.ToUpper(cmd) {
84 case "QUIT":
85 conn.writeReply(221, "Goodbye")
86 conn.tp.Close()
87 return
88 case "HELO":
89 conn.esmtp = false
90 fallthrough
91 case "EHLO":
92 conn.esmtp = true
93 conn.doEHLO()
94 case "STARTTLS":
95 conn.doSTARTTLS()
96 case "AUTH":
97 conn.doAUTH()
98 case "MAIL":
99 conn.doMAIL()
100 case "RCPT":
101 conn.doRCPT()
102 case "DATA":
103 conn.doDATA()
104 case "RSET":
105 conn.doRSET()
106 case "VRFY":
107 conn.writeReply(252, "I'll do my best")
108 case "EXPN":
109 conn.writeReply(550, "access denied")
110 case "NOOP":
111 conn.reply(ReplyOK)
112 case "HELP":
113 conn.writeReply(250, "https://tools.ietf.org/html/rfc5321")
114 default:
115 conn.writeReply(500, "unrecognized command")
116 }
117 }
118 }
119
120 func (conn *connection) reply(reply ReplyLine) error {
121 return conn.writeReply(reply.Code, reply.Message)
122 }
123
124 func (conn *connection) writeReply(code int, msg string) error {
125 conn.log.Info("writeReply", zap.Int("code", code))
126 var err error
127 if len(msg) > 0 {
128 err = conn.tp.PrintfLine("%d %s", code, msg)
129 } else {
130 err = conn.tp.PrintfLine("%d", code)
131 }
132 if err != nil {
133 conn.log.Error("writeReply",
134 zap.Int("code", code),
135 zap.Error(err))
136 }
137 return err
138 }
139
140 // parsePath parses out either a forward-, reverse-, or return-path from the
141 // current connection line. Returns a (valid-path, ReplyOK) if it was
142 // successfully parsed.
143 func (conn *connection) parsePath(command string) (string, ReplyLine) {
144 if len(conn.line) < len(command) {
145 return "", ReplyBadSyntax
146 }
147 if strings.ToUpper(command) != strings.ToUpper(conn.line[:len(command)]) {
148 return "", ReplyLine{500, "unrecognized command"}
149 }
150 params := conn.line[len(command):]
151 idx := strings.Index(params, ">")
152 if idx == -1 {
153 return "", ReplyBadSyntax
154 }
155 return strings.ToLower(params[:idx+1]), ReplyOK
156 }
157
158 func (conn *connection) doEHLO() {
159 conn.resetBuffers()
160
161 var cmd string
162 _, err := fmt.Sscanf(conn.line, "%s %s", &cmd, &conn.ehlo)
163 if err != nil {
164 conn.reply(ReplyBadSyntax)
165 return
166 }
167
168 if cmd == "HELO" {
169 conn.writeReply(250, fmt.Sprintf("Hello %s [%s]", conn.ehlo, conn.remoteAddr))
170 } else {
171 conn.tp.PrintfLine("250-Hello %s [%s]", conn.ehlo, conn.remoteAddr)
172 if conn.server.TLSConfig() != nil && conn.tls == nil {
173 conn.tp.PrintfLine("250-STARTTLS")
174 }
175 if conn.tls != nil {
176 conn.tp.PrintfLine("250-AUTH PLAIN")
177 }
178 conn.tp.PrintfLine("250 SIZE %d", 40960000)
179 }
180
181 conn.log.Info("doEHLO()", zap.String("ehlo", conn.ehlo))
182
183 conn.state = stateInitial
184 }
185
186 func (conn *connection) doSTARTTLS() {
187 if conn.state != stateInitial {
188 conn.reply(ReplyBadSequence)
189 return
190 }
191
192 tlsConfig := conn.server.TLSConfig()
193 if !conn.esmtp || tlsConfig == nil {
194 conn.writeReply(500, "unrecognized command")
195 return
196 }
197
198 conn.log.Info("doSTARTTLS()")
199 conn.writeReply(220, "initiate TLS connection")
200
201 tlsConn := tls.Server(conn.nc, tlsConfig)
202 if err := tlsConn.Handshake(); err != nil {
203 conn.log.Error("failed to do TLS handshake", zap.Error(err))
204 return
205 }
206
207 conn.nc = tlsConn
208 conn.tp = textproto.NewConn(tlsConn)
209 conn.state = stateNew
210
211 connState := tlsConn.ConnectionState()
212 conn.tls = &connState
213
214 conn.log.Info("TLS connection done", zap.String("state", conn.getTransportString()))
215 }
216
217 func (conn *connection) doAUTH() {
218 if conn.state != stateInitial || conn.tls == nil {
219 conn.reply(ReplyBadSequence)
220 return
221 }
222
223 if conn.authc != "" {
224 conn.writeReply(503, "already authenticated")
225 return
226 }
227
228 var cmd, authType string
229 _, err := fmt.Sscanf(conn.line, "%s %s", &cmd, &authType)
230 if err != nil {
231 conn.reply(ReplyBadSyntax)
232 return
233 }
234
235 if authType != "PLAIN" {
236 conn.writeReply(504, "unrecognized auth type")
237 return
238 }
239
240 conn.log.Info("doAUTH()")
241
242 conn.writeReply(334, " ")
243
244 authLine, err := conn.tp.ReadLine()
245 if err != nil {
246 conn.log.Error("failed to read auth line", zap.Error(err))
247 conn.reply(ReplyBadSyntax)
248 return
249 }
250
251 authBytes, err := base64.StdEncoding.DecodeString(authLine)
252 if err != nil {
253 conn.reply(ReplyBadSyntax)
254 return
255 }
256
257 authParts := strings.Split(string(authBytes), "\x00")
258 if len(authParts) != 3 {
259 conn.log.Error("bad auth line syntax")
260 conn.reply(ReplyBadSyntax)
261 return
262 }
263
264 if !conn.server.Authenticate(authParts[0], authParts[1], authParts[2]) {
265 conn.log.Error("failed to authenticate", zap.String("authc", authParts[1]))
266 conn.writeReply(535, "invalid credentials")
267 return
268 }
269
270 conn.log.Info("authenticated", zap.String("authz", authParts[0]), zap.String("authc", authParts[1]))
271 conn.authc = authParts[1]
272 conn.reply(ReplyOK)
273 }
274
275 func (conn *connection) doMAIL() {
276 if conn.state != stateInitial {
277 conn.reply(ReplyBadSequence)
278 return
279 }
280
281 mailFrom, reply := conn.parsePath("MAIL FROM:")
282 if reply != ReplyOK {
283 conn.reply(reply)
284 return
285 }
286
287 var err error
288 conn.mailFrom, err = mail.ParseAddress(mailFrom)
289 if err != nil || conn.mailFrom == nil {
290 conn.reply(ReplyBadSyntax)
291 return
292 }
293
294 conn.log.Info("doMAIL()", zap.String("address", conn.mailFrom.Address))
295
296 conn.state = stateMail
297 conn.reply(ReplyOK)
298 }
299
300 func (conn *connection) doRCPT() {
301 if conn.state != stateMail && conn.state != stateRecipient {
302 conn.reply(ReplyBadSequence)
303 return
304 }
305
306 rcptTo, reply := conn.parsePath("RCPT TO:")
307 if reply != ReplyOK {
308 conn.reply(reply)
309 return
310 }
311
312 address, err := mail.ParseAddress(rcptTo)
313 if err != nil {
314 conn.reply(ReplyBadSyntax)
315 return
316 }
317
318 if reply := conn.server.VerifyAddress(*address); reply != ReplyOK {
319 conn.log.Warn("invalid address",
320 zap.String("address", address.Address),
321 zap.Stringer("reply", reply))
322 conn.reply(reply)
323 return
324 }
325
326 conn.log.Info("doRCPT()", zap.String("address", address.Address))
327
328 conn.rcptTo = append(conn.rcptTo, *address)
329
330 conn.state = stateRecipient
331 conn.reply(ReplyOK)
332 }
333
334 func (conn *connection) doDATA() {
335 if conn.state != stateRecipient {
336 conn.reply(ReplyBadSequence)
337 return
338 }
339
340 conn.writeReply(354, "Start mail input; end with <CRLF>.<CRLF>")
341 conn.log.Info("doDATA()")
342
343 data, err := conn.tp.ReadDotBytes()
344 if err != nil {
345 conn.log.Error("failed to ReadDotBytes()",
346 zap.Error(err),
347 zap.String("bytes", fmt.Sprintf("%x", data)))
348 conn.writeReply(552, "transaction failed")
349 return
350 }
351
352 received := time.Now()
353 env := Envelope{
354 RemoteAddr: conn.remoteAddr,
355 EHLO: conn.ehlo,
356 MailFrom: *conn.mailFrom,
357 RcptTo: conn.rcptTo,
358 Received: received,
359 ID: conn.envelopeID(received),
360 }
361
362 conn.log.Info("received message",
363 zap.Int("bytes", len(data)),
364 zap.Time("date", received),
365 zap.String("id", env.ID))
366
367 trace := conn.getReceivedInfo(env)
368
369 env.Data = append(trace, data...)
370
371 if reply := conn.server.OnMessageDelivered(env); reply != nil {
372 conn.log.Warn("message was rejected", zap.String("id", env.ID))
373 conn.reply(*reply)
374 return
375 }
376
377 conn.state = stateInitial
378 conn.reply(ReplyOK)
379 }
380
381 func (conn *connection) envelopeID(t time.Time) string {
382 var idBytes [4]byte
383 rand.Read(idBytes[:])
384 return fmt.Sprintf("m.%d.%x", t.UnixNano(), idBytes)
385 }
386
387 func (conn *connection) getReceivedInfo(envelope Envelope) []byte {
388 rhost, _, err := net.SplitHostPort(conn.remoteAddr.String())
389 if err != nil {
390 rhost = conn.remoteAddr.String()
391 }
392
393 rhosts, err := net.LookupAddr(rhost)
394 if err == nil {
395 rhost = fmt.Sprintf("%s [%s]", rhosts[0], rhost)
396 }
397
398 base := fmt.Sprintf("Received: from %s (%s)\r\n ", conn.ehlo, rhost)
399
400 with := "SMTP"
401 if conn.esmtp {
402 with = "E" + with
403 }
404 if conn.tls != nil {
405 with += "S"
406 }
407 base += fmt.Sprintf("by %s (mailpopbox) with %s id %s\r\n ", conn.server.Name(), with, envelope.ID)
408
409 base += fmt.Sprintf("for <%s>\r\n ", envelope.RcptTo[0].Address)
410
411 transport := conn.getTransportString()
412 date := envelope.Received.Format(time.RFC1123Z) // Same as RFC 5322 ยง 3.3
413 base += fmt.Sprintf("(using %s);\r\n %s\r\n", transport, date)
414
415 return []byte(base)
416 }
417
418 func (conn *connection) getTransportString() string {
419 if conn.tls == nil {
420 return "PLAINTEXT"
421 }
422
423 ciphers := map[uint16]string{
424 tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
425 tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
426 tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
427 tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
428 tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
429 tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
430 tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
431 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
432 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
433 tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
434 tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
435 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
436 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
437 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
438 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
439 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
440 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
441 }
442 versions := map[uint16]string{
443 tls.VersionSSL30: "SSLv3.0",
444 tls.VersionTLS10: "TLSv1.0",
445 tls.VersionTLS11: "TLSv1.1",
446 tls.VersionTLS12: "TLSv1.2",
447 }
448
449 state := conn.tls
450
451 version := versions[state.Version]
452 cipher := ciphers[state.CipherSuite]
453
454 if version == "" {
455 version = fmt.Sprintf("%x", state.Version)
456 }
457 if cipher == "" {
458 cipher = fmt.Sprintf("%x", state.CipherSuite)
459 }
460
461 name := ""
462 if state.ServerName != "" {
463 name = fmt.Sprintf(" name=%s", state.ServerName)
464 }
465
466 return fmt.Sprintf("%s cipher=%s%s", version, cipher, name)
467 }
468
469 func (conn *connection) doRSET() {
470 conn.log.Info("doRSET()")
471 conn.state = stateInitial
472 conn.resetBuffers()
473 conn.reply(ReplyOK)
474 }
475
476 func (conn *connection) resetBuffers() {
477 conn.mailFrom = nil
478 conn.rcptTo = make([]mail.Address, 0)
479 }