In tests, do not expect zero-cost mode to catch UAF.
[zcpointer.git] / zcpointer.h
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <limits>
16 #include <memory>
17 #include <forward_list>
18 #include <stdexcept>
19
20 namespace zc {
21
22 class UseAfterFreeError : public std::logic_error {
23 public:
24 using std::logic_error::logic_error;
25 };
26
27 #if defined(ZCPOINTER_TRACK_REFS) && ZCPOINTER_TRACK_REFS
28
29 template <typename T> class ref;
30
31 namespace internal {
32
33 template <typename T>
34 class OwnedPtrDeleter {
35 public:
36 OwnedPtrDeleter() {}
37 ~OwnedPtrDeleter() {}
38
39 OwnedPtrDeleter(OwnedPtrDeleter&& other) : refs_(std::move(other.refs_)) {
40 }
41
42 void operator=(const OwnedPtrDeleter& o) {
43 refs_ = o.refs_;
44 }
45
46 void operator()(T* t) const {
47 for (auto& ref : refs_) {
48 ref->MarkDeleted();
49 }
50 delete t;
51 }
52
53 protected:
54 friend class ref<T>;
55
56 void AddRef(ref<T>* ref) {
57 refs_.push_front(ref);
58 }
59
60 void RemoveRef(ref<T>* ref) {
61 refs_.remove(ref);
62 }
63
64 private:
65 std::forward_list<ref<T>*> refs_;
66 };
67
68 void RaiseUseAfterFree(const char* error) __attribute__((noreturn));
69
70 } // namespace internal
71
72 template <typename T>
73 class owned : public std::unique_ptr<T, internal::OwnedPtrDeleter<T>> {
74 private:
75 using Deleter = internal::OwnedPtrDeleter<T>;
76
77 public:
78 using std::unique_ptr<T, Deleter>::unique_ptr;
79
80 ref<T> get() {
81 return ref<T>(*this);
82 }
83
84 private:
85 T* get() const {
86 return this->std::unique_ptr<T, Deleter>::get();
87 }
88 };
89
90 template <typename T>
91 class ref {
92 public:
93 ref() : ptr_(nullptr) {}
94
95 ref(std::nullptr_t) : ref() {}
96
97 explicit ref(owned<T>& o) : ptr_(nullptr) {
98 if (o != nullptr) {
99 ptr_ = &o;
100 ptr_->get_deleter().AddRef(this);
101 }
102 }
103
104 ref(const ref<T>& r) {
105 *this = r;
106 }
107
108 ref<T>& operator=(const ref<T>& o) {
109 ptr_ = o.ptr_;
110 if (ptr_ != nullptr && !IsDeleted()) {
111 ptr_->get_deleter().AddRef(this);
112 }
113 return *this;
114 }
115
116 ~ref() {
117 if (ptr_ != nullptr && !IsDeleted()) {
118 ptr_->get_deleter().RemoveRef(this);
119 }
120 MarkDeleted();
121 }
122
123 T* operator->() const {
124 CheckDeleted();
125 return ptr_->operator->();
126 }
127
128 bool operator==(const ref<T>& r) const {
129 if (ptr_ == nullptr) {
130 return r.ptr_ == nullptr;
131 } else {
132 return ptr_ == r.ptr_ && *ptr_ == *r.ptr_;
133 }
134 }
135
136 bool operator==(std::nullptr_t) const {
137 return ptr_ == nullptr;
138 }
139
140 bool operator!=(const ref<T>& r) const {
141 return !(*this == r);
142 }
143
144 protected:
145 friend class internal::OwnedPtrDeleter<T>;
146
147 void MarkDeleted() {
148 ptr_ = DeletedSentinel();
149 }
150
151 private:
152 void CheckDeleted() const {
153 if (IsDeleted()) {
154 internal::RaiseUseAfterFree("attempt to access deleted pointer");
155 }
156 }
157
158 bool IsDeleted() const {
159 return ptr_ == DeletedSentinel();
160 }
161
162 inline static owned<T>* DeletedSentinel() {
163 return reinterpret_cast<owned<T>*>(std::numeric_limits<uintptr_t>::max());
164 }
165
166 owned<T>* ptr_;
167 };
168
169 #else
170
171 template <typename T>
172 using owned = std::unique_ptr<T>;
173
174 template <typename T>
175 using ref = T*;
176
177 #endif
178
179 } // namespace zc