shared_ptr自定义实现

247 阅读1分钟

表示引用计数的封装类

/**
 * 智能指针中的引用计数类。
 */
class RefCount {
public:
    RefCount() : count_(new size_t(1)) {}

    // 拷贝构造
    RefCount(const RefCount &rhs) : count_(rhs.count_) {
        (*count_)++;
    }

    // 移动构造
    RefCount(RefCount &&rhs) : count_(rhs.count_) {
        rhs.count_ = nullptr;
    }

    // 拷贝赋值
    RefCount &operator=(const RefCount &rhs) {
        tryRelease();
        count_ = rhs.count_;
        (*count_)++;
        return *this;
    }

    // 移动赋值
    RefCount &operator=(RefCount &&rhs)
    {
        tryRelease();
        count_ = rhs.count_;
        rhs.count_ = nullptr;
        return *this;
    }

    ~RefCount()
    {
        if (*count_ == 1) {
            delete count_;
            count_ = nullptr;
        } else {
            (*count_)--;
        }
    }

    size_t IsOnly()
    {
        return *count_ == 1;
    }

    // 尝试减少引用计数并释放内存
    void tryRelease()
    {
        if (count_ == nullptr) {
            return;
        }

        if (*count_ == 1) {
            delete count_;
        } else {
            (*count_)--;
        }
    }

private:
    size_t* count_;
};

智能指针封装类

template <typename T>
class SmartPtr {
public:
    // 默认构造函数
    SmartPtr() : SmartPtr(nullptr) {}

    // 原始构造函数
    explicit SmartPtr(T* ptr = nullptr) : ptr_(ptr) {}

    // 拷贝构造函数
    SmartPtr(const SmartPtr& rhs)
    {
        ptr_ = rhs.ptr_;
        count_ = rhs.count_;
    }

    // 拷贝赋值运算符
    SmartPtr& operator=(const SmartPtr& rhs)
    {
        if (count_.IsOnly()) {
            delete ptr_;
        } else {
            count_.tryRelease();
        }

        ptr_ = rhs.ptr_;
        count_ = rhs.count_;
        return *this;
    }

    SmartPtr(SmartPtr&& rhs) noexcept : ptr_(rhs.ptr_), count_(rhs.count_)
    {
        rhs.count_ = rhs.ptr_ = nullptr;
    }

    SmartPtr& operator=(const SmartPtr&& rhs)
    {
        if (this == &rhs) {
            return *this;
        }

        count_ = std::move(rhs.count_);
        ptr_ = rhs.ptr_;
        return *this;
    }

    // 析构函数
    ~SmartPtr()
    {
        std::cout << "SmartPtr deconstructor" << std::endl;
        if (count_.IsOnly()) {
            delete ptr_;
            ptr_ = nullptr;
        }
        count_.tryRelease();
    }

    // 解引用运算符*
    T& operator*() const
    {
        return *ptr_;
    }

    // 成员访问运算符->
    T* operator->() const
    {
        return ptr_;
    }

private:
    T* ptr_; // 实际维护的地址
    RefCount count_;
};