需要帮助学习非线程安全代码的Python
我正在学习Python 3.12.2的一个教程。现在我到了一个部分,教程想展示一些不安全的代码,也就是在多线程情况下可能会出问题的代码。教程说下面这段代码会产生不可预测的结果。但是对我来说,它的结果非常可预测。代码是:
# when no thread synchronization used
from threading import Thread as Thread
def inc():
global x
for _ in range(1000000):
x+=1
#global variable
x = 0
counter = 0
while counter < 10:
# creating threads
threads = [Thread(target=inc) for _ in range(10)]
# start the threads
for thread in threads:
thread.start()
#wait for the threads
for thread in threads:
thread. Join()
print("Pass ", counter, "final value of x:", f"{x:,}")
x = 0
counter += 1
运行后得到了以下输出:
PS D:\PythonDev> python .\thread3a.py
Pass 0 final value of x: 10,000,000
Pass 1 final value of x: 10,000,000
Pass 2 final value of x: 10,000,000
Pass 3 final value of x: 10,000,000
Pass 4 final value of x: 10,000,000
Pass 5 final value of x: 10,000,000
Pass 6 final value of x: 10,000,000
Pass 7 final value of x: 10,000,000
Pass 8 final value of x: 10,000,000
Pass 9 final value of x: 10,000,000
PS D:\PythonDev>
我修改了代码,添加了外层的循环,这样我就不用每次都从命令行反复运行它。根据教程,每次运行的预期结果应该是10,000,000。然而,实际上结果应该是不可预测的,并且应该小于10,000,000。但我的结果既不是这样。请问我哪里出错了?
我的环境是:
O/S: MS Windows 10 Home 22H2
RAM: 16 GB
CPU: Intel Core I7-2860QM
Terminal Session: PowerShell 7.4.1
Python version: 3.12.2
2 个回答
这个教程是对的,因为操作 x += 1
不是一个原子操作;它包含了多个 Python 字节码操作(见下面的偏移量 12 到 18):
>>> import dis
>>> def inc():
... global x
... for _ in range(1000000):
... x+=1
...
>>> dis.dis(inc)
3 0 LOAD_GLOBAL 0 (range)
2 LOAD_CONST 1 (1000000)
4 CALL_FUNCTION 1
6 GET_ITER
>> 8 FOR_ITER 12 (to 22)
10 STORE_FAST 0 (_)
4 12 LOAD_GLOBAL 1 (x)
14 LOAD_CONST 2 (1)
16 INPLACE_ADD
18 STORE_GLOBAL 1 (x)
20 JUMP_ABSOLUTE 8
>> 22 LOAD_CONST 0 (None)
24 RETURN_VALUE
>>>
由于全局解释器锁(GIL)的存在,线程不能并行执行它们的字节码。因此,如果每个线程没有进行任何输入输出或网络活动来释放 GIL,让其他线程执行,那么这个线程会一直执行它的字节码,直到它的时间片用完。如果你运行的电脑“快”,它可以在一个时间片内完成 1,000,000 次循环。我这台电脑似乎不“快”,当我运行你的代码(把 thread. Join()
替换成 thread.join()
后)时,我得到:
Pass 0 final value of x: 6,688,134
Pass 1 final value of x: 6,096,719
Pass 2 final value of x: 6,250,393
Pass 3 final value of x: 6,116,210
Pass 4 final value of x: 6,686,225
Pass 5 final value of x: 4,912,244
Pass 6 final value of x: 4,965,819
Pass 7 final value of x: 6,301,143
Pass 8 final value of x: 6,549,947
Pass 9 final value of x: 7,321,995
这是因为每个线程在代码中的任意点被中断,如果字节码在执行 x += 1
的过程中被中断,那么结果就会少于 10,000,000(你应该能明白为什么)。但是,如果我把循环次数减少到只有 100,000,那么一个时间片就足够了,我每次都能得到 1,000,000 的结果。你可以尝试把循环次数不断增加十倍,直到你看到问题出现。
如果你改成 ...
from threading import Thread as Thread, Lock
lock = Lock()
def inc():
global x
for _ in range(1000000):
with lock:
x+=1
...
... 你会在每次迭代中得到 10,000,000 -- 但代码运行得会慢很多。
正如@BooBoo所说,“这个教程是正确的,因为操作x += 1不是原子操作”,但在快速的系统上,线程之间发生问题的可能性就小得多。即使把范围增加到1亿,我也没有遇到失败。其实我还挺惊讶的……每次增加10亿次,居然从来没有出错过??。为了制造一个最坏的情况,我使用了以下方法,故意让增量操作变慢,通过强制在读取、修改和存储操作之间进行上下文切换:
from threading import Thread, Lock
import time
lock = Lock()
def inc():
global x
for _ in range(10000):
#with lock:
a = x # capture current value
time.sleep(0) # gives up time slice explicitly
a += 1 # increment it
x = a # store it back in global.
没有锁的情况:
Pass 0 final value of x: 10,096
Pass 1 final value of x: 10,107
Pass 2 final value of x: 10,095
Pass 3 final value of x: 10,112
Pass 4 final value of x: 10,110
Pass 5 final value of x: 10,105
Pass 6 final value of x: 10,111
Pass 7 final value of x: 10,105
Pass 8 final value of x: 10,092
Pass 9 final value of x: 10,107
有锁的情况:
Pass 0 final value of x: 100,000
Pass 1 final value of x: 100,000
Pass 2 final value of x: 100,000
Pass 3 final value of x: 100,000
Pass 4 final value of x: 100,000
Pass 5 final value of x: 100,000
Pass 6 final value of x: 100,000
Pass 7 final value of x: 100,000
Pass 8 final value of x: 100,000
Pass 9 final value of x: 100,000