In OwnedPtrDeleter<T>, store a vptr for a behavior function.
[zcpointer.git] / test.cc
diff --git a/test.cc b/test.cc
index 23c310a5bc28236d413f3f64e4f23f871c6fb7d5..d28d5912be1adfd582126a9585286cd432529d3d 100644 (file)
--- a/test.cc
+++ b/test.cc
 // limitations under the License.
 
 #include <iostream>
+#include <stdexcept>
+#include <vector>
 
+#include "test_helpers.h"
 #include "zcpointer.h"
 
-class C {
+class TestFailure : public std::logic_error {
  public:
-  ~C() {
-    std::cout << "~C" << std::endl;
-  }
-
-  void DoThing() {
-    std::cout << "DoThing" << std::endl;
-  }
+  using std::logic_error::logic_error;
 };
 
+#define STRING(x) QUOTE(x)
+#define QUOTE(x) #x
+#define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
+
+#define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
+
+#if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
+
+#define EXPECT_UAF(expr) do { \
+    try { \
+      (expr); \
+      throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
+    } catch (zc::UseAfterFreeError) {} \
+  } while(0)
+
+#else
+
+#define EXPECT_UAF(expr) do { \
+    std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
+    try { \
+      (expr); \
+    } catch (std::logic_error& e) { \
+      std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
+    } \
+  } while(0)
+
+#endif
+
 void TestReset() {
   zc::owned<C> c(new C());
   zc::ref<C> owned = c.get();
   zc::ref<C> owned2 = owned;
   c.reset();
-  owned2->DoThing();
+  EXPECT_UAF(owned2->DoThing());
 }
 
 template <typename T>
 void TestUnwrap() {
   zc::owned<T> t(new T());
-  //T* unwrap = t.get();
+  T* unwrap = t.get();
 
   zc::ref<T> ref = t.get();
   T* unwrap2 = ref;
+
+  zc::member<T> tm;
+  T* tp = &tm;
 }
 
 void TestMove() {
@@ -52,7 +80,7 @@ void TestMove() {
   owned->DoThing();
 
   c2.reset();
-  owned->DoThing();
+  EXPECT_UAF(owned->DoThing());
 }
 
 void PtrHelper(zc::ref<C>* out) {
@@ -63,7 +91,113 @@ void PtrHelper(zc::ref<C>* out) {
 void TestPtr() {
   zc::ref<C> ref;
   PtrHelper(&ref);
-  ref->DoThing();
+  EXPECT_UAF(ref->DoThing());
+}
+
+void TestEquality() {
+  zc::owned<C> a(new C());
+  zc::owned<C> b(new C());
+
+  EXPECT(a == a);
+  EXPECT(b == b);
+  EXPECT(a != b);
+
+  zc::ref<C> ra = a.get();
+  zc::ref<C> rb = b.get();
+
+  EXPECT(ra == ra);
+  EXPECT(ra == a.get());
+  EXPECT(rb == rb);
+  EXPECT(rb == b.get());
+
+  EXPECT(rb != ra);
+
+  zc::ref<C> r = a.get();
+  EXPECT(r == ra);
+  EXPECT(r == a.get());
+
+  zc::owned<C> c;
+  zc::owned<C> c2;
+  zc::ref<C> rc = nullptr;
+
+  EXPECT(rc == c.get());
+  EXPECT(c == nullptr);
+  EXPECT(rc == nullptr);
+  EXPECT(a != c);
+  EXPECT(c == c2);
+}
+
+void TestNulls() {
+  zc::owned<C> l;
+  zc::owned<C> r;
+
+  zc::ref<C> rl = l.get();
+  zc::ref<C> rr = r.get();
+
+  r = std::move(l);
+  rl = rr;
+
+  EXPECT(l == nullptr);
+  EXPECT(r == nullptr);
+  EXPECT(rl == nullptr);
+  EXPECT(rr == nullptr);
+}
+
+void TestVector() {
+  zc::owned<C> c;
+  std::vector<zc::ref<C>> vec{
+    c.get(),
+    c.get(),
+    c.get()
+  };
+
+  for (const auto& r : vec) {
+    EXPECT(r == c.get());
+  }
+
+  zc::ref<C> ref;
+  {
+    std::vector<zc::owned<C>> vec;
+    vec.push_back(std::move(zc::owned<C>(new C())));
+    vec.push_back(std::move(zc::owned<C>(new C())));
+    vec.push_back(std::move(zc::owned<C>(new C())));
+    ref = vec[1].get();
+  }
+  EXPECT_UAF(ref->DoThing());
+}
+
+void TestStack() {
+  zc::ref<C> rc;
+  {
+    zc::member<C> c;
+    rc = &c;
+    EXPECT(rc == &c);
+    c.DoThing();
+  }
+  EXPECT_UAF(rc->DoThing());
+}
+
+void TestMember() {
+  zc::ref<C> ref;
+  zc::ref<std::vector<C>> vec_ref;
+  {
+    X x("hello world");
+    ref = x.c();
+    vec_ref = x.vec_c();
+
+    vec_ref->push_back(C());
+    vec_ref->push_back(C());
+
+    vec_ref->at(1).DoThing();
+  }
+  EXPECT_UAF(ref->DoThing());
+  EXPECT_UAF(vec_ref->at(1).DoThing());
+
+  {
+    zc::member<X> x("foo bar");
+    ref = x.c();
+  }
+  EXPECT_UAF(ref->DoThing());
 }
 
 #define TEST_FUNC(fn) { #fn , Test##fn }
@@ -76,16 +210,25 @@ int main() {
     TEST_FUNC(Reset),
     TEST_FUNC(Move),
     TEST_FUNC(Ptr),
+    TEST_FUNC(Equality),
+    TEST_FUNC(Nulls),
+    TEST_FUNC(Vector),
+    TEST_FUNC(Stack),
+    TEST_FUNC(Member),
   };
 
+  bool passed = true;
   for (const auto& test : kTests) {
     std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
     try {
       test.test();
-      std::cout << "=== FAIL " << test.name
-                << ": Did not receive UseAfterFreeException ===" << std::endl;
-    } catch (zc::UseAfterFreeError) {
-      std::cout << "=== PASS " << test.name << " ===" << std::endl;
+      std::cout << "+++ PASS " << test.name << " +++" << std::endl;
+    } catch (const TestFailure& e) {
+      passed = false;
+      std::cout << "!!! FAIL " << test.name
+                << ": Assertion failure: " << e.what() << " ===" << std::endl;
     }
   }
+
+  return passed ? 0 : 1;
 }