In tests, do not expect zero-cost mode to catch UAF.
[zcpointer.git] / test.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <iostream>
16 #include <vector>
17
18 #include "zcpointer.h"
19
20 class C {
21 public:
22 ~C() {}
23
24 void DoThing() {}
25 };
26
27 #define EXPECT(expr) do { if (!(expr)) { throw std::logic_error(#expr); } } while(0)
28
29 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
30
31 #define EXPECT_UAF(expr) do { \
32 try { \
33 (expr); \
34 throw std::logic_error("Expected use-after-free: " #expr); \
35 } catch (zc::UseAfterFreeError) {} \
36 } while(0)
37
38 #else
39
40 #define EXPECT_UAF(expr) do { \
41 std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
42 (expr); \
43 } while(0)
44
45 #endif
46
47 void TestReset() {
48 zc::owned<C> c(new C());
49 zc::ref<C> owned = c.get();
50 zc::ref<C> owned2 = owned;
51 c.reset();
52 EXPECT_UAF(owned2->DoThing());
53 }
54
55 template <typename T>
56 void TestUnwrap() {
57 zc::owned<T> t(new T());
58 //T* unwrap = t.get();
59
60 zc::ref<T> ref = t.get();
61 T* unwrap2 = ref;
62 }
63
64 void TestMove() {
65 zc::owned<C> c(new C());
66 zc::ref<C> owned = c.get();
67
68 zc::owned<C> c2(std::move(c));
69 owned->DoThing();
70
71 c2.reset();
72 EXPECT_UAF(owned->DoThing());
73 }
74
75 void PtrHelper(zc::ref<C>* out) {
76 zc::owned<C> c(new C());
77 *out = c.get();
78 }
79
80 void TestPtr() {
81 zc::ref<C> ref;
82 PtrHelper(&ref);
83 EXPECT_UAF(ref->DoThing());
84 }
85
86 void TestEquality() {
87 zc::owned<C> a(new C());
88 zc::owned<C> b(new C());
89
90 EXPECT(a == a);
91 EXPECT(b == b);
92 EXPECT(a != b);
93
94 zc::ref<C> ra = a.get();
95 zc::ref<C> rb = b.get();
96
97 EXPECT(ra == ra);
98 EXPECT(ra == a.get());
99 EXPECT(rb == rb);
100 EXPECT(rb == b.get());
101
102 EXPECT(rb != ra);
103
104 zc::ref<C> r = a.get();
105 EXPECT(r == ra);
106 EXPECT(r == a.get());
107
108 zc::owned<C> c;
109 zc::owned<C> c2;
110 zc::ref<C> rc = nullptr;
111
112 EXPECT(rc == c.get());
113 EXPECT(c == nullptr);
114 EXPECT(rc == nullptr);
115 EXPECT(a != c);
116 EXPECT(c == c2);
117 }
118
119 void TestNulls() {
120 zc::owned<C> l;
121 zc::owned<C> r;
122
123 zc::ref<C> rl = l.get();
124 zc::ref<C> rr = r.get();
125
126 r = std::move(l);
127 rl = rr;
128
129 EXPECT(l == nullptr);
130 EXPECT(r == nullptr);
131 EXPECT(rl == nullptr);
132 EXPECT(rr == nullptr);
133 }
134
135 void TestVector() {
136 zc::owned<C> c;
137 std::vector<zc::ref<C>> vec{
138 c.get(),
139 c.get(),
140 c.get()
141 };
142
143 for (const auto& r : vec) {
144 EXPECT(r == c.get());
145 }
146
147 zc::ref<C> ref;
148 {
149 std::vector<zc::owned<C>> vec;
150 vec.push_back(std::move(zc::owned<C>(new C())));
151 vec.push_back(std::move(zc::owned<C>(new C())));
152 vec.push_back(std::move(zc::owned<C>(new C())));
153 ref = vec[1].get();
154 }
155 EXPECT_UAF(ref->DoThing());
156 }
157
158 #define TEST_FUNC(fn) { #fn , Test##fn }
159
160 int main() {
161 struct {
162 const char* name;
163 void (*test)();
164 } kTests[] = {
165 TEST_FUNC(Reset),
166 TEST_FUNC(Move),
167 TEST_FUNC(Ptr),
168 TEST_FUNC(Equality),
169 TEST_FUNC(Nulls),
170 TEST_FUNC(Vector),
171 };
172
173 bool passed = true;
174 for (const auto& test : kTests) {
175 std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
176 try {
177 test.test();
178 std::cout << "+++ PASS " << test.name << " +++" << std::endl;
179 } catch (const std::logic_error& e) {
180 passed = false;
181 std::cout << "!!! FAIL " << test.name
182 << ": Assertion failure: " << e.what() << " ===" << std::endl;
183 }
184 }
185
186 return passed ? 0 : 1;
187 }