TransmittableThreadLocal,文中简称TTL。
InheritableThreadLocal,文中简称ITL。
Thread.inheritableThreadLocals字段,简称itls。
简介
TTL产生的背景
TTL是为了解决线程复用时,ThreadLocal在父子线程之间的值传递问题。如果不存在线程复用问题,ITL完全可以满足需求。
ITL、TTL区别
ITL是在线程创建时,将主线程的inheritableThreadLocals赋值到子线程,而TTL是在线程执行前,完成赋值
整体流程
-
主线程:保存使用的ttl,调用TTL#set、#get后,将ttl设置到holder中
-
主线程:初始化子线程,TtlRunnable#get。并快照holder中的ttl,Transmitter#caputre
-
子线程:执行#run之前,设置主线程的ttl快照。TtlRunnable#run => Transmitte#replay
- 保存子线程ttl的值作为备份
- 将ttl快照中的值设置到子线程中
-
子线程:#run执行完成后,将备份恢复Transmitte#restore
流程分析
核心类
TTL类
-
TTL继承ITL,重写#set、#get方法,并在方法中增加#addThisToHolder方法
-
#addThisToHolde:将ttl对象作为key,存储到holder中,等待子线程使用
-
新增#beforeExecute、#afterExecute回调方法
holder属性
-
holder的类型是:ITL。使用private static final修饰
-
holder用来存储项目中所有的ttl对象,也是按线程隔离
- 只会关注使用过的ttl对象,调用过TTL#set、#get方法
- 只会关注TTL类型的ThreadLocal
TtlRunnable
- 工厂设计模式。通过#get,生成TtlRunnable,并将holder中记录的ttl进行快照
- 装饰器模式,对Runnbale进行增强。
- 执行#run方法之前,将复制的ttl对象值,设置到子类的itls中
- 执行#run方法之后,将子类的ttl还原
Transmitter
- 用来在线程切换时进行数据的快照保存(capture)、重放(replay)和恢复(restore)
先看Demo
public class TtlTest2 {
private static final ThreadLocal<String> ttl1 = new TransmittableThreadLocal<>();
private static final CountDownLatch countDownLatch = new CountDownLatch(1);
public static void main(String[] args) throws InterruptedException {
ttl1.set("main");
Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(20));
ttlExecutor.execute(() -> {
System.out.println("log1-> ttl: " + ttl1.get());
ttl1.set("thread");
System.out.println("log2-> ttl: " + ttl1.get());
countDownLatch.countDown();
});
countDownLatch.await();
System.out.println("log3-> ttl: " + ttl1.get());
}
}
执行结果:
log1-> ttl: main
log2-> ttl: thread
log3-> ttl: main
- #capture
public static Object capture() {
return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
/**
* 复制线程在holder中的ttl对象
* 返回值Map:key:ttl对象,value为ttl对象的值(这里是浅拷贝,直接引用的)
* Demo中主线程创建了ttlA对象,set了值为main,这里Map返回的就是 <ttl1,main>
*/
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
/**
* threadLocalHolder,同样支持ThreadLocal
*/
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
final TtlCopier<Object> copier = entry.getValue();
threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
}
return threadLocal2Value;
}
- #replay
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
/**
* 备份当前线程的holder中的ttl值,把ttl快照<ttl1,main>设置到当前线程中
*/
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// 备份ttl
backup.put(threadLocal, threadLocal.get());
// 删除快照中不存在的ttl【待研究】
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// 线程ttl赋值快照的ttl<ttl1,main>,ttl1设置为main
setTtlValuesTo(captured);
// 执行#beforeExecute回调
doExecuteCallback(true);
return backup;
}
private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<ThreadLocal<Object>, Object>();
for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
backup.put(threadLocal, threadLocal.get());
final Object value = entry.getValue();
if (value == threadLocalClearMark) threadLocal.remove();
else threadLocal.set(value);
}
return backup;
}
- #restore
public static void restore(@NonNull Object backup) {
final Snapshot backupSnapshot = (Snapshot) backup;
restoreTtlValues(backupSnapshot.ttl2Value);
restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
// 执行#afterExecute回调
doExecuteCallback(false);
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
// 子线程中,ttl1的value已经由 mian -> thread了
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// 还原ttl1中的值,thread -> main
setTtlValuesTo(backup);
}
private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
使用注意:
使用增强过的Runnable
使用TtlExecutor#getTtlExecutorService获取线程池 或者 重写ThreadPoolTaskExecutor的方法,主动调用TtlRunnable#get
public class MyThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {
@Override
public void execute(@Nullable Runnable task) {
Runnable runnable = TtlRunnable.get(task);
assert runnable != null;
super.execute(runnable);
}
深-浅拷贝问题
问题Demo
public class TtlTest3 {
private static final ThreadLocal<User> ttl1 = new TransmittableThreadLocal<>();
private static final CountDownLatch countDownLatch = new CountDownLatch(1);
public static void main(String[] args) throws InterruptedException {
ttl1.set(new User("张三", 20));
Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(20));
ttlExecutor.execute(() -> {
System.out.println("log1-> ttl: " + ttl1.get());
User user = ttl1.get();
user.setName("李四");
ttl1.set(user);
System.out.println("log1-> ttl: " + ttl1.get());
countDownLatch.countDown();
});
countDownLatch.await();
System.out.println("log3-> ttl: " + ttl1.get());
}
}
输出:
log1-> ttl: User(name=张三, age=20)
log2-> ttl: User(name=李四, age=20)
log3-> ttl: User(name=李四, age=20)
解决方案: 重写TTL#copy
public class UserTtl extends TransmittableThreadLocal<UserInfo> {
@Override
public UserInfo copy(UserInfo parentValue) {
return new UserInfo(parentValue.getName(),parentValue.getAge());
}
}