1 // Copyright 2016 Google Inc. All rights reserved.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
19 #include "test_helpers.h"
20 #include "zcpointer.h"
22 class TestFailure
: public std
::logic_error
{
24 using std
::logic_error
::logic_error
;
27 #define STRING(x) QUOTE(x)
29 #define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
31 #define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
33 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
35 #define EXPECT_UAF(expr) do { \
38 throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
39 } catch (zc::UseAfterFreeError) {} \
44 #define EXPECT_UAF(expr) do { \
45 std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
48 } catch (std::logic_error& e) { \
49 std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
56 zc
::owned
<C
> c(new C());
57 zc
::ref
<C
> owned
= c
.get();
58 zc
::ref
<C
> owned2
= owned
;
60 EXPECT_UAF(owned2
->DoThing());
65 zc
::owned
<T
> t(new T());
66 //T* unwrap = t.get();
68 zc
::ref
<T
> ref
= t
.get();
73 zc
::owned
<C
> c(new C());
74 zc
::ref
<C
> owned
= c
.get();
76 zc
::owned
<C
> c2(std
::move(c
));
80 EXPECT_UAF(owned
->DoThing());
83 void PtrHelper(zc
::ref
<C
>* out
) {
84 zc
::owned
<C
> c(new C());
91 EXPECT_UAF(ref
->DoThing());
95 zc
::owned
<C
> a(new C());
96 zc
::owned
<C
> b(new C());
102 zc
::ref
<C
> ra
= a
.get();
103 zc
::ref
<C
> rb
= b
.get();
106 EXPECT(ra
== a
.get());
108 EXPECT(rb
== b
.get());
112 zc
::ref
<C
> r
= a
.get();
114 EXPECT(r
== a
.get());
118 zc
::ref
<C
> rc
= nullptr;
120 EXPECT(rc
== c
.get());
121 EXPECT(c
== nullptr);
122 EXPECT(rc
== nullptr);
131 zc
::ref
<C
> rl
= l
.get();
132 zc
::ref
<C
> rr
= r
.get();
137 EXPECT(l
== nullptr);
138 EXPECT(r
== nullptr);
139 EXPECT(rl
== nullptr);
140 EXPECT(rr
== nullptr);
145 std
::vector
<zc
::ref
<C
>> vec
{
151 for (const auto& r
: vec
) {
152 EXPECT(r
== c
.get());
157 std
::vector
<zc
::owned
<C
>> vec
;
158 vec
.push_back(std
::move(zc
::owned
<C
>(new C())));
159 vec
.push_back(std
::move(zc
::owned
<C
>(new C())));
160 vec
.push_back(std
::move(zc
::owned
<C
>(new C())));
163 EXPECT_UAF(ref
->DoThing());
174 EXPECT_UAF(rc
->DoThing());
179 zc
::ref
<std
::vector
<C
>> vec_ref
;
185 vec_ref
->push_back(C());
186 vec_ref
->push_back(C());
188 vec_ref
->at(1).DoThing();
190 EXPECT_UAF(ref
->DoThing());
191 EXPECT_UAF(vec_ref
->at(1).DoThing());
194 zc
::member
<X
> x("foo bar");
197 EXPECT_UAF(ref
->DoThing());
200 #define TEST_FUNC(fn) { #fn , Test##fn }
218 for (const auto& test
: kTests
) {
219 std
::cout
<< "=== BEGIN " << test
.name
<< " ===" << std
::endl
;
222 std
::cout
<< "+++ PASS " << test
.name
<< " +++" << std
::endl
;
223 } catch (const TestFailure
& e
) {
225 std
::cout
<< "!!! FAIL " << test
.name
226 << ": Assertion failure: " << e
.what() << " ===" << std
::endl
;
230 return passed ?
0 : 1;