将 ThreadLocal 传播到从执行器服务提取的新线程

我正在使用 ExecutorService 和 Future(此处的示例代码)在具有超时的单独线程中运行一个进程(线程“生成”发生在 AOP Aspect 中)。

现在,主线程是一个令人不安的请求。Resteasy 使用一个或多个 ThreadLocal 变量来存储一些上下文信息,我需要在 Rest 方法调用中的某个时刻检索这些信息。问题是,由于 Resteasy 线程在新线程中运行,因此 ThreadLocal 变量丢失。

将 Resteasy 使用的任何 ThreadLocal 变量“传播”到新线程的最佳方式是什么?似乎Resteasy使用多个ThreadLocal变量来跟踪上下文信息,我想“盲目地”将所有信息传输到新线程。

我已经研究了子类化和使用beforeExecute方法将当前线程传递到池,但我找不到将ThreadLocal变量传递到池的方法。ThreadPoolExecutor

有什么建议吗?

谢谢


答案 1

与线程关联的实例集保存在每个 的私有成员中。您列举这些内容的唯一机会就是对 ;这样,您可以覆盖线程字段的访问限制。ThreadLocalThreadThread

一旦你可以获得 的集合,你就可以使用 和 hooks 在后台线程中复制,或者通过为任务创建一个包装器来拦截调用以设置未设置必要的实例。实际上,后一种技术可能效果更好,因为它会为您提供一个方便的位置来存储任务排队时的值。ThreadLocalbeforeExecute()afterExecute()ThreadPoolExecutorRunnablerun()ThreadLocalThreadLocal


更新:下面是第二种方法的更具体说明。与我的原始描述相反,包装器中存储的所有内容都是调用线程,在执行任务时会询问该线程。

static Runnable wrap(Runnable task)
{
  Thread caller = Thread.currentThread();
  return () -> {
    Iterable<ThreadLocal<?>> vars = copy(caller);
    try {
      task.run();
    }
    finally {
      for (ThreadLocal<?> var : vars)
        var.remove();
    }
  };
}

/**
 * For each {@code ThreadLocal} in the specified thread, copy the thread's 
 * value to the current thread.  
 * 
 * @param caller the calling thread
 * @return all of the {@code ThreadLocal} instances that are set on current thread
 */
private static Collection<ThreadLocal<?>> copy(Thread caller)
{
  /* Use a nasty bunch of reflection to do this. */
  throw new UnsupportedOperationException();
}

答案 2

基于@erickson答案,我编写了此代码。它适用于可继承的ThreadLocals。它使用与线程构造函数中使用的相同方法构建可继承的ThreadLocals列表。当然,我用反射来做到这一点。此外,我重写了执行器类。

public class MyThreadPoolExecutor extends ThreadPoolExecutor
{
   @Override
   public void execute(Runnable command)
   {
      super.execute(new Wrapped(command, Thread.currentThread()));
   }
}

包装纸:

   private class Wrapped implements Runnable
   {
      private final Runnable task;

      private final Thread caller;

      public Wrapped(Runnable task, Thread caller)
      {
         this.task = task;
         this.caller = caller;
      }

      public void run()
      {
         Iterable<ThreadLocal<?>> vars = null;
         try
         {
            vars = copy(caller);
         }
         catch (Exception e)
         {
            throw new RuntimeException("error when coping Threads", e);
         }
         try {
            task.run();
         }
         finally {
            for (ThreadLocal<?> var : vars)
               var.remove();
         }
      }
   }

复制方法:

public static Iterable<ThreadLocal<?>> copy(Thread caller) throws Exception
   {
      List<ThreadLocal<?>> threadLocals = new ArrayList<>();
      Field field = Thread.class.getDeclaredField("inheritableThreadLocals");
      field.setAccessible(true);
      Object map = field.get(caller);
      Field table = Class.forName("java.lang.ThreadLocal$ThreadLocalMap").getDeclaredField("table");
      table.setAccessible(true);

      Method method = ThreadLocal.class
              .getDeclaredMethod("createInheritedMap", Class.forName("java.lang.ThreadLocal$ThreadLocalMap"));
      method.setAccessible(true);
      Object o = method.invoke(null, map);

      Field field2 = Thread.class.getDeclaredField("inheritableThreadLocals");
      field2.setAccessible(true);
      field2.set(Thread.currentThread(), o);

      Object tbl = table.get(o);
      int length = Array.getLength(tbl);
      for (int i = 0; i < length; i++)
      {
         Object entry = Array.get(tbl, i);
         Object value = null;
         if (entry != null)
         {
            Method referentField = Class.forName("java.lang.ThreadLocal$ThreadLocalMap$Entry").getMethod(
                    "get");
            referentField.setAccessible(true);
            value = referentField.invoke(entry);
            threadLocals.add((ThreadLocal<?>) value);
         }
      }
      return threadLocals;
   }