使用ThreadLocal实现的计数器

本文介绍了一种利用ThreadLocal实现计数器的方法,并解决了进程回收导致计数丢失的问题。通过不同版本迭代优化,最终实现了稳定可靠的全局计数器。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

今天脑子里闪过使用ThreadLocal实现计数器的念头,百度了一下,没有讲到怎么聚合所有进程各自的计数器值。所以自己实现一个,代码如下。

import java.util.WeakHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

public class ThradeLocalTest {
    static class Counter {
        private static class Entry {
            long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;

            long count = 0;
        }

        private static Lock lock = new ReentrantLock();
        private static WeakHashMap<Thread, Entry> map = new WeakHashMap<>();

        private static ThreadLocal<Entry> local = ThreadLocal.withInitial(Entry::new);

        public void increase() {
            Entry entry = local.get();
            long count = entry.count;
            if (count == 0) {
                lock.lock();
                try {
                    map.put(Thread.currentThread(), entry);
                } finally {
                    lock.unlock();
                }
            }
            local.get().count = count + 1;
        }

        public long getAll() {
            return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum);
        }
    }

    public static void main(String[] args) throws InterruptedException {
        Counter counter = new Counter();
        int number = 100;
        Thread[] threads = new Thread[number];
        for (int i = 0; i < number; i++) {
            threads[i] = new Thread(()->{
                for (int j = 0; j < 100_000_000; j++) {
                    counter.increase();
                }
            });
        }
        for (Thread thread1 : threads) {
            thread1.start();
        }
        System.out.println(counter.getAll());
        for (Thread thread2 : threads) {
            thread2.join();
        }
        System.out.println(counter.getAll());
    }
}


该代码有个问题,就是如果Thread被回收了,对应的计数就丢失了,所以需要自己实现一下存储计数使用的map,改完之后发代码。

import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

public class ThreadLocalTest {
    static class Counter {
        private static class Entry {
            long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;

            long count = 0;
        }

        private volatile long removedTotal = 0;
        long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;

        private Lock lock = new ReentrantLock();
        private Map<WeakReference<Thread>, Entry> map = new HashMap<>();
        private ReferenceQueue<Object> queue = new ReferenceQueue<>();

        private final ThreadLocal<Entry> local = ThreadLocal.withInitial(Entry::new);

        public void increase() {
            Entry entry = local.get();
            long count = entry.count;
            if (count == 0) {
                lock.lock();
                try {
                    map.put(new WeakReference<>(Thread.currentThread(), queue), entry);
                } finally {
                    lock.unlock();
                }
            }
            expunge();
            local.get().count = count + 1;
        }

        private void expunge() {
            for (Object x; (x = queue.poll()) != null; ) {
                synchronized (queue) {
                    lock.lock();
                    try {
                        Entry e = map.get(x);
                        map.remove(x);
                        removedTotal += e.count;
                    } finally {
                        lock.unlock();
                    }
                }
            }
        }

        public long getAll() {
            return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum) + removedTotal;
        }
    }

    public static void main(String[] args) throws InterruptedException {
        Counter counter = new Counter();
        int number = 4;
        Thread[] threads = new Thread[number];
        for (int i = 0; i < number; i++) {
            threads[i] = new Thread(()->{
                for (int j = 0; j < 100_000_000; j++) {
                    counter.increase();
                }
            });
        }
        for (Thread thread1 : threads) {
            thread1.start();
        }
        for (Thread thread2 : threads) {
            thread2.join();
        }
        for (int i = 0; i < number; i++) {
            threads[i] = new Thread(()->{
                for (int j = 0; j < 100_000_000; j++) {
                    counter.increase();
                }
            });
        }
        long start = System.currentTimeMillis();
        for (Thread thread1 : threads) {
            thread1.start();
        }
        for (Thread thread2 : threads) {
            thread2.join();
        }
        System.out.println(System.currentTimeMillis() - start);
        System.out.println(counter.getAll());

        start = System.currentTimeMillis();
        for (long i = 0; i < 800_000_000L; i++) {}
        System.out.println(System.currentTimeMillis() - start);
    }
}

极端情况啥的还没测试,没有触及gc。

又修改了一个版本,不使用ThreadLocal了,创建一个subCounter用于每个线程计数。

import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

public class ThreadLocalTest {
    static class Counter {
        private static class Entry {
            long count = 0;
        }

        public interface SubCounter {
            void increase();
        }

        private class SubCounterImpl implements SubCounter {
            private final Entry entry;

            SubCounterImpl(Entry entry) {
                this.entry = entry;
            }

            public void increase() {
                entry.count++;
            }
        }

        private volatile long removedTotal = 0;
        long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;

        private final Lock lock = new ReentrantLock();
        private final Map<WeakReference<SubCounter>, Entry> map = new HashMap<>();
        private final ReferenceQueue<Object> queue = new ReferenceQueue<>();

        public SubCounter createSubCounter() {
            expunge();
            Entry entry = new Entry();
            SubCounter subCounter = new SubCounterImpl(entry);
            lock.lock();
            try {
                map.put(new WeakReference<>(subCounter, queue), entry);
            } finally {
                lock.unlock();
            }
            return subCounter;
        }

        private void expunge() {
            for (Object x; (x = queue.poll()) != null; ) {
                synchronized (queue) {
                    lock.lock();
                    try {
                        Entry e = map.get(x);
                        map.remove(x);
                        removedTotal += e.count;
                    } finally {
                        lock.unlock();
                    }
                }
            }
        }

        public long getAll() {
            expunge();
            return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum) + removedTotal;
        }
    }

    public static void main(String[] args) throws InterruptedException {
        Counter counter = new Counter();
        int number = 4;
        Thread[] threads = new Thread[number];
        for (int i = 0; i < number; i++) {
            threads[i] = new Thread(new Runnable() {
                Counter.SubCounter subCounter = counter.createSubCounter();
                @Override
                public void run() {
                    for (int j = 0; j < 100_000_000; j++) {
                        subCounter.increase();
                    }
                }
            }
            );
        }
        for (Thread thread1 : threads) {
            thread1.start();
        }
        for (Thread thread2 : threads) {
            thread2.join();
        }
        for (int i = 0; i < number; i++) {
            threads[i] = new Thread(new Runnable() {
                Counter.SubCounter subCounter = counter.createSubCounter();
                @Override
                public void run() {
                    for (int j = 0; j < 100_000_000; j++) {
                        subCounter.increase();
                    }
                }
            }
            );
        }
        long start = System.currentTimeMillis();
        for (Thread thread1 : threads) {
            thread1.start();
        }
        for (Thread thread2 : threads) {
            thread2.join();
        }
        System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
        System.out.printf("Count:%d\n", counter.getAll());

        start = System.currentTimeMillis();
        int intCount = 0;
        for (long i = 0; i < 800_000_000L; i++) {
            intCount++;
        }
        System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
        System.out.printf("Count:%d\n", intCount);

        // 测试垃圾回收
        Counter counter2 = new Counter();
        start = System.currentTimeMillis();
        for (int i = 0; i < 2; i++) {
            for (int j = 0; j < 200_000; j++) {
                counter2.createSubCounter().increase();
            }
        }
        System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
        System.out.printf("Count:%d\n", counter2.getAll());
        System.out.printf("counter size:%d\n", counter.map.size());
        System.out.printf("counter2 size:%d\n", counter2.map.size());
    }
}

配置最大运行内存为2M,运行结果如下:

Time:29
Count:800000000
Time:1193
Count:800000000
Time:5945
Count:400000
counter size:5
counter2 size:504


### Java `ThreadLocal` 工作原理 `ThreadLocal` 是一种特殊的变量容器,允许每个线程拥有该变量的一个独立副本。这意味着不同线程之间无法相互干扰彼此持有的 `ThreadLocal` 变量实例[^1]。 具体来说,在 JVM 中每一个线程都维护着一个名为 `ThreadLocalMap` 的哈希表结构来保存这些特定于当前线程的数据项。每当调用 `set()` 方法向某个 `ThreadLocal` 对象赋值时,实际上是在对应的线程内部创建了一个键值对记录;当通过 `get()` 获取值的时候,则是从这个映射关系里查找并返回相应的对象引用[^2]。 对于继承自父级线程属性的情况(即 `InheritableThreadLocal`),JVM 还会复制一份来自父进程的相关配置给新启动的孩子们使用[^3]。 ```java public class ThreadLocalExample { private static final ThreadLocal<Integer> threadLocalValue = ThreadLocal.withInitial(() -> 0); public void increment() { Integer value = threadLocalValue.get(); threadLocalValue.set(value + 1); } public int getValue() { return threadLocalValue.get(); } } ``` 这段代码展示了如何定义一个初始值为零 (`withInitial`) 的整数类型的 `ThreadLocal` 实例,并实现了两个方法用于操作其上的数值——增加计数器(`increment`) 和读取当前值(`getValue`)。 ### 应用场景 #### 数据库连接池管理 在一个典型的Web应用服务器环境中,多个请求可能并发执行数据库查询任务。为了提高性能和资源利用率,通常采用连接池技术预先分配一定数量的数据库链接供后续重用。此时就可以利用 `ThreadLocal` 来确保每次HTTP请求处理过程中所使用的都是同一个持久化的Session对象,从而简化事务管理和错误恢复逻辑[^4]。 #### 用户上下文传递 假设存在一个多租户SaaS平台架构下的微服务组件通信链路,其中涉及到跨层传播认证信息的需求。借助 `ThreadLocal` ,可以在进入业务流程之初就将登录用户的唯一标识符存入内存空间内,之后无论经过多少次函数调用跳转都不会丢失这条重要线索,直到整个交易结束为止自动清除掉临时缓存的内容。 #### 并发计算中的状态保持 考虑这样一个例子:有一个复杂的算法需要分成若干个小部分由不同的工作单元协同完成。如果希望各个子任务能够共享某些中间结果而不必担心同步竞争条件带来的麻烦的话,那么完全可以把它们封装成静态字段形式并通过 `ThreadLocal` 方式来进行分发控制。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值