Make member<T> not heap-allocated with ZCPOINTER_TRACK_REFS.
[zcpointer.git] / test.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <iostream>
16 #include <stdexcept>
17 #include <vector>
18
19 #include "test_helpers.h"
20 #include "zcpointer.h"
21
22 class TestFailure : public std::logic_error {
23 public:
24 using std::logic_error::logic_error;
25 };
26
27 #define STRING(x) QUOTE(x)
28 #define QUOTE(x) #x
29 #define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
30
31 #define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
32
33 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
34
35 #define EXPECT_UAF(expr) do { \
36 try { \
37 (expr); \
38 throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
39 } catch (zc::UseAfterFreeError) {} \
40 } while(0)
41
42 #else
43
44 #define EXPECT_UAF(expr) do { \
45 std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
46 try { \
47 (expr); \
48 } catch (std::logic_error& e) { \
49 std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
50 } \
51 } while(0)
52
53 #endif
54
55 void TestReset() {
56 zc::owned<C> c(new C());
57 zc::ref<C> owned = c.get();
58 zc::ref<C> owned2 = owned;
59 c.reset();
60 EXPECT_UAF(owned2->DoThing());
61 }
62
63 template <typename T>
64 void TestUnwrap() {
65 zc::owned<T> t(new T());
66 //T* unwrap = t.get();
67
68 zc::ref<T> ref = t.get();
69 T* unwrap2 = ref;
70 }
71
72 void TestMove() {
73 zc::owned<C> c(new C());
74 zc::ref<C> owned = c.get();
75
76 zc::owned<C> c2(std::move(c));
77 owned->DoThing();
78
79 c2.reset();
80 EXPECT_UAF(owned->DoThing());
81 }
82
83 void PtrHelper(zc::ref<C>* out) {
84 zc::owned<C> c(new C());
85 *out = c.get();
86 }
87
88 void TestPtr() {
89 zc::ref<C> ref;
90 PtrHelper(&ref);
91 EXPECT_UAF(ref->DoThing());
92 }
93
94 void TestEquality() {
95 zc::owned<C> a(new C());
96 zc::owned<C> b(new C());
97
98 EXPECT(a == a);
99 EXPECT(b == b);
100 EXPECT(a != b);
101
102 zc::ref<C> ra = a.get();
103 zc::ref<C> rb = b.get();
104
105 EXPECT(ra == ra);
106 EXPECT(ra == a.get());
107 EXPECT(rb == rb);
108 EXPECT(rb == b.get());
109
110 EXPECT(rb != ra);
111
112 zc::ref<C> r = a.get();
113 EXPECT(r == ra);
114 EXPECT(r == a.get());
115
116 zc::owned<C> c;
117 zc::owned<C> c2;
118 zc::ref<C> rc = nullptr;
119
120 EXPECT(rc == c.get());
121 EXPECT(c == nullptr);
122 EXPECT(rc == nullptr);
123 EXPECT(a != c);
124 EXPECT(c == c2);
125 }
126
127 void TestNulls() {
128 zc::owned<C> l;
129 zc::owned<C> r;
130
131 zc::ref<C> rl = l.get();
132 zc::ref<C> rr = r.get();
133
134 r = std::move(l);
135 rl = rr;
136
137 EXPECT(l == nullptr);
138 EXPECT(r == nullptr);
139 EXPECT(rl == nullptr);
140 EXPECT(rr == nullptr);
141 }
142
143 void TestVector() {
144 zc::owned<C> c;
145 std::vector<zc::ref<C>> vec{
146 c.get(),
147 c.get(),
148 c.get()
149 };
150
151 for (const auto& r : vec) {
152 EXPECT(r == c.get());
153 }
154
155 zc::ref<C> ref;
156 {
157 std::vector<zc::owned<C>> vec;
158 vec.push_back(std::move(zc::owned<C>(new C())));
159 vec.push_back(std::move(zc::owned<C>(new C())));
160 vec.push_back(std::move(zc::owned<C>(new C())));
161 ref = vec[1].get();
162 }
163 EXPECT_UAF(ref->DoThing());
164 }
165
166 void TestStack() {
167 zc::ref<C> rc;
168 {
169 zc::member<C> c;
170 rc = &c;
171 EXPECT(rc == &c);
172 c.DoThing();
173 }
174 EXPECT_UAF(rc->DoThing());
175 }
176
177 void TestMember() {
178 zc::ref<C> ref;
179 zc::ref<std::vector<C>> vec_ref;
180 {
181 X x("hello world");
182 ref = x.c();
183 vec_ref = x.vec_c();
184
185 vec_ref->push_back(C());
186 vec_ref->push_back(C());
187
188 vec_ref->at(1).DoThing();
189 }
190 EXPECT_UAF(ref->DoThing());
191 EXPECT_UAF(vec_ref->at(1).DoThing());
192
193 {
194 zc::member<X> x("foo bar");
195 ref = x.c();
196 }
197 EXPECT_UAF(ref->DoThing());
198 }
199
200 #define TEST_FUNC(fn) { #fn , Test##fn }
201
202 int main() {
203 struct {
204 const char* name;
205 void (*test)();
206 } kTests[] = {
207 TEST_FUNC(Reset),
208 TEST_FUNC(Move),
209 TEST_FUNC(Ptr),
210 TEST_FUNC(Equality),
211 TEST_FUNC(Nulls),
212 TEST_FUNC(Vector),
213 TEST_FUNC(Stack),
214 TEST_FUNC(Member),
215 };
216
217 bool passed = true;
218 for (const auto& test : kTests) {
219 std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
220 try {
221 test.test();
222 std::cout << "+++ PASS " << test.name << " +++" << std::endl;
223 } catch (const TestFailure& e) {
224 passed = false;
225 std::cout << "!!! FAIL " << test.name
226 << ": Assertion failure: " << e.what() << " ===" << std::endl;
227 }
228 }
229
230 return passed ? 0 : 1;
231 }