C++ shared_ptr的尝试实现

81 阅读3分钟

成员变量

  • 原始对象指针
  • 引用计数指针

实现接口

  • 构造与析构
  • 拷贝构造
  • 移动构造
  • 赋值运算符
  • 移动赋值运算符
  • 解引用运算符
  • 箭头运算符
  • 获取原始指针
  • 获取引用计数
  • 重置函数
  • 释放资源函数
#include <atomic>

template<typename T>
class shared_ptr 
{
private:
    T* data;
    std::atomic<size_t>* ref_count; // 原子计数器指针确保在堆中可共享

    void release(){        
        if(ref_count){
            // fetch_sub,原子减1后返回原值
            // std::memory_order_acq_rel,当前线程的读或写内存不能被重排到此存储之前或之后
            ref_count->fetch_sub(1, std::memory_order_acq_rel);
            if(*ref_count == 1){
                delete data; // 释放托管对象
                delete ref_count; // 释放计数器
            }
        }
    }

public:
    shared_ptr() : data(nullptr),  ref_count(nullptr){}
    ~shared_ptr(){  release();  }

    // 防止隐式类型转换
    explicit shared_ptr(T* p) : data(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr){}

    // 拷贝构造
    shared_ptr(const shared_ptr<T>& other) : data(other->data) {
        if(ref_count){
            // fetch_add,原子加1后返回原值
            // std::memory_order_relaxed,不要求定序或同步
            ref_count->fetch_add(1, std::memory_order_relaxed);
        }
    }

    // 移动构造函数。noexcept,表面该函数不会抛出异常
    shared_ptr(const shared_ptr<T>&& other) noexcept : data(other->data), ref_count(other->ref_count){
        other->data = nullptr;
        other->ref_count = nullptr;
    }

    // 拷贝赋值运算符重载
    shared_ptr<T>& operator=(const shared_ptr<T>& other) {
        // 通过地址比较判断是否为同一对象
        if(this != &other){
            release(); // 释放当前资源
            this->data = other->data;
            this->ref_count = other->ref_count;
        }
        if(ref_count){
            // fetch_add,原子加1后返回原值
            // std::memory_order_relaxed,不要求定序或同步
            ref_count->fetch_add(1, std::memory_order_relaxed);
        }
        return *this; // 返回当前对象的左值引用
    }    

    // 移动赋值运算符重载
    shared_ptr<T>& operator=(const shared_ptr&& other){
        // 通过地址比较判断是否为同一对象
        if(this != &other){
            release(); // 释放当前资源
            this->data = other.data;
            this->ref_count = other.ref_count;
            other->data = nullptr;
            other->ref_count = nullptr;
        }
        return *this; // 返回当前对象的左值引用
    }

    // 解引用运算符重载,返回左值引用,避免拷贝
    T& operator*() const {
        return *data; // 获取指针指向的值
    }

    // 箭头运算符,返回原生指针,指向实际管理的对象
    T* operator->() const {
        return data;
    }

    // 获取原始指针
    T* get() const {
        return data;
    }

    // 获取引用计数,原子读取
    std::size_t get_count() const {
        return ref_count ? ref_count->load(std::memory_order_acquire) : 0;
    }

    // 重置指针或置空
    void reset(T* p = nullptr){
        release();
        this->data = p->data;
        this->ref_count = p ? new std::atomic<std::size_t>(1) : nullptr;
    }

};

测试用例

编译命令g++ -std=c++20 testsp.cpp -o testsp -I. -lpthread

#include <iostream>
#include <thread>
#include <chrono>

#include "shared_ptr.h"

int main(){
    // 创建共享指针
    std::shared_ptr<int> ptr(new int(42));

    // 创建多个线程操作引用计数
    const int num_threads = 10;
    std::vector<std::thread> threads;

    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back([&ptr]() {
            for (int j = 0; j < 10000; ++j) {
                // 创建局部 shared_ptr 增加引用计数
                std::shared_ptr<int> local_ptr(ptr);
                
                // 短暂暂停增加线程切换概率
                std::this_thread::sleep_for(std::chrono::milliseconds(1));
            }
        });
    }

    // 等待所有线程完成
    for (auto& thread : threads) {
        thread.join();
    }

    // 验证引用计数
    std::cout << "use_count: " << ptr.use_count() << std::endl;
    if (ptr.use_count() == 1) {
        std::cout << "Test passed: shared_ptr is thread-safe!" << std::endl;
    } else {
        std::cout << "Test failed: shared_ptr is not thread-safe!" << std::endl;
    }
}