之前有一篇文章谈到链路追踪场景下,需要在异步线程之间,实现跨线程的ThreadLocal传递, 简单场景可以用InheritableThreadLocal,但ITL在线程池化场景下不适用,因为ITL是在子线程初始化时,拷贝了父线程的ThreadLocal,但池化场景,子线程是会被多次复用的,但ITL只能在子线程第一次创建时,传递ThreadLocal,之后的复用都无法重新设置ThreadLocal。于是TransmittableThreadLocal出现了,可以解决ThreadLocal在线程池化场景下的传递问题。
使用方法
-
引入依赖
<dependency> <groupId>com.alibaba</groupId> <artifactId>transmittable-thread-local</artifactId> <version>2.11.4</version> </dependency> -
创建TransmittableThreadLocal对象
-
提交Runnable对象到线程池中时,使用TtlRunnable进行包裹
如此,即可以完成ThreadLocal变量在父子线程之间的传递,并且不会出现InheritableThreadLocal在
线程池化场景下的问题
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* @author yogurtzzz
* @date 2020/5/28 15:14
**/
public class ThreadLocalTest {
private static ThreadLocal<String> ttl = new TransmittableThreadLocal<>();
/** 保证只有1个线程,以便观察这个线程被多个Runnable复用时,能否成功完成ThreadLocal的传递 **/
private static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
1,1,0,TimeUnit.SECONDS,new ArrayBlockingQueue<>(10)
);
public static void main(String[] args) throws InterruptedException {
ttl.set("yogurtzzz");
for (int i = 0; i < 5; i++) {
if (i == 2) {
ttl.set("changed");
}
TtlRunnable runnable = TtlRunnable.get(() -> {
System.out.println(Thread.currentThread().getName() + " : " + ttl.get());
});
threadPoolExecutor.execute(runnable);
//CompletableFuture.runAsync(runnable,threadPoolExecutor);
TimeUnit.MILLISECONDS.sleep(500);
}
}
}
结果:

原理
- TTL继承自InheritableThreadLocal
- 通过一个Holder,保存了每个线程当前持有的所有ThreadLocal对象
- 用TtlRunnable的get方法来包裹一个Runnable对象,包裹对象时,会采用类似SNAPSHOT,快照的机制,通过Holder,捕获父线程当前持有的所有ThreadLocal。随后,子线程启动,在Runnable对象执行run方法之前,从Holder中取出先前捕获到的父线程所持有的ThreadLocal对象,并设置到当前子线程当中,设置之前会保存子线程原有的ThreadLocal作为backUp,当子线程执行结束后,通过backUp恢复其原有的ThreadLocal
手写源码
下面仿造TTL的核心思路,手写了一个简单版的实现,也能解决ThreadLocal在线程间传递的问题
- CoreThreadLocal
import java.util.Iterator;
import java.util.WeakHashMap;
/**
* @author yogurtzzz
* @date 2020/5/28 15:11
*
* 简单版TransmittableThreadLocal
* 用于处理InheritableThreadLocal在线程池化场景下的问题
* 需配合CoreRunnable使用
**/
public class CoreThreadLocal<T> extends InheritableThreadLocal<T> {
/**
* 用来保存当前线程持有的threadLocal
* 它是static的
* 所有线程都会有一个holder的threadLocal
* 这个holder是一个map,保存着当前线程所持有的threadLocal
* **/
private static InheritableThreadLocal<WeakHashMap<CoreThreadLocal<Object>,?>> holder =
new InheritableThreadLocal<WeakHashMap<CoreThreadLocal<Object>,?>>() {
@Override
protected WeakHashMap<CoreThreadLocal<Object>, ?> childValue(WeakHashMap<CoreThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<>(parentValue);
}
@Override
protected WeakHashMap<CoreThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<>();
}
};
@Override
public T get() {
T value = super.get();
if (null != value) {
addToHolder();
}
return value;
}
@Override
public void set(T value) {
if (null == value) {
removeFromHolder();
super.remove();
}else {
super.set(value);
addToHolder();
}
}
private void addToHolder() {
if (!holder.get().containsKey(this)) {
holder.get().put((CoreThreadLocal<Object>)this,null);
}
}
private void removeFromHolder() {
holder.get().remove(this);
}
static class Transmitter {
/** 捕捉当前父线程的threadLocal **/
public static SnapShot capture() {
return new SnapShot(captureCtlValues());
}
private static WeakHashMap<CoreThreadLocal<Object>, Object> captureCtlValues() {
WeakHashMap<CoreThreadLocal<Object>,Object> ctlValues = new WeakHashMap<>();
/** 从holder中取当前线程持有的threadLocal **/
for (CoreThreadLocal<Object> ctlItem : holder.get().keySet()) {
ctlValues.put(ctlItem,ctlItem.get());
}
return ctlValues;
}
/**
* 将capture设置到当前线程,并保存当前线程原有的threadLocal,返回
* **/
public static CoreThreadLocal.SnapShot replay(CoreThreadLocal.SnapShot snapShot) {
WeakHashMap<CoreThreadLocal<Object>, Object> capture = snapShot.ctlValue;
WeakHashMap<CoreThreadLocal<Object>, Object> backValue = new WeakHashMap<>();
/**
* 从holder中获取当前线程持有的threadLocal的Map,进行迭代保存
* **/
Iterator<CoreThreadLocal<Object>> iterator = holder.get().keySet().iterator();
while (iterator.hasNext()) {
CoreThreadLocal<Object> threadLocal = iterator.next();
backValue.put(threadLocal,threadLocal.get());
if (!capture.containsKey(threadLocal)) {
iterator.remove();
threadLocal.remove();
}
}
/**
* 设置上capture
* */
setThreadLocal(capture);
return new SnapShot(backValue);
}
private static void setThreadLocal(WeakHashMap<CoreThreadLocal<Object>, Object> ctlValues) {
ctlValues.forEach((key, value) -> {
key.set(value);
});
}
/**
* 恢复为backUp
* **/
public static void restore(CoreThreadLocal.SnapShot backUp) {
Iterator<CoreThreadLocal<Object>> iterator = holder.get().keySet().iterator();
while (iterator.hasNext()) {
CoreThreadLocal<Object> threadLocal = iterator.next();
if (!backUp.ctlValue.containsKey(threadLocal)) {
iterator.remove();
threadLocal.remove();
}
}
setThreadLocal(backUp.ctlValue);
}
}
static class SnapShot {
final WeakHashMap<CoreThreadLocal<Object>,Object> ctlValue;
private SnapShot(WeakHashMap<CoreThreadLocal<Object>,Object> ctlValue) {
this.ctlValue = ctlValue;
}
}
}
- CoreRunnable
public class CoreRunnable implements Runnable{
private AtomicReference<CoreThreadLocal.SnapShot> captureRef;
private Runnable runnable;
private CoreRunnable(Runnable runnable) {
this.runnable = runnable;
captureRef = new AtomicReference<>(Transmitter.capture());
}
@Override
public void run() {
CoreThreadLocal.SnapShot capture = captureRef.get();
CoreThreadLocal.SnapShot backUp = Transmitter.replay(capture);
try{
runnable.run();
}finally {
Transmitter.restore(backUp);
}
}
/** 调用这个函数进行包装 **/
public static CoreRunnable getRunnable(Runnable runnable) {
return new CoreRunnable(runnable);
}
}
- 测试
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* @author yogurtzzz
* @date 2020/6/4 13:05
**/
public class Test {
private static ThreadLocal<String> ttl = new CoreThreadLocal<>();
/** 保证只有1个线程,以便观察这个线程被多个Runnable复用时,能否成功完成ThreadLocal的传递 **/
private static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
1,1,0,TimeUnit.SECONDS,new ArrayBlockingQueue<>(10)
);
public static void main(String[] args) throws InterruptedException {
ttl.set("yogurtzzz");
for (int i = 0; i < 5; i++) {
if (i == 2) {
ttl.set("changed");
}
CoreRunnable runnable = CoreRunnable.getRunnable(() -> {
System.out.println(Thread.currentThread().getName() + " : " + ttl.get());
});
threadPoolExecutor.execute(runnable);
//CompletableFuture.runAsync(runnable,threadPoolExecutor);
TimeUnit.MILLISECONDS.sleep(500);
}
}
}
结果:


5965

被折叠的 条评论
为什么被折叠?



