In OwnedPtrDeleter<T>, store a vptr for a behavior function.
[zcpointer.git] / zcpointer.h
index 2edf05fb93c8b04acb27ef6783fb1ceaa4777495..2301ef190e3fda8bca83a74ad1487de78f8476bb 100644 (file)
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#ifndef ZCPOINTER_ZCPOINTER_H_
+#define ZCPOINTER_ZCPOINTER_H_
+
+#include <limits>
 #include <memory>
 #include <forward_list>
 #include <stdexcept>
+#include <utility>
 
 namespace zc {
 
+class UseAfterFreeError : public std::logic_error {
+ public:
+  using std::logic_error::logic_error;
+};
+
 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
 
 template <typename T> class ref;
 
 namespace internal {
 
+enum class OwnershipBehavior {
+  DELETE_POINTER,
+  BORROW_POINTER,
+};
+
 template <typename T>
 class OwnedPtrDeleter {
  public:
-  OwnedPtrDeleter() {}
+  OwnedPtrDeleter()
+      : refs_(),
+        finalizer_(&OwnedPtrDeleter<T>::HandleDeletePointer) {}
+
   ~OwnedPtrDeleter() {}
 
-  OwnedPtrDeleter(OwnedPtrDeleter&& other) : refs_(std::move(other.refs_)) {
+  explicit OwnedPtrDeleter(OwnershipBehavior behavior)
+      : refs_(),
+        finalizer_(behavior == OwnershipBehavior::BORROW_POINTER
+                       ? &OwnedPtrDeleter<T>::HandleBorrowPointer
+                       : &OwnedPtrDeleter<T>::HandleDeletePointer) {
+  }
+
+  OwnedPtrDeleter(OwnedPtrDeleter&& other)
+      : refs_(std::move(other.refs_)),
+        finalizer_(other.finalizer_) {
+  }
+
+  void operator=(const OwnedPtrDeleter& o) {
+    refs_ = o.refs_;
+    finalizer_ = o.finalizer_;
   }
 
   void operator()(T* t) const {
     for (auto& ref : refs_) {
       ref->MarkDeleted();
     }
-    delete t;
+    (this->finalizer_)(t);
   }
 
  protected:
@@ -51,13 +83,20 @@ class OwnedPtrDeleter {
     refs_.remove(ref);
   }
 
+  static void HandleDeletePointer(T* t) {
+    delete t;
+  }
+
+  static void HandleBorrowPointer(T* t) {}
+
  private:
+  void (*finalizer_)(T*);
   std::forward_list<ref<T>*> refs_;
 };
 
-}  // namespace internal
+void RaiseUseAfterFree() __attribute__((noreturn));
 
-using UseAfterFreeException = std::logic_error;
+}  // namespace internal
 
 template <typename T>
 class owned : public std::unique_ptr<T, internal::OwnedPtrDeleter<T>> {
@@ -71,13 +110,6 @@ class owned : public std::unique_ptr<T, internal::OwnedPtrDeleter<T>> {
     return ref<T>(*this);
   }
 
- protected:
-  friend class ref<T>;
-
-  T* GetRawPointer() const {
-    return get();
-  }
-
  private:
   T* get() const {
     return this->std::unique_ptr<T, Deleter>::get();
@@ -87,66 +119,95 @@ class owned : public std::unique_ptr<T, internal::OwnedPtrDeleter<T>> {
 template <typename T>
 class ref {
  public:
-  ref(owned<T>& o) : ptr_(o.GetRawPointer()), deleter_(o.get_deleter()) {
-    deleter_.AddRef(this);
-  }
+  ref() : ptr_(nullptr) {}
+
+  ref(std::nullptr_t) : ref() {}
 
-  ref(const ref<T>& o) : ptr_(o.ptr_), deleter_(o.deleter_), deleted_(o.deleted_) {
-    if (!deleted_) {
-      deleter_.AddRef(this);
+  explicit ref(owned<T>& o) : ptr_(nullptr) {
+    if (o != nullptr) {
+      ptr_ = &o;
+      ptr_->get_deleter().AddRef(this);
     }
   }
 
-  ref<T>& operator=(ref<T> o) {
+  ref(const ref<T>& r) {
+    *this = r;
+  }
+
+  ref<T>& operator=(const ref<T>& o) {
     ptr_ = o.ptr_;
-    deleter_ = o.deleter_;
-    deleted_ = o.deleted_;
-    if (!deleted_) {
-      deleter_.AddRef(this);
+    if (ptr_ != nullptr && !IsDeleted()) {
+      ptr_->get_deleter().AddRef(this);
     }
     return *this;
   }
 
   ~ref() {
-    deleter_.RemoveRef(this);
+    if (ptr_ != nullptr && !IsDeleted()) {
+      ptr_->get_deleter().RemoveRef(this);
+    }
+    MarkDeleted();
   }
 
-#if 0
-  operator T*() const {
+  T* operator->() const {
     CheckDeleted();
-    return ptr_;
+    return ptr_->operator->();
   }
-#endif
 
-  T* operator->() const {
-    CheckDeleted();
-    return ptr_;
+  bool operator==(const ref<T>& r) const {
+    if (ptr_ == nullptr) {
+      return r.ptr_ == nullptr;
+    } else {
+      return ptr_ == r.ptr_ && *ptr_ == *r.ptr_;
+    }
   }
 
-#if 0
-  T* get() {
-    CheckDeleted();
-    return ptr_;
+  bool operator==(std::nullptr_t) const {
+    return ptr_ == nullptr;
+  }
+
+  bool operator!=(const ref<T>& r) const {
+    return !(*this == r);
   }
-#endif
 
  protected:
   friend class internal::OwnedPtrDeleter<T>;
 
   void MarkDeleted() {
-    deleted_ = true;
+    ptr_ = DeletedSentinel();
   }
 
  private:
   void CheckDeleted() const {
-    if (deleted_) {
-      throw UseAfterFreeException("attempt to access deleted pointer");
+    if (IsDeleted()) {
+      internal::RaiseUseAfterFree();
     }
   }
 
-  T* ptr_;
-  internal::OwnedPtrDeleter<T>& deleter_;
-  bool deleted_ = false;
+  bool IsDeleted() const {
+    return ptr_ == DeletedSentinel();
+  }
+
+  inline static owned<T>* DeletedSentinel() {
+    return reinterpret_cast<owned<T>*>(std::numeric_limits<uintptr_t>::max());
+  }
+
+  owned<T>* ptr_;
+};
+
+template <typename T>
+class member : public T {
+ public:
+  using T::T;
+
+  ref<T> operator&() {
+    return ptr_.get();
+  }
+
+ private:
+  owned<T> ptr_ = owned<T>(this,
+                           internal::OwnedPtrDeleter<T>(
+                               internal::OwnershipBehavior::BORROW_POINTER));
 };
 
 #else
@@ -157,6 +218,11 @@ using owned = std::unique_ptr<T>;
 template <typename T>
 using ref = T*;
 
+template <typename T>
+using member = T;
+
 #endif
 
 }  // namespace zc
+
+#endif  // ZCPOINTER_ZCPOINTER_H_