// limitations under the License.
#include <iostream>
+#include <stdexcept>
#include <vector>
+#include "test_helpers.h"
#include "zcpointer.h"
-class C {
+class TestFailure : public std::logic_error {
public:
- ~C() {}
-
- void DoThing() {}
+ using std::logic_error::logic_error;
};
-#define EXPECT(expr) do { if (!(expr)) { throw std::logic_error(#expr); } } while(0)
+#define STRING(x) QUOTE(x)
+#define QUOTE(x) #x
+#define AT_FILE_LINE " @ " __FILE__ ":" STRING(__LINE__)
+
+#define EXPECT(expr) do { if (!(expr)) { throw TestFailure(#expr AT_FILE_LINE); } } while(0)
+
+#if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
#define EXPECT_UAF(expr) do { \
try { \
(expr); \
- throw std::logic_error("Expected use-after-free: " #expr); \
+ throw TestFailure("Expected use-after-free: " #expr AT_FILE_LINE); \
} catch (zc::UseAfterFreeError) {} \
} while(0)
+#else
+
+#define EXPECT_UAF(expr) do { \
+ std::cout << ">>> ZCPOINTER_TRACK_REFS not enabled, cannot catch UAF" << std::endl; \
+ try { \
+ (expr); \
+ } catch (std::logic_error& e) { \
+ std::cout << ">>> Caught error: " << typeid(e).name() << ": " << e.what() << std::endl; \
+ } \
+ } while(0)
+
+#endif
+
void TestReset() {
zc::owned<C> c(new C());
zc::ref<C> owned = c.get();
template <typename T>
void TestUnwrap() {
zc::owned<T> t(new T());
- //T* unwrap = t.get();
+ T* unwrap = t.get();
zc::ref<T> ref = t.get();
T* unwrap2 = ref;
+
+ zc::member<T> tm;
+ T* tp = &tm;
}
void TestMove() {
EXPECT_UAF(ref->DoThing());
}
+void TestStack() {
+ zc::ref<C> rc;
+ {
+ zc::member<C> c;
+ rc = &c;
+ EXPECT(rc == &c);
+ c.DoThing();
+ }
+ EXPECT_UAF(rc->DoThing());
+}
+
+void TestMember() {
+ zc::ref<C> ref;
+ zc::ref<std::vector<C>> vec_ref;
+ {
+ X x("hello world");
+ ref = x.c();
+ vec_ref = x.vec_c();
+
+ vec_ref->push_back(C());
+ vec_ref->push_back(C());
+
+ vec_ref->at(1).DoThing();
+ }
+ EXPECT_UAF(ref->DoThing());
+ EXPECT_UAF(vec_ref->at(1).DoThing());
+
+ {
+ zc::member<X> x("foo bar");
+ ref = x.c();
+ }
+ EXPECT_UAF(ref->DoThing());
+}
+
#define TEST_FUNC(fn) { #fn , Test##fn }
int main() {
TEST_FUNC(Equality),
TEST_FUNC(Nulls),
TEST_FUNC(Vector),
+ TEST_FUNC(Stack),
+ TEST_FUNC(Member),
};
bool passed = true;
std::cout << "=== BEGIN " << test.name << " ===" << std::endl;
try {
test.test();
- std::cout << "=== PASS " << test.name << " ===" << std::endl;
- } catch (const std::logic_error& e) {
+ std::cout << "+++ PASS " << test.name << " +++" << std::endl;
+ } catch (const TestFailure& e) {
passed = false;
- std::cout << "=== FAIL " << test.name
+ std::cout << "!!! FAIL " << test.name
<< ": Assertion failure: " << e.what() << " ===" << std::endl;
}
}