成员变量:
- 原始对象指针
- 引用计数指针
实现接口:
- 构造与析构
- 拷贝构造
- 移动构造
- 赋值运算符
- 移动赋值运算符
- 解引用运算符
- 箭头运算符
- 获取原始指针
- 获取引用计数
- 重置函数
- 释放资源函数
#include <atomic>
template<typename T>
class shared_ptr
{
private:
T* data;
std::atomic<size_t>* ref_count; // 原子计数器指针确保在堆中可共享
void release(){
if(ref_count){
// fetch_sub,原子减1后返回原值
// std::memory_order_acq_rel,当前线程的读或写内存不能被重排到此存储之前或之后
ref_count->fetch_sub(1, std::memory_order_acq_rel);
if(*ref_count == 1){
delete data; // 释放托管对象
delete ref_count; // 释放计数器
}
}
}
public:
shared_ptr() : data(nullptr), ref_count(nullptr){}
~shared_ptr(){ release(); }
// 防止隐式类型转换
explicit shared_ptr(T* p) : data(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr){}
// 拷贝构造
shared_ptr(const shared_ptr<T>& other) : data(other->data) {
if(ref_count){
// fetch_add,原子加1后返回原值
// std::memory_order_relaxed,不要求定序或同步
ref_count->fetch_add(1, std::memory_order_relaxed);
}
}
// 移动构造函数。noexcept,表面该函数不会抛出异常
shared_ptr(const shared_ptr<T>&& other) noexcept : data(other->data), ref_count(other->ref_count){
other->data = nullptr;
other->ref_count = nullptr;
}
// 拷贝赋值运算符重载
shared_ptr<T>& operator=(const shared_ptr<T>& other) {
// 通过地址比较判断是否为同一对象
if(this != &other){
release(); // 释放当前资源
this->data = other->data;
this->ref_count = other->ref_count;
}
if(ref_count){
// fetch_add,原子加1后返回原值
// std::memory_order_relaxed,不要求定序或同步
ref_count->fetch_add(1, std::memory_order_relaxed);
}
return *this; // 返回当前对象的左值引用
}
// 移动赋值运算符重载
shared_ptr<T>& operator=(const shared_ptr&& other){
// 通过地址比较判断是否为同一对象
if(this != &other){
release(); // 释放当前资源
this->data = other.data;
this->ref_count = other.ref_count;
other->data = nullptr;
other->ref_count = nullptr;
}
return *this; // 返回当前对象的左值引用
}
// 解引用运算符重载,返回左值引用,避免拷贝
T& operator*() const {
return *data; // 获取指针指向的值
}
// 箭头运算符,返回原生指针,指向实际管理的对象
T* operator->() const {
return data;
}
// 获取原始指针
T* get() const {
return data;
}
// 获取引用计数,原子读取
std::size_t get_count() const {
return ref_count ? ref_count->load(std::memory_order_acquire) : 0;
}
// 重置指针或置空
void reset(T* p = nullptr){
release();
this->data = p->data;
this->ref_count = p ? new std::atomic<std::size_t>(1) : nullptr;
}
};
测试用例:
编译命令g++ -std=c++20 testsp.cpp -o testsp -I. -lpthread
#include <iostream>
#include <thread>
#include <chrono>
#include "shared_ptr.h"
int main(){
// 创建共享指针
std::shared_ptr<int> ptr(new int(42));
// 创建多个线程操作引用计数
const int num_threads = 10;
std::vector<std::thread> threads;
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([&ptr]() {
for (int j = 0; j < 10000; ++j) {
// 创建局部 shared_ptr 增加引用计数
std::shared_ptr<int> local_ptr(ptr);
// 短暂暂停增加线程切换概率
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
});
}
// 等待所有线程完成
for (auto& thread : threads) {
thread.join();
}
// 验证引用计数
std::cout << "use_count: " << ptr.use_count() << std::endl;
if (ptr.use_count() == 1) {
std::cout << "Test passed: shared_ptr is thread-safe!" << std::endl;
} else {
std::cout << "Test failed: shared_ptr is not thread-safe!" << std::endl;
}
}