Add file-line information to test failure messages.
[zcpointer.git] / test.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <iostream>
16 #include <stdexcept>
17 #include <vector>
18
19 #include "zcpointer.h"
20
21 class C {
22 public:
23 ~C() {}
24
25 void DoThing() {}
26 };
27
28 class TestFailure : public std::logic_error {
29 public:
30 using std::logic_error::logic_error;
31 };
32
33 #define STRING(x) QUOTE(x)
34 #define QUOTE(x) #x
35 #define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
36
37 #define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
38
39 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
40
41 #define EXPECT_UAF(expr) do { \
42 try { \
43 (expr); \
44 throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
45 } catch (zc::UseAfterFreeError) {} \
46 } while(0)
47
48 #else
49
50 #define EXPECT_UAF(expr) do { \
51 std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
52 try { \
53 (expr); \
54 } catch (std::logic_error& e) { \
55 std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
56 } \
57 } while(0)
58
59 #endif
60
61 void TestReset() {
62 zc::owned<C> c(new C());
63 zc::ref<C> owned = c.get();
64 zc::ref<C> owned2 = owned;
65 c.reset();
66 EXPECT_UAF(owned2->DoThing());
67 }
68
69 template <typename T>
70 void TestUnwrap() {
71 zc::owned<T> t(new T());
72 //T* unwrap = t.get();
73
74 zc::ref<T> ref = t.get();
75 T* unwrap2 = ref;
76 }
77
78 void TestMove() {
79 zc::owned<C> c(new C());
80 zc::ref<C> owned = c.get();
81
82 zc::owned<C> c2(std::move(c));
83 owned->DoThing();
84
85 c2.reset();
86 EXPECT_UAF(owned->DoThing());
87 }
88
89 void PtrHelper(zc::ref<C>* out) {
90 zc::owned<C> c(new C());
91 *out = c.get();
92 }
93
94 void TestPtr() {
95 zc::ref<C> ref;
96 PtrHelper(&ref);
97 EXPECT_UAF(ref->DoThing());
98 }
99
100 void TestEquality() {
101 zc::owned<C> a(new C());
102 zc::owned<C> b(new C());
103
104 EXPECT(a == a);
105 EXPECT(b == b);
106 EXPECT(a != b);
107
108 zc::ref<C> ra = a.get();
109 zc::ref<C> rb = b.get();
110
111 EXPECT(ra == ra);
112 EXPECT(ra == a.get());
113 EXPECT(rb == rb);
114 EXPECT(rb == b.get());
115
116 EXPECT(rb != ra);
117
118 zc::ref<C> r = a.get();
119 EXPECT(r == ra);
120 EXPECT(r == a.get());
121
122 zc::owned<C> c;
123 zc::owned<C> c2;
124 zc::ref<C> rc = nullptr;
125
126 EXPECT(rc == c.get());
127 EXPECT(c == nullptr);
128 EXPECT(rc == nullptr);
129 EXPECT(a != c);
130 EXPECT(c == c2);
131 }
132
133 void TestNulls() {
134 zc::owned<C> l;
135 zc::owned<C> r;
136
137 zc::ref<C> rl = l.get();
138 zc::ref<C> rr = r.get();
139
140 r = std::move(l);
141 rl = rr;
142
143 EXPECT(l == nullptr);
144 EXPECT(r == nullptr);
145 EXPECT(rl == nullptr);
146 EXPECT(rr == nullptr);
147 }
148
149 void TestVector() {
150 zc::owned<C> c;
151 std::vector<zc::ref<C>> vec{
152 c.get(),
153 c.get(),
154 c.get()
155 };
156
157 for (const auto& r : vec) {
158 EXPECT(r == c.get());
159 }
160
161 zc::ref<C> ref;
162 {
163 std::vector<zc::owned<C>> vec;
164 vec.push_back(std::move(zc::owned<C>(new C())));
165 vec.push_back(std::move(zc::owned<C>(new C())));
166 vec.push_back(std::move(zc::owned<C>(new C())));
167 ref = vec[1].get();
168 }
169 EXPECT_UAF(ref->DoThing());
170 }
171
172 #define TEST_FUNC(fn) { #fn , Test##fn }
173
174 int main() {
175 struct {
176 const char* name;
177 void (*test)();
178 } kTests[] = {
179 TEST_FUNC(Reset),
180 TEST_FUNC(Move),
181 TEST_FUNC(Ptr),
182 TEST_FUNC(Equality),
183 TEST_FUNC(Nulls),
184 TEST_FUNC(Vector),
185 };
186
187 bool passed = true;
188 for (const auto& test : kTests) {
189 std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
190 try {
191 test.test();
192 std::cout << "+++ PASS " << test.name << " +++" << std::endl;
193 } catch (const TestFailure& e) {
194 passed = false;
195 std::cout << "!!! FAIL " << test.name
196 << ": Assertion failure: " << e.what() << " ===" << std::endl;
197 }
198 }
199
200 return passed ? 0 : 1;
201 }