In OwnedPtrDeleter<T>, store a vptr for a behavior function.
[zcpointer.git] / test.cc
diff --git a/test.cc b/test.cc
index 2d43d3f82875c92289e09cab1995e658931be869..d28d5912be1adfd582126a9585286cd432529d3d 100644 (file)
--- a/test.cc
+++ b/test.cc
 // limitations under the License.
 
 #include <iostream>
+#include <stdexcept>
 #include <vector>
 
+#include "test_helpers.h"
 #include "zcpointer.h"
 
-class C {
+class TestFailure : public std::logic_error {
  public:
-  ~C() {}
-
-  void DoThing() {}
+  using std::logic_error::logic_error;
 };
 
-#define EXPECT(expr) do { if (!(expr)) { throw std::logic_error(#expr); } } while(0)
+#define STRING(x) QUOTE(x)
+#define QUOTE(x) #x
+#define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
+
+#define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
+
+#if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
 
 #define EXPECT_UAF(expr) do { \
     try { \
       (expr); \
-      throw std::logic_error("Expected use-after-free: " #expr); \
+      throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
     } catch (zc::UseAfterFreeError) {} \
   } while(0)
 
+#else
+
+#define EXPECT_UAF(expr) do { \
+    std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
+    try { \
+      (expr); \
+    } catch (std::logic_error& e) { \
+      std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
+    } \
+  } while(0)
+
+#endif
+
 void TestReset() {
   zc::owned<C> c(new C());
   zc::ref<C> owned = c.get();
@@ -44,10 +63,13 @@ void TestReset() {
 template <typename T>
 void TestUnwrap() {
   zc::owned<T> t(new T());
-  //T* unwrap = t.get();
+  T* unwrap = t.get();
 
   zc::ref<T> ref = t.get();
   T* unwrap2 = ref;
+
+  zc::member<T> tm;
+  T* tp = &tm;
 }
 
 void TestMove() {
@@ -144,6 +166,40 @@ void TestVector() {
   EXPECT_UAF(ref->DoThing());
 }
 
+void TestStack() {
+  zc::ref<C> rc;
+  {
+    zc::member<C> c;
+    rc = &c;
+    EXPECT(rc == &c);
+    c.DoThing();
+  }
+  EXPECT_UAF(rc->DoThing());
+}
+
+void TestMember() {
+  zc::ref<C> ref;
+  zc::ref<std::vector<C>> vec_ref;
+  {
+    X x("hello world");
+    ref = x.c();
+    vec_ref = x.vec_c();
+
+    vec_ref->push_back(C());
+    vec_ref->push_back(C());
+
+    vec_ref->at(1).DoThing();
+  }
+  EXPECT_UAF(ref->DoThing());
+  EXPECT_UAF(vec_ref->at(1).DoThing());
+
+  {
+    zc::member<X> x("foo bar");
+    ref = x.c();
+  }
+  EXPECT_UAF(ref->DoThing());
+}
+
 #define TEST_FUNC(fn) { #fn , Test##fn }
 
 int main() {
@@ -157,6 +213,8 @@ int main() {
     TEST_FUNC(Equality),
     TEST_FUNC(Nulls),
     TEST_FUNC(Vector),
+    TEST_FUNC(Stack),
+    TEST_FUNC(Member),
   };
 
   bool passed = true;
@@ -164,10 +222,10 @@ int main() {
     std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
     try {
       test.test();
-      std::cout << "=== PASS " << test.name << " ===" << std::endl;
-    } catch (const std::logic_error& e) {
+      std::cout << "+++ PASS " << test.name << " +++" << std::endl;
+    } catch (const TestFailure& e) {
       passed = false;
-      std::cout << "=== FAIL " << test.name
+      std::cout << "!!! FAIL " << test.name
                 << ": Assertion failure: " << e.what() << " ===" << std::endl;
     }
   }