In OwnedPtrDeleter<T>, store a vptr for a behavior function.
[zcpointer.git] / test.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <iostream>
16 #include <stdexcept>
17 #include <vector>
18
19 #include "test_helpers.h"
20 #include "zcpointer.h"
21
22 class TestFailure : public std::logic_error {
23 public:
24 using std::logic_error::logic_error;
25 };
26
27 #define STRING(x) QUOTE(x)
28 #define QUOTE(x) #x
29 #define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
30
31 #define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
32
33 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
34
35 #define EXPECT_UAF(expr) do { \
36 try { \
37 (expr); \
38 throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
39 } catch (zc::UseAfterFreeError) {} \
40 } while(0)
41
42 #else
43
44 #define EXPECT_UAF(expr) do { \
45 std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
46 try { \
47 (expr); \
48 } catch (std::logic_error& e) { \
49 std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
50 } \
51 } while(0)
52
53 #endif
54
55 void TestReset() {
56 zc::owned<C> c(new C());
57 zc::ref<C> owned = c.get();
58 zc::ref<C> owned2 = owned;
59 c.reset();
60 EXPECT_UAF(owned2->DoThing());
61 }
62
63 template <typename T>
64 void TestUnwrap() {
65 zc::owned<T> t(new T());
66 T* unwrap = t.get();
67
68 zc::ref<T> ref = t.get();
69 T* unwrap2 = ref;
70
71 zc::member<T> tm;
72 T* tp = &tm;
73 }
74
75 void TestMove() {
76 zc::owned<C> c(new C());
77 zc::ref<C> owned = c.get();
78
79 zc::owned<C> c2(std::move(c));
80 owned->DoThing();
81
82 c2.reset();
83 EXPECT_UAF(owned->DoThing());
84 }
85
86 void PtrHelper(zc::ref<C>* out) {
87 zc::owned<C> c(new C());
88 *out = c.get();
89 }
90
91 void TestPtr() {
92 zc::ref<C> ref;
93 PtrHelper(&ref);
94 EXPECT_UAF(ref->DoThing());
95 }
96
97 void TestEquality() {
98 zc::owned<C> a(new C());
99 zc::owned<C> b(new C());
100
101 EXPECT(a == a);
102 EXPECT(b == b);
103 EXPECT(a != b);
104
105 zc::ref<C> ra = a.get();
106 zc::ref<C> rb = b.get();
107
108 EXPECT(ra == ra);
109 EXPECT(ra == a.get());
110 EXPECT(rb == rb);
111 EXPECT(rb == b.get());
112
113 EXPECT(rb != ra);
114
115 zc::ref<C> r = a.get();
116 EXPECT(r == ra);
117 EXPECT(r == a.get());
118
119 zc::owned<C> c;
120 zc::owned<C> c2;
121 zc::ref<C> rc = nullptr;
122
123 EXPECT(rc == c.get());
124 EXPECT(c == nullptr);
125 EXPECT(rc == nullptr);
126 EXPECT(a != c);
127 EXPECT(c == c2);
128 }
129
130 void TestNulls() {
131 zc::owned<C> l;
132 zc::owned<C> r;
133
134 zc::ref<C> rl = l.get();
135 zc::ref<C> rr = r.get();
136
137 r = std::move(l);
138 rl = rr;
139
140 EXPECT(l == nullptr);
141 EXPECT(r == nullptr);
142 EXPECT(rl == nullptr);
143 EXPECT(rr == nullptr);
144 }
145
146 void TestVector() {
147 zc::owned<C> c;
148 std::vector<zc::ref<C>> vec{
149 c.get(),
150 c.get(),
151 c.get()
152 };
153
154 for (const auto& r : vec) {
155 EXPECT(r == c.get());
156 }
157
158 zc::ref<C> ref;
159 {
160 std::vector<zc::owned<C>> vec;
161 vec.push_back(std::move(zc::owned<C>(new C())));
162 vec.push_back(std::move(zc::owned<C>(new C())));
163 vec.push_back(std::move(zc::owned<C>(new C())));
164 ref = vec[1].get();
165 }
166 EXPECT_UAF(ref->DoThing());
167 }
168
169 void TestStack() {
170 zc::ref<C> rc;
171 {
172 zc::member<C> c;
173 rc = &c;
174 EXPECT(rc == &c);
175 c.DoThing();
176 }
177 EXPECT_UAF(rc->DoThing());
178 }
179
180 void TestMember() {
181 zc::ref<C> ref;
182 zc::ref<std::vector<C>> vec_ref;
183 {
184 X x("hello world");
185 ref = x.c();
186 vec_ref = x.vec_c();
187
188 vec_ref->push_back(C());
189 vec_ref->push_back(C());
190
191 vec_ref->at(1).DoThing();
192 }
193 EXPECT_UAF(ref->DoThing());
194 EXPECT_UAF(vec_ref->at(1).DoThing());
195
196 {
197 zc::member<X> x("foo bar");
198 ref = x.c();
199 }
200 EXPECT_UAF(ref->DoThing());
201 }
202
203 #define TEST_FUNC(fn) { #fn , Test##fn }
204
205 int main() {
206 struct {
207 const char* name;
208 void (*test)();
209 } kTests[] = {
210 TEST_FUNC(Reset),
211 TEST_FUNC(Move),
212 TEST_FUNC(Ptr),
213 TEST_FUNC(Equality),
214 TEST_FUNC(Nulls),
215 TEST_FUNC(Vector),
216 TEST_FUNC(Stack),
217 TEST_FUNC(Member),
218 };
219
220 bool passed = true;
221 for (const auto& test : kTests) {
222 std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
223 try {
224 test.test();
225 std::cout << "+++ PASS " << test.name << " +++" << std::endl;
226 } catch (const TestFailure& e) {
227 passed = false;
228 std::cout << "!!! FAIL " << test.name
229 << ": Assertion failure: " << e.what() << " ===" << std::endl;
230 }
231 }
232
233 return passed ? 0 : 1;
234 }