在Python中评估不安全用户输入的数学方程式

29 投票
2 回答
3048 浏览
提问于 2025-04-29 00:27

我有一个网站,用户可以在上面输入数学公式,然后这些公式会根据网站提供的数据(常量)进行计算。需要的数学操作包括符号、算术运算,还有一些基本的函数,比如 min()max()。举个例子,一个公式可能是:

max(a * b + 100, a / b - 200)

我们可以用 Python 的 eval() 函数来直接计算这个公式,但大家都知道这样做会让网站面临安全风险。那么,有什么安全的方法来进行数学公式的计算呢?

如果选择用 Python 来计算这个公式,是否有一些 Python 沙箱可以限制 Python 的功能,只允许用户提供的运算符和函数使用?像定义函数这样的完整 Python 功能应该完全禁用。可以使用子进程(可以参考 PyPy 沙箱)。特别是,循环和其他可能被利用的内存和 CPU 使用漏洞应该被关闭。

暂无标签

2 个回答

9

声明:我是另一个回答中提到的Alexer。老实说,我半开玩笑地提到用字节码解析的方法,因为我正好有99%的代码是为一个无关的项目准备的,所以我能在几分钟内快速做出一个初步的演示。不过,这种方法本身没有什么问题,只是这个任务需要的工具比较复杂。实际上,你完全可以只对代码进行反汇编(检查操作码是否在一个白名单上),确认常量和名称是有效的,然后用普通的、危险的eval来执行它。这样做的话,你就失去了在执行过程中插入额外检查的能力。(另一个声明:我还是不太敢用eval)

不过,我有点无聊,所以我写了一些代码,用更聪明的方法来做这件事;使用抽象语法树(AST)而不是字节码。只需要在compile()中加一个额外的标志。(或者直接用ast.parse(),反正你也需要模块中的类型)

import ast
import operator

_operations = {
        ast.Add: operator.add,
        ast.Sub: operator.sub,
        ast.Mult: operator.mul,
        ast.Div: operator.div,
        ast.Pow: operator.pow,
}

def _safe_eval(node, variables, functions):
        if isinstance(node, ast.Num):
                return node.n
        elif isinstance(node, ast.Name):
                return variables[node.id] # KeyError -> Unsafe variable
        elif isinstance(node, ast.BinOp):
                op = _operations[node.op.__class__] # KeyError -> Unsafe operation
                left = _safe_eval(node.left, variables, functions)
                right = _safe_eval(node.right, variables, functions)
                if isinstance(node.op, ast.Pow):
                        assert right < 100
                return op(left, right)
        elif isinstance(node, ast.Call):
                assert not node.keywords and not node.starargs and not node.kwargs
                assert isinstance(node.func, ast.Name), 'Unsafe function derivation'
                func = functions[node.func.id] # KeyError -> Unsafe function
                args = [_safe_eval(arg, variables, functions) for arg in node.args]
                return func(*args)

        assert False, 'Unsafe operation'

def safe_eval(expr, variables={}, functions={}):
        node = ast.parse(expr, '<string>', 'eval').body
        return _safe_eval(node, variables, functions)

if __name__ == '__main__':
        import math

        print safe_eval('sin(a*pi/b)', dict(a=1, b=2, pi=math.pi), dict(sin=math.sin))

这和字节码版本是一样的;如果你检查操作是否在白名单上,并确认名称和值是有效的,那么你应该可以在AST上调用eval。(不过我还是不建议这么做。因为我比较谨慎。而在使用eval时,谨慎是很重要的)

6

在Python中,有一种相对简单的方法可以做到这一点,而且不需要使用第三方包。

  • 可以使用 compile() 准备一个单行的Python表达式,将其转换为字节码,以便后续使用 eval()

  • 不是通过 eval() 来运行字节码,而是通过你自己定制的操作码循环来运行,只实现你真正需要的操作码。例如,不使用内置函数,不访问属性,这样就不会让沙盒环境逃逸。

不过,这里有一些需要注意的地方,比如要准备应对CPU和内存的耗尽,这些问题并不是这个方法特有的,其他方法也会遇到。

这里有一篇关于这个主题的完整博客文章这里有一个相关的代码片段。下面是简化的示例代码。

""""

    The orignal author: Alexer / #python.fi

"""

import opcode
import dis
import sys
import multiprocessing
import time

# Python 3 required
assert sys.version_info[0] == 3, "No country for old snakes"


class UnknownSymbol(Exception):
    """ There was a function or constant in the expression we don't support. """


class BadValue(Exception):
    """ The user tried to input dangerously big value. """

    MAX_ALLOWED_VALUE = 2**63


class BadCompilingInput(Exception):
    """ The user tried to input something which might cause compiler to slow down. """


def disassemble(co):
    """ Loop through Python bytecode and match instructions  with our internal opcodes.

    :param co: Python code object
    """
    code = co.co_code
    n = len(code)
    i = 0
    extended_arg = 0
    result = []
    while i < n:
        op = code[i]

        curi = i
        i = i+1
        if op >= dis.HAVE_ARGUMENT:
            # Python 2
            # oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            oparg = code[i] + code[i+1] * 256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == dis.EXTENDED_ARG:
                # Python 2
                #extended_arg = oparg*65536L
                extended_arg = oparg*65536
        else:
            oparg = None

        # print(opcode.opname[op])

        opv = globals()[opcode.opname[op].replace('+', '_')](co, curi, i, op, oparg)

        result.append(opv)

    return result

# For the opcodes see dis.py
# (Copy-paste)
# https://docs.python.org/2/library/dis.html

class Opcode:
    """ Base class for out internal opcodes. """
    args = 0
    pops = 0
    pushes = 0
    def __init__(self, co, i, nexti, op, oparg):
        self.co = co
        self.i = i
        self.nexti = nexti
        self.op = op
        self.oparg = oparg

    def get_pops(self):
        return self.pops

    def get_pushes(self):
        return self.pushes

    def touch_value(self, stack, frame):
        assert self.pushes == 0
        for i in range(self.pops):
            stack.pop()


class OpcodeArg(Opcode):
    args = 1


class OpcodeConst(OpcodeArg):
    def get_arg(self):
        return self.co.co_consts[self.oparg]


class OpcodeName(OpcodeArg):
    def get_arg(self):
        return self.co.co_names[self.oparg]


class POP_TOP(Opcode):
    """Removes the top-of-stack (TOS) item."""
    pops = 1
    def touch_value(self, stack, frame):
        stack.pop()


class DUP_TOP(Opcode):
    """Duplicates the reference on top of the stack."""
    # XXX: +-1
    pops = 1
    pushes = 2
    def touch_value(self, stack, frame):
        stack[-1:] = 2 * stack[-1:]


class ROT_TWO(Opcode):
    """Swaps the two top-most stack items."""
    pops = 2
    pushes = 2
    def touch_value(self, stack, frame):
        stack[-2:] = stack[-2:][::-1]


class ROT_THREE(Opcode):
    """Lifts second and third stack item one position up, moves top down to position three."""
    pops = 3
    pushes = 3
    direct = True
    def touch_value(self, stack, frame):
        v3, v2, v1 = stack[-3:]
        stack[-3:] = [v1, v3, v2]


class ROT_FOUR(Opcode):
    """Lifts second, third and forth stack item one position up, moves top down to position four."""
    pops = 4
    pushes = 4
    direct = True
    def touch_value(self, stack, frame):
        v4, v3, v2, v1 = stack[-3:]
        stack[-3:] = [v1, v4, v3, v2]


class UNARY(Opcode):
    """Unary Operations take the top of the stack, apply the operation, and push the result back on the stack."""
    pops = 1
    pushes = 1


class UNARY_POSITIVE(UNARY):
    """Implements TOS = +TOS."""
    def touch_value(self, stack, frame):
        stack[-1] = +stack[-1]


class UNARY_NEGATIVE(UNARY):
    """Implements TOS = -TOS."""
    def touch_value(self, stack, frame):
        stack[-1] = -stack[-1]


class BINARY(Opcode):
    """Binary operations remove the top of the stack (TOS) and the second top-most stack item (TOS1) from the stack. They perform the operation, and put the result back on the stack."""
    pops = 2
    pushes = 1


class BINARY_POWER(BINARY):
    """Implements TOS = TOS1 ** TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        print(TOS1, TOS)
        if abs(TOS1) > BadValue.MAX_ALLOWED_VALUE or abs(TOS) > BadValue.MAX_ALLOWED_VALUE:
            raise BadValue("The value for exponent was too big")

        stack[-2:] = [TOS1 ** TOS]


class BINARY_MULTIPLY(BINARY):
    """Implements TOS = TOS1 * TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 * TOS]


class BINARY_DIVIDE(BINARY):
    """Implements TOS = TOS1 / TOS when from __future__ import division is not in effect."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 / TOS]


class BINARY_MODULO(BINARY):
    """Implements TOS = TOS1 % TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 % TOS]


class BINARY_ADD(BINARY):
    """Implements TOS = TOS1 + TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 + TOS]


class BINARY_SUBTRACT(BINARY):
    """Implements TOS = TOS1 - TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 - TOS]


class BINARY_FLOOR_DIVIDE(BINARY):
    """Implements TOS = TOS1 // TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 // TOS]


class BINARY_TRUE_DIVIDE(BINARY):
    """Implements TOS = TOS1 / TOS when from __future__ import division is in effect."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 / TOS]


class BINARY_LSHIFT(BINARY):
    """Implements TOS = TOS1 << TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 << TOS]


class BINARY_RSHIFT(BINARY):
    """Implements TOS = TOS1 >> TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 >> TOS]


class BINARY_AND(BINARY):
    """Implements TOS = TOS1 & TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 & TOS]


class BINARY_XOR(BINARY):
    """Implements TOS = TOS1 ^ TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 ^ TOS]


class BINARY_OR(BINARY):
    """Implements TOS = TOS1 | TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 | TOS]


class RETURN_VALUE(Opcode):
    """Returns with TOS to the caller of the function."""
    pops = 1
    final = True
    def touch_value(self, stack, frame):
        value = stack.pop()
        return value


class LOAD_CONST(OpcodeConst):
    """Pushes co_consts[consti] onto the stack.""" # consti
    pushes = 1
    def touch_value(self, stack, frame):
        # XXX moo: Validate type
        value = self.get_arg()
        assert isinstance(value, (int, float))
        stack.append(value)


class LOAD_NAME(OpcodeName):
    """Pushes the value associated with co_names[namei] onto the stack.""" # namei
    pushes = 1
    def touch_value(self, stack, frame):
        # XXX moo: Get name from dict of valid variables/functions
        name = self.get_arg()
        if name not in frame:
            raise UnknownSymbol("Does not know symbol {}".format(name))
        stack.append(frame[name])


class CALL_FUNCTION(OpcodeArg):
    """Calls a function. The low byte of argc indicates the number of positional parameters, the high byte the number of keyword parameters. On the stack, the opcode finds the keyword parameters first. For each keyword argument, the value is on top of the key. Below the keyword parameters, the positional parameters are on the stack, with the right-most parameter on top. Below the parameters, the function object to call is on the stack. Pops all function arguments, and the function itself off the stack, and pushes the return value.""" # argc
    pops = None
    pushes = 1

    def get_pops(self):
        args = self.oparg & 0xff
        kwargs = (self.oparg >> 8) & 0xff
        return 1 + args + 2 * kwargs

    def touch_value(self, stack, frame):
        argc = self.oparg & 0xff
        kwargc = (self.oparg >> 8) & 0xff
        assert kwargc == 0
        if argc > 0:
            args = stack[-argc:]
            stack[:] = stack[:-argc]
        else:
            args = []
        func = stack.pop()

        assert func in frame.values(), "Uh-oh somebody injected bad function. This does not happen."

        result = func(*args)
        stack.append(result)


def check_for_pow(expr):
    """ Python evaluates power operator during the compile time if its on constants.

    You can do CPU / memory burning attack with ``2**999999999999999999999**9999999999999``.
    We mainly care about memory now, as we catch timeoutting in any case.

    We just disable pow and do not care about it.
    """
    if "**" in expr:
        raise BadCompilingInput("Power operation is not allowed")


def _safe_eval(expr, functions_and_constants={}, check_compiling_input=True):
    """ Evaluate a Pythonic math expression and return the output as a string.

    The expr is limited to 1024 characters / 1024 operations
    to prevent CPU burning or memory stealing.

    :param functions_and_constants: Supplied "built-in" data for evaluation
    """

    # Some safety checks
    assert len(expr) < 1024

    # Check for potential bad compiler input
    if check_compiling_input:
        check_for_pow(expr)

    # Compile Python source code to Python code for eval()
    code = compile(expr, '', 'eval')

    # Dissect bytecode back to Python opcodes
    ops = disassemble(code)
    assert len(ops) < 1024

    stack = []
    for op in ops:
        value = op.touch_value(stack, functions_and_constants)

    return value

撰写回答