TransmittableThreadLocal的简单使用 + 手写源码

之前有一篇文章谈到链路追踪场景下,需要在异步线程之间,实现跨线程的ThreadLocal传递, 简单场景可以用InheritableThreadLocal,但ITL在线程池化场景下不适用,因为ITL是在子线程初始化时,拷贝了父线程的ThreadLocal,但池化场景,子线程是会被多次复用的,但ITL只能在子线程第一次创建时,传递ThreadLocal,之后的复用都无法重新设置ThreadLocal。于是TransmittableThreadLocal出现了,可以解决ThreadLocal在线程池化场景下的传递问题。

使用方法

  1. 引入依赖

    <dependency>
        <groupId>com.alibaba</groupId>
        <artifactId>transmittable-thread-local</artifactId>
        <version>2.11.4</version>
    </dependency>
    
  2. 创建TransmittableThreadLocal对象

  3. 提交Runnable对象到线程池中时,使用TtlRunnable进行包裹

如此,即可以完成ThreadLocal变量在父子线程之间的传递,并且不会出现InheritableThreadLocal在

线程池化场景下的问题

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;

import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * @author yogurtzzz
 * @date 2020/5/28 15:14
 **/
public class ThreadLocalTest {

    private static ThreadLocal<String> ttl = new TransmittableThreadLocal<>();

    /** 保证只有1个线程,以便观察这个线程被多个Runnable复用时,能否成功完成ThreadLocal的传递 **/
    private static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
            1,1,0,TimeUnit.SECONDS,new ArrayBlockingQueue<>(10)
    );

    public static void main(String[] args) throws InterruptedException {
        ttl.set("yogurtzzz");

        for (int i = 0; i < 5; i++) {
            if (i == 2) {
                ttl.set("changed");
            }
            TtlRunnable runnable = TtlRunnable.get(() -> {
                System.out.println(Thread.currentThread().getName() + " : " + ttl.get());
            });

            threadPoolExecutor.execute(runnable);
            //CompletableFuture.runAsync(runnable,threadPoolExecutor);
            TimeUnit.MILLISECONDS.sleep(500);
        }
    }
}

结果:
在这里插入图片描述

原理

  1. TTL继承自InheritableThreadLocal
  2. 通过一个Holder,保存了每个线程当前持有的所有ThreadLocal对象
  3. 用TtlRunnable的get方法来包裹一个Runnable对象,包裹对象时,会采用类似SNAPSHOT,快照的机制,通过Holder,捕获父线程当前持有的所有ThreadLocal。随后,子线程启动,在Runnable对象执行run方法之前,从Holder中取出先前捕获到的父线程所持有的ThreadLocal对象,并设置到当前子线程当中,设置之前会保存子线程原有的ThreadLocal作为backUp,当子线程执行结束后,通过backUp恢复其原有的ThreadLocal

手写源码

下面仿造TTL的核心思路,手写了一个简单版的实现,也能解决ThreadLocal在线程间传递的问题

  1. CoreThreadLocal
import java.util.Iterator;
import java.util.WeakHashMap;
/**
 * @author yogurtzzz
 * @date 2020/5/28 15:11
 *
 * 简单版TransmittableThreadLocal
 * 用于处理InheritableThreadLocal在线程池化场景下的问题
 * 需配合CoreRunnable使用
 **/
public class CoreThreadLocal<T> extends InheritableThreadLocal<T> {
    /**
     * 用来保存当前线程持有的threadLocal
     * 它是static的
     * 所有线程都会有一个holder的threadLocal
     * 这个holder是一个map,保存着当前线程所持有的threadLocal
     * **/
    private static InheritableThreadLocal<WeakHashMap<CoreThreadLocal<Object>,?>> holder =
            new InheritableThreadLocal<WeakHashMap<CoreThreadLocal<Object>,?>>() {
                @Override
                protected WeakHashMap<CoreThreadLocal<Object>, ?> childValue(WeakHashMap<CoreThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<>(parentValue);
                }

                @Override
                protected WeakHashMap<CoreThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<>();
                }
            };

    @Override
    public T get() {
        T value = super.get();
        if (null != value) {
            addToHolder();
        }
        return value;
    }

    @Override
    public void set(T value) {
        if (null == value) {
            removeFromHolder();
            super.remove();
        }else {
            super.set(value);
            addToHolder();
        }
    }

    private void addToHolder() {
        if (!holder.get().containsKey(this)) {
            holder.get().put((CoreThreadLocal<Object>)this,null);
        }
    }

    private void removeFromHolder() {
        holder.get().remove(this);
    }

    static class Transmitter {
        /** 捕捉当前父线程的threadLocal **/
        public static SnapShot capture() {
            return new SnapShot(captureCtlValues());
        }

        private static WeakHashMap<CoreThreadLocal<Object>, Object> captureCtlValues() {
            WeakHashMap<CoreThreadLocal<Object>,Object> ctlValues = new WeakHashMap<>();
            /** 从holder中取当前线程持有的threadLocal **/
            for (CoreThreadLocal<Object> ctlItem : holder.get().keySet()) {
                ctlValues.put(ctlItem,ctlItem.get());
            }
            return ctlValues;
        }

        /**
         * 将capture设置到当前线程,并保存当前线程原有的threadLocal,返回
         * **/
        public static CoreThreadLocal.SnapShot replay(CoreThreadLocal.SnapShot snapShot) {
            WeakHashMap<CoreThreadLocal<Object>, Object> capture = snapShot.ctlValue;
            WeakHashMap<CoreThreadLocal<Object>, Object> backValue = new WeakHashMap<>();
            /**
             * 从holder中获取当前线程持有的threadLocal的Map,进行迭代保存
             * **/
            Iterator<CoreThreadLocal<Object>> iterator = holder.get().keySet().iterator();
            while (iterator.hasNext()) {
                CoreThreadLocal<Object> threadLocal = iterator.next();
                backValue.put(threadLocal,threadLocal.get());
                if (!capture.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.remove();
                }
            }
            /**
             * 设置上capture
             * */
            setThreadLocal(capture);
            return new SnapShot(backValue);
        }

        private static void setThreadLocal(WeakHashMap<CoreThreadLocal<Object>, Object> ctlValues) {
            ctlValues.forEach((key, value) -> {
                key.set(value);
            });
        }

        /**
         * 恢复为backUp
         * **/
        public static void restore(CoreThreadLocal.SnapShot backUp) {
            Iterator<CoreThreadLocal<Object>> iterator = holder.get().keySet().iterator();

            while (iterator.hasNext()) {
                CoreThreadLocal<Object> threadLocal = iterator.next();
                if (!backUp.ctlValue.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.remove();
                }
            }
            setThreadLocal(backUp.ctlValue);
        }
    }


    static class SnapShot {
        final WeakHashMap<CoreThreadLocal<Object>,Object> ctlValue;

        private SnapShot(WeakHashMap<CoreThreadLocal<Object>,Object> ctlValue) {
            this.ctlValue = ctlValue;
        }
    }
}
  1. CoreRunnable
public class CoreRunnable implements Runnable{

    private AtomicReference<CoreThreadLocal.SnapShot> captureRef;

    private Runnable runnable;

    private CoreRunnable(Runnable runnable) {
        this.runnable = runnable;
        captureRef = new AtomicReference<>(Transmitter.capture());
    }

    @Override
    public void run() {
        CoreThreadLocal.SnapShot capture = captureRef.get();
        CoreThreadLocal.SnapShot backUp = Transmitter.replay(capture);
        try{
            runnable.run();
        }finally {
            Transmitter.restore(backUp);
        }
    }

    /** 调用这个函数进行包装 **/
    public static CoreRunnable getRunnable(Runnable runnable) {
        return new CoreRunnable(runnable);
    }
}
  1. 测试
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * @author yogurtzzz
 * @date 2020/6/4 13:05
 **/
public class Test {


    private static ThreadLocal<String> ttl = new CoreThreadLocal<>();

    /** 保证只有1个线程,以便观察这个线程被多个Runnable复用时,能否成功完成ThreadLocal的传递 **/
    private static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
            1,1,0,TimeUnit.SECONDS,new ArrayBlockingQueue<>(10)
    );

    public static void main(String[] args) throws InterruptedException {
        ttl.set("yogurtzzz");

        for (int i = 0; i < 5; i++) {
            if (i == 2) {
                ttl.set("changed");
            }
            CoreRunnable runnable = CoreRunnable.getRunnable(() -> {
                System.out.println(Thread.currentThread().getName() + " : " + ttl.get());
            });

            threadPoolExecutor.execute(runnable);
            //CompletableFuture.runAsync(runnable,threadPoolExecutor);
            TimeUnit.MILLISECONDS.sleep(500);
        }
    }
}

结果:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值