在字符串中找到完全平方数

1 投票
4 回答
1595 浏览
提问于 2025-04-17 15:52

一个完美的平方数用二进制表示时,有些位被替换成了“?”,比如说1??,这个数字就是4。(或者是1????000???0000)

我需要找到这个完美的平方数。(这个数字只有一个可能的结果)

在这个字符串中,“?”的数量是n。

为了找到这个数字,我的做法是遍历2的n次方个数字(比如111、110、101、100),然后检查它们是否是完美的平方数。我使用了以下的函数来检查是否是完美的平方数。

bool issqr(int n){
   int d=(int)(sqrt(n));
   if(d*d==n) return true;
   else return false;
}

虽然我在Python中实现了这个功能,但花费了很多时间,所以我转向了C++,只用位运算来生成2的n次方个数字(这样比Python版本快多了)

但是如果数字超过64位,这个方法就不行了。

怎么才能避免这个问题呢?如果一个数字有120位,我该怎么做呢?

(10100110???1?1?01?1?011000?1100?00101000?1?11001101100110001010111?0?1??0110?110?01?1100?1?0110?1?10111?01?0111000?10??101?01)

4 个回答

0

你不需要去检查所有的2的n次方个数字来找到完美平方,其实只需要进行一次简单的平方运算就可以了:

假设你有一个整数n,你想找到一个最大的完美平方数,这个数小于或等于n,我们称它为m。

然后:
d = (int)sqrt(n);
m = d*d;

解释:

假设存在一个完美平方数m',它比m大,这就意味着有一个整数d',使得:d' > d,并且d'*d' = m'。

但是d' >= d + 1,而(d + 1)*(d + 1) > n,所以m' > n,这和我们要求的m' <= n是矛盾的。

现在来回答你的问题:

为了找到完美平方,只需把所有的“?”都改成“1”,然后检查这个完美平方是否符合你的字符串。如果符合,那你就找到了想要的数字;如果不符合,就从最高位开始,把足够的“?”改成“0”,这样得到的数字就会小于或等于你刚找到的完美平方,然后继续这个过程,直到找到完美平方或者没有更多的选择。

2

根据我的理解,给定一个整数 n,你想找到一个平方数 sq,它满足以下条件:

2n - 1 < sq < 2n+1 - 1

这个条件可以理解为“我的数字必须是1后面跟着n个问号的形式”。

首先,如果 n 是偶数,那么 2n 是一个完美的平方数,并且符合你的条件(在二进制中,它的表示是1000...000,后面有n个零)。

如果 n 是奇数(比如 n = 2.p + 1),那么 2n+1 也是一个完美的平方数(即 (2p+1)2)。接下来计算的数字会给你一个完美的平方数:

(2p+1 - 1)2

为了满足第一个不等式,p 需要满足:

2n - 1 < (2p+1 - 1)2

0 < 2n+1 - 2p+2 + 1 - 2n + 1,

最后,

2n + 2 - 2p+2 > 0
或者
22p - 2p+1 + 1 > 0

如果我们考虑一个函数,它将 p 和 f(p) 关联起来,定义为:

f(p) = 22p - 2p+1 + 1

这个函数对于每个正实数都是定义好的,并且是严格递增的。此外,f(0) = 0。最后,当 p > 0 时,初始条件是满足的!对于 p = 0 - 或者 n = 1 -,这个问题没有有效的解决方案。

3

与其重新用C++编写代码,不如先考虑优化你的算法。最小的可能答案是把原始值的平方根算出来,把所有的'?'替换成0,然后向上取整;最大的可能答案是把'?'替换成1后算平方根,然后向下取整。找到这两个值后,逐个检查它们的平方是否符合模式。

这样做更快,因为你只需要检查的数字少得多,而且在循环中不需要计算平方根:平方运算要简单得多。

你不需要通过比较字符串来检查是否匹配:

mask = int(pattern.replace('0', '1').replace('?', '0'), 2)
test = int(pattern.replace('?', '0'), 2)

def is_match(n):
    return (n&mask)==test

所以把这些内容整合在一起:

def int_sqrt(x):
    if x < 0:
        raise ValueError('square root not defined for negative numbers')
    n = int(x)
    if n == 0:
        return 0
    a, b = divmod(n.bit_length(), 2)
    x = 2**(a+b)
    while True:
        y = (x + n//x)//2
        if y >= x:
            return x
        x = y

def find_match(pattern):
    lowest = int(pattern.replace('?', '0'), 2)
    highest = int(pattern.replace('?', '1'), 2)
    mask = int(pattern.replace('0', '1').replace('?', '0'), 2)
    lowsqrt = int_sqrt(lowest)
    if lowsqrt*lowsqrt != lowest:
            lowsqrt += 1
    highsqrt = int_sqrt(highest)
    for n in range(lowsqrt, highsqrt+1):
        if (n*n & mask)==lowest:
            return n*n

print(find_match('1??1??1'))
print(find_match('1??0??1'))
print(find_match('1??????????????????????????????????????????????????????????????????????1??0??1'))

输出:

121
81
151115727461209345152081

注意:这只适用于Python 3.x,最后的测试在Python 2.x中会导致range溢出。

撰写回答