TransmittableThreadLocal使用踩坑-(主线程set,异步线程get)
背景:为了获取相关字段方便,项目里使用了TransmittableThreadLocal上下文,在异步逻辑中get值时发现并非当前请求的值,且是偶发状况(并发问题)。
发现:TransmittableThreadLocal是阿里开源的可以实现父子线程值传递的工具,其子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰(防止数据“污染”),如果没有使用装饰后的线程池,那么使用TransmittableThreadLocal上下文,就有可能出现线程不安全的问题。
话不多说上代码:
封装的上下文,成员变量RequestHeader
package org.example.ttl.threadLocal; import com.alibaba.ttl.TransmittableThreadLocal; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.ToString; import org.apache.commons.lang3.ObjectUtils; /** * description: * author: JohnsonLiu * create at: 2021/12/24 23:19 */ @Data @AllArgsConstructor @NoArgsConstructor public class RequestContext { private static final ThreadLocal<RequestContext> transmittableThreadLocal = new TransmittableThreadLocal(); private static final RequestContext INSTANCE = new RequestContext(); private RequestHeader requestHeader; public static void create(RequestHeader requestHeader) { transmittableThreadLocal.set(new RequestContext(requestHeader)); } public static RequestContext current() { return ObjectUtils.defaultIfNull(transmittableThreadLocal.get(), INSTANCE); } public static RequestHeader get() { return current().getRequestHeader(); } public static void remove() { transmittableThreadLocal.set(null); } @Data @AllArgsConstructor @NoArgsConstructor @ToString static class RequestHeader { private String requestUrl; private String requestType; } }
获取上下文内容的case:
package org.example.ttl.threadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* description: TransmittableThreadLocal正确使用
* author: JohnsonLiu
* create at: 2021/12/24 22:24
* 验证结论:
* 1.线程池必须使用TtlExecutors修饰,或者Runnable\Callable必须使用TtlRunnable\TtlCallable修饰
* ---->原因:子线程复用,子线程拥有的上下文内容会对下次使用造成“污染”,而修饰后的子线程在执行run方法后会进行“回放”,防止污染
*/
public class TransmittableThreadLocalCase2 {
// 为达到线程100%复用便于测试,线程池核心数1
private static final Executor TTL_EXECUTOR = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)));
// 如果使用一般的线程池或者Runnable\Callable时,会存在线程“污染”,比如线程池中线程会复用,复用的线程会“污染”该线程执行下一次任务
private static final Executor EXECUTOR = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000));
public static void main(String[] args) {
RequestContext.create(new RequestContext.RequestHeader("url", "get"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步):" + RequestContext.get());
// 模拟另一个线程修改上下文内容
EXECUTOR.execute(() -> {
RequestContext.create(new RequestContext.RequestHeader("url", "put"));
});
// 保证上面子线程修改成功
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 异步获取上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 异步):" + RequestContext.get());
});
// 主线程修改上下文内容
RequestContext.create(new RequestContext.RequestHeader("url", "post"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步<reCreate>):" + RequestContext.get());
// 主线程remove
RequestContext.remove();
// 子线程获取remove后的上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之后 异步):" + RequestContext.get());
});
}
}
使用一般线程池结果:
使用修饰后的线程池结果:
这种问题的解决办法:
如果大家跟我一样存在这样的使用,那么也会低概率存在这样的问题,正确的使用方式是:
子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰,这一点很容易被遗漏,比如上下文和异步逻辑不是同一个人开发的,那么异步逻辑的开发者就很可能直接在异步逻辑中使用上下文,而忽略装饰线程池,造成线程复用时的“数据污染”。
另外还有一种不同于上面的上下文用法,同样使用不当也会存在线程安全问题:
上代码样例
package org.example.ttl.threadLocal; import com.alibaba.ttl.TransmittableThreadLocal; import java.util.LinkedHashMap; import java.util.Map; /** * description: TransmittableThreadLocal正确使用 * author: JohnsonLiu * create at: 2021/12/24 23:19 */ public class ServiceContext { private static final ThreadLocal<Map<Integer, Integer>> transmittableThreadLocal = new TransmittableThreadLocal() { /** * 如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务 * 重写copy方法且重新赋值给新的LinkedHashMap,不然会导致父子线程都是持有同一个引用,只要有修改取值都会变化。引用值线程不安全 * parentValue是父线程执行子任务那个时刻的快照值,后续父线程再次set值也不会影响子线程get,因为已经不是同一个引用 * @param parentValue * @return */ @Override public Object copy(Object parentValue) { if (parentValue instanceof Map) { System.out.println("copy"); return new LinkedHashMap<Integer, Integer>((Map) parentValue); } return null; } /** * 如果使用普通线程池执行异步任务,重写childValue即可实现子线程获取的是父线程执行任务那个时刻的快照值,重新赋值给新的LinkedHashMap,父线程修改不会影响子线程(非共享) * 但是如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务,此时就会变成引用共享,必须得重写copy方法才能实现非共享 * @param parentValue * @return */ @Override protected Object childValue(Object parentValue) { if (parentValue instanceof Map) { System.out.println("childValue"); return new LinkedHashMap<Integer, Integer>((Map) parentValue); } return null; } /** * 初始化,每次get时都会进行初始化 * @return */ @Override protected Object initialValue() { System.out.println("initialValue"); return new LinkedHashMap<Integer, Integer>(); } }; public static void set(Integer key, Integer value) { transmittableThreadLocal.get().put(key, value); } public static Map<Integer, Integer> get() { return transmittableThreadLocal.get(); } public static void remove() { transmittableThreadLocal.remove(); } }
使用case:
package org.example.ttl.threadLocal; import com.alibaba.ttl.threadpool.TtlExecutors; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; /** * description: TransmittableThreadLocal正确使用 * author: JohnsonLiu * create at: 2021/12/24 22:24 */ public class TransmittableThreadLocalCase { private static final Executor executor = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000))); // private static final Executor executor = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)); static int i = 0; public static void main(String[] args) { ServiceContext.set(++i, i); executor.execute(() -> { try { Thread.sleep(3000); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + " 子线程(rm之前):" + ServiceContext.get()); }); ServiceContext.set(++i, i); ServiceContext.remove(); executor.execute(() -> { try { Thread.sleep(3000); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + " 子线程(rm之后):" + ServiceContext.get()); }); } }
代码中已有详细的说明,大家可自行copy代码验证。
以上仅代表个人理解,如有错误之处,还请留言指教,进行更正!