一、ThreadLocal
概念
(个人理解)ThreadLocal是解决多线程问题的一种解决方案
线程安全问题是因为多个线程,可以共同操作同一个公共变量,而导致的公共变量数据安全问题。
解决方案:
- 加锁,同一个时间内,只能一个线程修改公共变量。
- 把公共变量放到每个线程内,线程独享。不存在操作公共变量的问题了。(ThreadLocal就是这种)
源码
ThreadLocal
案例
1. dynamic-datasource存储当前线程的数据源
(dynamic-datasource的源码)
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package com.baomidou.dynamic.datasource.toolkit;
import java.util.ArrayDeque;
import java.util.Deque;
import org.springframework.util.StringUtils;
public final class DynamicDataSourceContextHolder {
private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new ThreadLocal() {
protected Object initialValue() {
return new ArrayDeque();
}
};
private DynamicDataSourceContextHolder() {
}
public static String peek() {
return (String)((Deque)LOOKUP_KEY_HOLDER.get()).peek();
}
public static void push(String ds) {
((Deque)LOOKUP_KEY_HOLDER.get()).push(StringUtils.isEmpty(ds) ? "" : ds);
}
public static void poll() {
Deque<String> deque = (Deque)LOOKUP_KEY_HOLDER.get();
deque.poll();
if (deque.isEmpty()) {
LOOKUP_KEY_HOLDER.remove();
}
}
public static void clear() {
LOOKUP_KEY_HOLDER.remove();
}
}
2. 自定义存储当前数据源
public class DataSourceContextHolder {
/** 环境 */
private static final ThreadLocal<EnvTypeEnum> ENV_CONTEXT_HOLDER = new ThreadLocal<>();
/** 数据源 */
private static final ThreadLocal<DatabaseType> DATABASE_CONTEXT_HOLDER = new ThreadLocal<DatabaseType>();
public static void setEnvType(EnvTypeEnum envType){
ENV_CONTEXT_HOLDER.set(envType);
}
public static EnvTypeEnum getEnvType(){
return ENV_CONTEXT_HOLDER.get();
}
public static void clearEnvType(){
ENV_CONTEXT_HOLDER.remove();
}
public static void setDatabaseType(DatabaseType type) {
DATABASE_CONTEXT_HOLDER.set(type);
}
public static DatabaseType getDatabaseType() {
return DATABASE_CONTEXT_HOLDER.get();
}
}
问题:子线程问题(如果使用了多线程)
ThreadLocal而言,子线程是拿不到父线程的TL对象的。
但是Thread里还维护了一个inheritableThreadLocals用来存放父线程的TL对象。

二、InheritableThreadlocal
概念
InheritableThreadlocal 是ThreadLocal的一个子类
原理:
在线程初始化的时候,会将当前线程的ThreadLocal放到InheritableThreadLocalMap。
然后在InheritableThreadLocal中,重写了childValue和getMap方法,用来返回inheritableThreadLocals,而不是threadLocals。即返回了线程初始化时,父线程塞进去的inheritableThreadLocals;


Thread初始化时的代码:
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
// 省略...
Thread parent = currentThread();
// 省略...
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
//省略...
}
流程:
伪代码
// 新建一个InheritableThreadLocal对象
ThreadLocal itl = new InheritableThreadLocal<String>()
main-thread 线程start:
// 此时,itl.getMap() 返回了 main线程的createInheritedMap对象,然后把“v1”字符串set了进去。
// 即,现在main线程的inheritableThreadLocals是ThreadLocalMap,且这个map里有一个key为InheritableThreadLocal对象(我们new的itl对象的threadLocalHashCode),value为“v1”字符串的值。
itl.set("v1");
// 新建一个线程,child_thead_1
new Thread(){
name = "child_thead_1";
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
init(){
// 拿到main-thread的线程对象
Thread parent = currentThread();
// 复制一份main线程的inheritableThreadLocals(ThreadLocalMap)到自己的 inheritableThreadLocals(ThreadLocalMap)
this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
}
}.start(){
run(){
// child_thead_1 start
// InheritableThreadLocal.getMap,返回了 child_thead_1 的 inheritableThreadLocals ;
// 返回的Map里 去get key=InheritableThreadLocal(我们new的itl对象)的值,肯定会返回“v1”
itl.get();
}
}
综上,就可以在子线程中,拿到父线程的ITL对象的value了。
InheritableThreadlocal
可继承的ThreadLocal;
Inheritable是在Thread对象中单独存储的一个Map,与ThreadLocalMap一样。
在Thread初始化的时候,会把父线程(当前线程)的ThreadLocal放到InheritableThreadLocalMap中。
一旦Thread初始化之后,Map里的ITL就不会再被赋值了。所以线程池里的线程只初始化一次,就会导致,ITL在线程池中失效。
问题:InheritableThreadLocal遇线程池失效
原因
我们综上可知,InheritableThreadLocal是在Thread对象初始化的时候,才将父线程的InheritableThreadLocalMap放到子线程的InheritableThreadLocalMap里。那么线程池的线程都是新建之后,可以复用的。所以肯定会有问题。
如:一旦子线程初始化了InheritableThreadLocalMap,就再也不会改变InheritableThreadLocalMap的值了。
问题描述:https://zhuanlan.zhihu.com/p/473659523
上边帖子的解决方案: 线程池不用new Thread,但是肯定需要new Runnable吧,所以对Runnable进行封装,在初始化Runnable的时候,通过【反射】对当前线程(即父线程)的ITL进行备份,然后在run的时候,拿到这个ITL 并塞到当前线程(即子线程)里。
解决
简书的github加速: https://gitcode.net/mirrors/alibaba/transmittable-thread-local?utm_source=csdn_github_accelerator
-
方法1:自定义Runable实现类
https://www.jianshu.com/p/29f4034f4250 -
方法2:用阿里的TransmittableThreadLocal
https://www.cnblogs.com/sweetchildomine/p/8807059.html
三、transmittable-thread-local 使用
Transmittable-ThreadLocal 是阿里的
ThreadLocal:https://www.cnblogs.com/hama1993/p/10382523.html
InheritableThreadLocal:https://www.cnblogs.com/hama1993/p/10400265.html
TransmittableThreadLocal:https://www.cnblogs.com/hama1993/p/10409740.html
https://www.cnblogs.com/cb1186512739/p/14214302.html
https://baijiahao.baidu.com/s?id=1702913441482159358&wfr=spider&for=pc
https://zhuanlan.zhihu.com/p/302371614
transmittableThreadLocal 继承了InheritableThreadLocal
TransmittableThreadLocal工作流程简介
ThreadLocal ttl= new TransmittableThreadLocal<String>();
// 即:不管每次set还是get都会将当前的ttl对象放到holder里存起来;且key是ttl对象,value是null,当一个set使用。
并且,这个holder是static的,即所有TransmittableThreadLocal对象,公用同一个holder,里边存所有的ttl对象
// 即在InheritableThreadLocal里,又维护了一个InheritableThreadLocal,并且这个InheritableThreadLocal里放的是当前对象TransmittableThreadLocal。 (维护了一个线程级别的InheritableThreadLocal,可以记录到每个线程的TransmittableThreadLocal对象。)
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap();
}
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap(parentValue);
}
};
并且TransmittableThreadLocal 里有个很重要的Transmitter内部类,用来维护ThreadLocal的值。

-
capture方法:
capture方法,在ttlRunable类里调用,即父线程调用。 遍历holder里所有的ttl,然后调用ttl的copyValue时,会调当前线程的ttl值。即ttl.get(),并且复制一份,存起来。 存到ttlRunable的成员属性captureRef里。 -
replay方法:
在run方法里调用,即子线程。 replay方法遍历holder,拿子线程的ttl值,放到backup里,然后遍历所有父线程的ttl对象和值,都去调用对应的ttl的set方法,把父线程的值放到子线程的ttl对象里。(描述不准确,其实是放到子Thread的inheritableThreadLocals里。为了便于理解,表述为放到ttl对象里) 然后子线程再get,就会拿到父线程里的值了。 -
restore方法
与replay类似,但是操作相反,删除现在所有的ttl对象的值,然后把backup里的值都恢复进来。
注意:必须使用Ttl包装类,来包装线程池的runable或者callable,否则TTL与ITL无异。
自定义 TtlRunnable 实现 Runnable,TtlRunnable初始化方法中保持当前线程中已有的TransmittableThreadLocal。


使用案例
场景
两个环境,uat和pro,在uat发布数据到pro的时候,客户端涉及到这些同步数据的数据库表都不让访问。
并且发布的时候,只某一个projectId下的所有数据,其它projectId的用户正常访问。
解决方案
难点:需要知道当前访问的projectId是什么,又不能修改每个方法去传递这个参数。
在springmvc拦截器中拦截请求头,projectId,然后放到ThreadLocal里,然后后续的所有当前线程,和当前子线程都可以访问到当前的projectId。
然后在所有的mybatis拦截器中,拿到所有的表名,再根据当前的projectId,来判断是否需要拒绝访问。
代码
全局的holder; 使用TransmittableThreadLocal;
防止在service中使用线程池,导致ThreadLocal的值拿不到。
import com.alibaba.ttl.TransmittableThreadLocal;
/**
* 全局ContextHolder
* @author ZY
* @date 2022/4/12 11:09
*/
public final class CommonContextHolder {
/* projectId start */
private static final ThreadLocal<Integer> PROJECT_HOLDER = new TransmittableThreadLocal<>();
public static void setProjectId(int projectId) {
PROJECT_HOLDER.set(projectId);
}
public static Integer getProjectId() {
return PROJECT_HOLDER.get();
}
public static void clearProjectId() {
PROJECT_HOLDER.remove();
}
/* projectId end */
/* client start */
private static final ThreadLocal<String> CLIENT_HOLDER = new TransmittableThreadLocal<>();
public static void setClient(String client) {
CLIENT_HOLDER.set(client);
}
public static String getClient() {
return CLIENT_HOLDER.get();
}
public static void clearClient() {
CLIENT_HOLDER.remove();
}
/* client end */
}
springmvc拦截器
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.OutputStream;
/**
* 全局拦截器
*
* @date 2022/4/12 10:53
*/
@Component
public class CommonInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
response.setCharacterEncoding("UTF-8");
response.setContentType("text/html;charset=utf-8");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Access-Control-Allow-Headers", "*");
OutputStream outputStream = response.getOutputStream();
String projectId = request.getHeader("project");
// 记录当前请求的项目id
Integer i = formatProjectId(projectId);
if (i == null) {
outputStream.write(JSONObject.toJSONString(ResponseEntity.invalidParams()).getBytes());
return false;
}
CommonContextHolder.setProjectId(i);
return true;
}
private Integer formatProjectId(String projectId) {
if (StringUtils.isEmpty(projectId)) return null;
try {
return Integer.parseInt(projectId);
} catch (NumberFormatException e) {
return null;
}
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
CommonContextHolder.clearProjectId();
}
}
mybatis拦截器
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.dubbo.common.utils.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* 在发布模块的时候,锁定某些模块。
* @date 2022/4/6 18:09
*/
@Component
@Intercepts({
@Signature(type= Executor.class,method="update",args={MappedStatement.class,Object.class}),
@Signature(type = Executor.class, method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
@Slf4j
public class ModuleLockMybatisInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
if (!needVerify(ms)) {
return invocation.proceed();
}
Object parameterObject = args[1];
BoundSql boundSql = ms.getBoundSql(parameterObject);
String sql = boundSql.getSql();
List<String> tableNameBySql = getTableNameBySql(sql);
log.debug("解析sql:{} \n表名:{}", sql, tableNameBySql);
if(isLock(tableNameBySql)){
// 如果用Exception,mybatis获取不到msg
throw new RuntimeException(ResponseCodeEnum.DENY_ACCESS_BY_RELEASING.getMsg());
}
return invocation.proceed();
}
private boolean needVerify(MappedStatement ms){
Integer projectId = CommonContextHolder.getProjectId();
if(projectId == null) return false;
return ModuleLockContext.needVerify(projectId);
}
private boolean isLock(List<String> tableNameList){
Integer projectId = CommonContextHolder.getProjectId();
if(projectId == null) return false;
log.debug(Thread.currentThread().getName() + "-------sql--project-holder::::" + projectId);
return ModuleLockContext.isLock(projectId, tableNameList);
}
/**
* 通过druid工具通过sql解析表名
* @param sql
* @return java.util.List<java.lang.String>
* @author ZY
* @date 2022/4/7 9:54
*/
private static List<String> getTableNameBySql(String sql) {
List<String> fromCache = SqlTableCache.getFromCache(sql);
if(fromCache != null) return fromCache;
String dbType = JdbcConstants.MYSQL;
try {
List<String> allTableNameList = new ArrayList<>();
/*//格式化输出
String sqlResult = SQLUtils.format(sql, dbType);
log.debug("格式化sql={}"+sqlResult);*/
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, dbType);
if (CollectionUtils.isEmpty(stmtList)) {
log.debug("stmtList为空无需获取");
return Collections.emptyList();
}
for (SQLStatement sqlStatement : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
sqlStatement.accept(visitor);
Map<TableStat.Name, TableStat> tables = visitor.getTables();
log.debug("druid解析sql的结果集:{}",tables);
Set<TableStat.Name> tableNameSet = tables.keySet();
allTableNameList.addAll(tableNameSet.stream().map(TableStat.Name::getName).filter(StringUtils::isNotBlank).collect(Collectors.toList()));
}
log.debug("解析sql后的表名:{}",allTableNameList);
SqlTableCache.addToCache(sql, allTableNameList);
return allTableNameList;
} catch (Exception e) {
log.error("解析sql异常,sql=" + sql, e);
}
return Collections.emptyList();
}
}
2600

被折叠的 条评论
为什么被折叠?



