Add a SMTP test for STARTTLS.
[mailpopbox.git] / smtp / conn_test.go
1 package smtp
2
3 import (
4 "crypto/tls"
5 "fmt"
6 "net"
7 "net/mail"
8 "net/textproto"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "testing"
13 "time"
14
15 "github.com/uber-go/zap"
16 )
17
18 func _fl(depth int) string {
19 _, file, line, _ := runtime.Caller(depth + 1)
20 return fmt.Sprintf("[%s:%d]", filepath.Base(file), line)
21 }
22
23 func ok(t testing.TB, err error) {
24 if err != nil {
25 t.Errorf("%s unexpected error: %v", _fl(1), err)
26 }
27 }
28
29 func readCodeLine(t testing.TB, conn *textproto.Conn, code int) string {
30 _, message, err := conn.ReadCodeLine(code)
31 if err != nil {
32 t.Errorf("%s ReadCodeLine error: %v", _fl(1), err)
33 }
34 return message
35 }
36
37 // runServer creates a TCP socket, runs a listening server, and returns the connection.
38 // The server exits when the Conn is closed.
39 func runServer(t *testing.T, server Server) net.Listener {
40 l, err := net.Listen("tcp", "localhost:0")
41 if err != nil {
42 t.Fatal(err)
43 return nil
44 }
45
46 go func() {
47 for {
48 conn, err := l.Accept()
49 if err != nil {
50 return
51 }
52 go AcceptConnection(conn, server, zap.New(zap.NullEncoder()))
53 }
54 }()
55
56 return l
57 }
58
59 type testServer struct {
60 EmptyServerCallbacks
61 blockList []string
62 tlsConfig *tls.Config
63 }
64
65 func (s *testServer) Name() string {
66 return "Test-Server"
67 }
68
69 func (s *testServer) TLSConfig() *tls.Config {
70 return s.tlsConfig
71 }
72
73 func (s *testServer) VerifyAddress(addr mail.Address) ReplyLine {
74 for _, block := range s.blockList {
75 if strings.ToLower(block) == addr.Address {
76 return ReplyBadMailbox
77 }
78 }
79 return ReplyOK
80 }
81
82 func createClient(t *testing.T, addr net.Addr) *textproto.Conn {
83 conn, err := textproto.Dial(addr.Network(), addr.String())
84 if err != nil {
85 t.Fatal(err)
86 return nil
87 }
88 return conn
89 }
90
91 type requestResponse struct {
92 request string
93 responseCode int
94 handler func(testing.TB, *textproto.Conn)
95 }
96
97 func runTableTest(t testing.TB, conn *textproto.Conn, seq []requestResponse) {
98 for i, rr := range seq {
99 t.Logf("%s case %d", _fl(1), i)
100 ok(t, conn.PrintfLine(rr.request))
101 if rr.handler != nil {
102 rr.handler(t, conn)
103 } else {
104 readCodeLine(t, conn, rr.responseCode)
105 }
106 }
107 }
108
109 // RFC 5321 ยง D.1
110 func TestScenarioTypical(t *testing.T) {
111 s := testServer{
112 blockList: []string{"Green@foo.com"},
113 }
114 l := runServer(t, &s)
115 defer l.Close()
116
117 conn := createClient(t, l.Addr())
118
119 message := readCodeLine(t, conn, 220)
120 if !strings.HasPrefix(message, s.Name()) {
121 t.Errorf("Greeting does not have server name, got %q", message)
122 }
123
124 greet := "greeting.TestScenarioTypical"
125 ok(t, conn.PrintfLine("EHLO "+greet))
126
127 _, message, err := conn.ReadResponse(250)
128 ok(t, err)
129 if !strings.Contains(message, greet) {
130 t.Errorf("EHLO response does not contain greeting, got %q", message)
131 }
132
133 ok(t, conn.PrintfLine("MAIL FROM:<Smith@bar.com>"))
134 readCodeLine(t, conn, 250)
135
136 ok(t, conn.PrintfLine("RCPT TO:<Jones@foo.com>"))
137 readCodeLine(t, conn, 250)
138
139 ok(t, conn.PrintfLine("RCPT TO:<Green@foo.com>"))
140 readCodeLine(t, conn, 550)
141
142 ok(t, conn.PrintfLine("RCPT TO:<Brown@foo.com>"))
143 readCodeLine(t, conn, 250)
144
145 ok(t, conn.PrintfLine("DATA"))
146 readCodeLine(t, conn, 354)
147
148 ok(t, conn.PrintfLine("Blah blah blah..."))
149 ok(t, conn.PrintfLine("...etc. etc. etc."))
150 ok(t, conn.PrintfLine("."))
151 readCodeLine(t, conn, 250)
152
153 ok(t, conn.PrintfLine("QUIT"))
154 readCodeLine(t, conn, 221)
155 }
156
157 func TestVerifyAddress(t *testing.T) {
158 s := testServer{
159 blockList: []string{"banned@test.mail"},
160 }
161 l := runServer(t, &s)
162 defer l.Close()
163
164 conn := createClient(t, l.Addr())
165 readCodeLine(t, conn, 220)
166
167 runTableTest(t, conn, []requestResponse{
168 {"EHLO test", 0, func(t testing.TB, conn *textproto.Conn) { conn.ReadResponse(250) }},
169 {"VRFY banned@test.mail", 252, nil},
170 {"VRFY allowed@test.mail", 252, nil},
171 {"MAIL FROM:<sender@example.com>", 250, nil},
172 {"RCPT TO:<banned@test.mail>", 550, nil},
173 {"QUIT", 221, nil},
174 })
175 }
176
177 func TestBadAddress(t *testing.T) {
178 l := runServer(t, &testServer{})
179 defer l.Close()
180
181 conn := createClient(t, l.Addr())
182 readCodeLine(t, conn, 220)
183
184 runTableTest(t, conn, []requestResponse{
185 {"EHLO test", 0, func(t testing.TB, conn *textproto.Conn) { conn.ReadResponse(250) }},
186 {"MAIL FROM:<sender>", 501, nil},
187 {"MAIL FROM:<sender@foo.com> SIZE=2163", 250, nil},
188 {"RCPT TO:<banned.net>", 501, nil},
189 {"QUIT", 221, nil},
190 })
191 }
192
193 func TestCaseSensitivty(t *testing.T) {
194 s := &testServer{}
195 s.blockList = []string{"reject@mail.com"}
196 l := runServer(t, s)
197 defer l.Close()
198
199 conn := createClient(t, l.Addr())
200 readCodeLine(t, conn, 220)
201
202 runTableTest(t, conn, []requestResponse{
203 {"nOoP", 250, nil},
204 {"ehLO test.TEST", 0, func(t testing.TB, conn *textproto.Conn) { conn.ReadResponse(250) }},
205 {"mail FROM:<sender@example.com>", 250, nil},
206 {"RcPT tO:<receive@mail.com>", 250, nil},
207 {"RCPT TO:<reject@MAIL.com>", 550, nil},
208 {"RCPT TO:<reject@mail.com>", 550, nil},
209 {"DATa", 0, func(t testing.TB, conn *textproto.Conn) {
210 readCodeLine(t, conn, 354)
211
212 ok(t, conn.PrintfLine("."))
213 readCodeLine(t, conn, 250)
214 }},
215 {"MAIL FR:", 501, nil},
216 {"QUiT", 221, nil},
217 })
218 }
219
220 func TestGetReceivedInfo(t *testing.T) {
221 conn := connection{
222 server: &testServer{},
223 remoteAddr: &net.IPAddr{net.IPv4(127, 0, 0, 1), ""},
224 }
225
226 now := time.Now()
227
228 const crlf = "\r\n"
229 const line1 = "Received: from remote.test. (localhost [127.0.0.1])" + crlf
230 const line2 = "by Test-Server (mailpopbox) with "
231 const msgId = "abcdef.hijk"
232 lineLast := now.Format(time.RFC1123Z) + crlf
233
234 type params struct {
235 ehlo string
236 esmtp bool
237 tls bool
238 address string
239 }
240
241 tests := []struct {
242 params params
243
244 expect []string
245 }{
246 {params{"remote.test.", true, false, "foo@bar.com"},
247 []string{line1,
248 line2 + "ESMTP id " + msgId + crlf,
249 "for <foo@bar.com>" + crlf,
250 "(using PLAINTEXT);" + crlf,
251 lineLast, ""}},
252 }
253
254 for _, test := range tests {
255 t.Logf("%#v", test.params)
256
257 conn.ehlo = test.params.ehlo
258 conn.esmtp = test.params.esmtp
259 //conn.tls = test.params.tls
260
261 envelope := Envelope{
262 RcptTo: []mail.Address{{"", test.params.address}},
263 Received: now,
264 ID: msgId,
265 }
266
267 actual := conn.getReceivedInfo(envelope)
268 actualLines := strings.SplitAfter(string(actual), crlf)
269
270 if len(actualLines) != len(test.expect) {
271 t.Errorf("wrong numbber of lines, expected %d, got %d", len(test.expect), len(actualLines))
272 continue
273 }
274
275 for i, line := range actualLines {
276 expect := test.expect[i]
277 if expect != strings.TrimLeft(line, " ") {
278 t.Errorf("Expected equal string %q, got %q", expect, line)
279 }
280 }
281 }
282
283 }
284
285 func getTLSConfig(t *testing.T) *tls.Config {
286 cert, err := tls.LoadX509KeyPair("../testtls/domain.crt", "../testtls/domain.key")
287 if err != nil {
288 t.Fatal(err)
289 return nil
290 }
291 return &tls.Config{
292 ServerName: "localhost",
293 Certificates: []tls.Certificate{cert},
294 InsecureSkipVerify: true,
295 }
296 }
297
298 func TestTLS(t *testing.T) {
299 l := runServer(t, &testServer{tlsConfig: getTLSConfig(t)})
300 defer l.Close()
301
302 nc, err := net.Dial(l.Addr().Network(), l.Addr().String())
303 ok(t, err)
304
305 conn := textproto.NewConn(nc)
306 readCodeLine(t, conn, 220)
307
308 ok(t, conn.PrintfLine("EHLO test-tls"))
309 _, resp, err := conn.ReadResponse(250)
310 ok(t, err)
311 if !strings.Contains(resp, "STARTTLS\n") {
312 t.Errorf("STARTTLS not advertised")
313 }
314
315 ok(t, conn.PrintfLine("STARTTLS"))
316 readCodeLine(t, conn, 220)
317
318 tc := tls.Client(nc, getTLSConfig(t))
319 err = tc.Handshake()
320 ok(t, err)
321
322 conn = textproto.NewConn(tc)
323
324 ok(t, conn.PrintfLine("EHLO test-tls-started"))
325 _, resp, err = conn.ReadResponse(250)
326 ok(t, err)
327 if strings.Contains(resp, "STARTTLS\n") {
328 t.Errorf("STARTTLS advertised when already started")
329 }
330 }