基于@erickson答案,我编写了此代码。它适用于可继承的ThreadLocals。它使用与线程构造函数中使用的相同方法构建可继承的ThreadLocals列表。当然,我用反射来做到这一点。此外,我重写了执行器类。
public class MyThreadPoolExecutor extends ThreadPoolExecutor
{
@Override
public void execute(Runnable command)
{
super.execute(new Wrapped(command, Thread.currentThread()));
}
}
包装纸:
private class Wrapped implements Runnable
{
private final Runnable task;
private final Thread caller;
public Wrapped(Runnable task, Thread caller)
{
this.task = task;
this.caller = caller;
}
public void run()
{
Iterable<ThreadLocal<?>> vars = null;
try
{
vars = copy(caller);
}
catch (Exception e)
{
throw new RuntimeException("error when coping Threads", e);
}
try {
task.run();
}
finally {
for (ThreadLocal<?> var : vars)
var.remove();
}
}
}
复制方法:
public static Iterable<ThreadLocal<?>> copy(Thread caller) throws Exception
{
List<ThreadLocal<?>> threadLocals = new ArrayList<>();
Field field = Thread.class.getDeclaredField("inheritableThreadLocals");
field.setAccessible(true);
Object map = field.get(caller);
Field table = Class.forName("java.lang.ThreadLocal$ThreadLocalMap").getDeclaredField("table");
table.setAccessible(true);
Method method = ThreadLocal.class
.getDeclaredMethod("createInheritedMap", Class.forName("java.lang.ThreadLocal$ThreadLocalMap"));
method.setAccessible(true);
Object o = method.invoke(null, map);
Field field2 = Thread.class.getDeclaredField("inheritableThreadLocals");
field2.setAccessible(true);
field2.set(Thread.currentThread(), o);
Object tbl = table.get(o);
int length = Array.getLength(tbl);
for (int i = 0; i < length; i++)
{
Object entry = Array.get(tbl, i);
Object value = null;
if (entry != null)
{
Method referentField = Class.forName("java.lang.ThreadLocal$ThreadLocalMap$Entry").getMethod(
"get");
referentField.setAccessible(true);
value = referentField.invoke(entry);
threadLocals.add((ThreadLocal<?>) value);
}
}
return threadLocals;
}