Introduce a basic EXPECT() and EXPECT_UAF() to test.cc.
[zcpointer.git] / test.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <iostream>
16
17 #include "zcpointer.h"
18
19 class C {
20 public:
21 ~C() {}
22
23 void DoThing() {}
24 };
25
26 #define EXPECT(expr) do { if (!(expr)) { throw std::logic_error(#expr); } } while(0)
27
28 #define EXPECT_UAF(expr) do { \
29 try { \
30 (expr); \
31 throw std::logic_error("Expected use-after-free: " #expr); \
32 } catch (zc::UseAfterFreeError) {} \
33 } while(0)
34
35 void TestReset() {
36 zc::owned<C> c(new C());
37 zc::ref<C> owned = c.get();
38 zc::ref<C> owned2 = owned;
39 c.reset();
40 EXPECT_UAF(owned2->DoThing());
41 }
42
43 template <typename T>
44 void TestUnwrap() {
45 zc::owned<T> t(new T());
46 //T* unwrap = t.get();
47
48 zc::ref<T> ref = t.get();
49 T* unwrap2 = ref;
50 }
51
52 void TestMove() {
53 zc::owned<C> c(new C());
54 zc::ref<C> owned = c.get();
55
56 zc::owned<C> c2(std::move(c));
57 owned->DoThing();
58
59 c2.reset();
60 EXPECT_UAF(owned->DoThing());
61 }
62
63 void PtrHelper(zc::ref<C>* out) {
64 zc::owned<C> c(new C());
65 *out = c.get();
66 }
67
68 void TestPtr() {
69 zc::ref<C> ref;
70 PtrHelper(&ref);
71 EXPECT_UAF(ref->DoThing());
72 }
73
74 #define TEST_FUNC(fn) { #fn , Test##fn }
75
76 int main() {
77 struct {
78 const char* name;
79 void (*test)();
80 } kTests[] = {
81 TEST_FUNC(Reset),
82 TEST_FUNC(Move),
83 TEST_FUNC(Ptr),
84 };
85
86 bool passed = true;
87 for (const auto& test : kTests) {
88 std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
89 try {
90 test.test();
91 std::cout << "=== PASS " << test.name << " ===" << std::endl;
92 } catch (const std::logic_error& e) {
93 passed = false;
94 std::cout << "=== FAIL " << test.name
95 << ": Assertion failure: " << e.what() << " ===" << std::endl;
96 }
97 }
98
99 return passed ? 0 : 1;
100 }