// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "zcpointer.h" class C { public: ~C() {} void DoThing() {} }; #define EXPECT(expr) do { if (!(expr)) { throw std::logic_error(#expr); } } while(0) #define EXPECT_UAF(expr) do { \ try { \ (expr); \ throw std::logic_error("Expected use-after-free: " #expr); \ } catch (zc::UseAfterFreeError) {} \ } while(0) void TestReset() { zc::owned c(new C()); zc::ref owned = c.get(); zc::ref owned2 = owned; c.reset(); EXPECT_UAF(owned2->DoThing()); } template void TestUnwrap() { zc::owned t(new T()); //T* unwrap = t.get(); zc::ref ref = t.get(); T* unwrap2 = ref; } void TestMove() { zc::owned c(new C()); zc::ref owned = c.get(); zc::owned c2(std::move(c)); owned->DoThing(); c2.reset(); EXPECT_UAF(owned->DoThing()); } void PtrHelper(zc::ref* out) { zc::owned c(new C()); *out = c.get(); } void TestPtr() { zc::ref ref; PtrHelper(&ref); EXPECT_UAF(ref->DoThing()); } void TestEquality() { zc::owned a(new C()); zc::owned b(new C()); EXPECT(a == a); EXPECT(b == b); EXPECT(a != b); zc::ref ra = a.get(); zc::ref rb = b.get(); EXPECT(ra == ra); EXPECT(ra == a.get()); EXPECT(rb == rb); EXPECT(rb == b.get()); EXPECT(rb != ra); zc::ref r = a.get(); EXPECT(r == ra); EXPECT(r == a.get()); zc::owned c; zc::owned c2; zc::ref rc = nullptr; EXPECT(rc == c.get()); EXPECT(c == nullptr); EXPECT(rc == nullptr); EXPECT(a != c); EXPECT(c == c2); } void TestNulls() { zc::owned l; zc::owned r; zc::ref rl = l.get(); zc::ref rr = r.get(); r = std::move(l); rl = rr; EXPECT(l == nullptr); EXPECT(r == nullptr); EXPECT(rl == nullptr); EXPECT(rr == nullptr); } #define TEST_FUNC(fn) { #fn , Test##fn } int main() { struct { const char* name; void (*test)(); } kTests[] = { TEST_FUNC(Reset), TEST_FUNC(Move), TEST_FUNC(Ptr), TEST_FUNC(Equality), TEST_FUNC(Nulls), }; bool passed = true; for (const auto& test : kTests) { std::cout << "=== BEGIN " << test.name << " ===" << std::endl; try { test.test(); std::cout << "=== PASS " << test.name << " ===" << std::endl; } catch (const std::logic_error& e) { passed = false; std::cout << "=== FAIL " << test.name << ": Assertion failure: " << e.what() << " ===" << std::endl; } } return passed ? 0 : 1; }