标签归档:optimization

如何将数据集分割/划分为训练和测试数据集,例如进行交叉验证?

问题:如何将数据集分割/划分为训练和测试数据集,例如进行交叉验证?

将NumPy数组随机分为训练和测试/验证数据集的好方法是什么?与Matlab中的cvpartitioncrossvalind函数类似。

What is a good way to split a NumPy array randomly into training and testing/validation dataset? Something similar to the cvpartition or crossvalind functions in Matlab.


回答 0

如果要将数据集分成两半,可以使用numpy.random.shuffle,或者numpy.random.permutation需要跟踪索引:

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
numpy.random.shuffle(x)
training, test = x[:80,:], x[80:,:]

要么

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
indices = numpy.random.permutation(x.shape[0])
training_idx, test_idx = indices[:80], indices[80:]
training, test = x[training_idx,:], x[test_idx,:]

有多种方法可以重复分区同一数据集以进行交叉验证。一种策略是从数据集中重复采样:

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
training_idx = numpy.random.randint(x.shape[0], size=80)
test_idx = numpy.random.randint(x.shape[0], size=20)
training, test = x[training_idx,:], x[test_idx,:]

最后,sklearn包含几种交叉验证方法(k折,nave -n-out等)。它还包括更高级的“分层抽样”方法,该方法创建了针对某些功能平衡的数据分区,例如,确保训练和测试集中的正例和负例比例相同。

If you want to split the data set once in two halves, you can use numpy.random.shuffle, or numpy.random.permutation if you need to keep track of the indices:

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
numpy.random.shuffle(x)
training, test = x[:80,:], x[80:,:]

or

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
indices = numpy.random.permutation(x.shape[0])
training_idx, test_idx = indices[:80], indices[80:]
training, test = x[training_idx,:], x[test_idx,:]

There are many ways to repeatedly partition the same data set for cross validation. One strategy is to resample from the dataset, with repetition:

import numpy
# x is your dataset
x = numpy.random.rand(100, 5)
training_idx = numpy.random.randint(x.shape[0], size=80)
test_idx = numpy.random.randint(x.shape[0], size=20)
training, test = x[training_idx,:], x[test_idx,:]

Finally, sklearn contains several cross validation methods (k-fold, leave-n-out, …). It also includes more advanced “stratified sampling” methods that create a partition of the data that is balanced with respect to some features, for example to make sure that there is the same proportion of positive and negative examples in the training and test set.


回答 1

还有另一个选择就是需要使用scikit-learn。如scikit的Wiki所述,您可以按照以下说明进行操作:

from sklearn.model_selection import train_test_split

data, labels = np.arange(10).reshape((5, 2)), range(5)

data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.20, random_state=42)

这样,您就可以将要拆分为训练和测试的数据的标签保持同步。

There is another option that just entails using scikit-learn. As scikit’s wiki describes, you can just use the following instructions:

from sklearn.model_selection import train_test_split

data, labels = np.arange(10).reshape((5, 2)), range(5)

data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.20, random_state=42)

This way you can keep in sync the labels for the data you’re trying to split into training and test.


回答 2

请注意。如果您想要训练,测试和AND验证集,则可以执行以下操作:

from sklearn.cross_validation import train_test_split

X = get_my_X()
y = get_my_y()
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5)

这些参数将为训练提供70%,为测试和验证集各提供15%。希望这可以帮助。

Just a note. In case you want train, test, AND validation sets, you can do this:

from sklearn.cross_validation import train_test_split

X = get_my_X()
y = get_my_y()
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5)

These parameters will give 70 % to training, and 15 % each to test and val sets. Hope this helps.


回答 3

由于sklearn.cross_validation模块被弃用,你可以使用:

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)

X_trn, X_tst, y_trn, y_tst = train_test_split(X, y, test_size=0.2, random_state=42)

As sklearn.cross_validation module was deprecated, you can use:

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)

X_trn, X_tst, y_trn, y_tst = train_test_split(X, y, test_size=0.2, random_state=42)

回答 4

您还可以考虑将训练和测试集进行分层。初始除法还会随机生成训练集和测试集,但要保留原始Class的比例。这使得训练和测试集更好地反映了原始数据集的属性。

import numpy as np  

def get_train_test_inds(y,train_proportion=0.7):
    '''Generates indices, making random stratified split into training set and testing sets
    with proportions train_proportion and (1-train_proportion) of initial sample.
    y is any iterable indicating classes of each observation in the sample.
    Initial proportions of classes inside training and 
    testing sets are preserved (stratified sampling).
    '''

    y=np.array(y)
    train_inds = np.zeros(len(y),dtype=bool)
    test_inds = np.zeros(len(y),dtype=bool)
    values = np.unique(y)
    for value in values:
        value_inds = np.nonzero(y==value)[0]
        np.random.shuffle(value_inds)
        n = int(train_proportion*len(value_inds))

        train_inds[value_inds[:n]]=True
        test_inds[value_inds[n:]]=True

    return train_inds,test_inds

y = np.array([1,1,2,2,3,3])
train_inds,test_inds = get_train_test_inds(y,train_proportion=0.5)
print y[train_inds]
print y[test_inds]

此代码输出:

[1 2 3]
[1 2 3]

You may also consider stratified division into training and testing set. Startified division also generates training and testing set randomly but in such a way that original class proportions are preserved. This makes training and testing sets better reflect the properties of the original dataset.

import numpy as np  

def get_train_test_inds(y,train_proportion=0.7):
    '''Generates indices, making random stratified split into training set and testing sets
    with proportions train_proportion and (1-train_proportion) of initial sample.
    y is any iterable indicating classes of each observation in the sample.
    Initial proportions of classes inside training and 
    testing sets are preserved (stratified sampling).
    '''

    y=np.array(y)
    train_inds = np.zeros(len(y),dtype=bool)
    test_inds = np.zeros(len(y),dtype=bool)
    values = np.unique(y)
    for value in values:
        value_inds = np.nonzero(y==value)[0]
        np.random.shuffle(value_inds)
        n = int(train_proportion*len(value_inds))

        train_inds[value_inds[:n]]=True
        test_inds[value_inds[n:]]=True

    return train_inds,test_inds

y = np.array([1,1,2,2,3,3])
train_inds,test_inds = get_train_test_inds(y,train_proportion=0.5)
print y[train_inds]
print y[test_inds]

This code outputs:

[1 2 3]
[1 2 3]

回答 5

我为自己的项目编写了一个函数来执行此操作(尽管它不使用numpy):

def partition(seq, chunks):
    """Splits the sequence into equal sized chunks and them as a list"""
    result = []
    for i in range(chunks):
        chunk = []
        for element in seq[i:len(seq):chunks]:
            chunk.append(element)
        result.append(chunk)
    return result

如果您希望将块随机化,则在将列表传递之前先对其进行随机排序。

I wrote a function for my own project to do this (it doesn’t use numpy, though):

def partition(seq, chunks):
    """Splits the sequence into equal sized chunks and them as a list"""
    result = []
    for i in range(chunks):
        chunk = []
        for element in seq[i:len(seq):chunks]:
            chunk.append(element)
        result.append(chunk)
    return result

If you want the chunks to be randomized, just shuffle the list before passing it in.


回答 6

这是一个以分层方式将数据分成n = 5折的代码

% X = data array
% y = Class_label
from sklearn.cross_validation import StratifiedKFold
skf = StratifiedKFold(y, n_folds=5)
for train_index, test_index in skf:
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

Here is a code to split the data into n=5 folds in a stratified manner

% X = data array
% y = Class_label
from sklearn.cross_validation import StratifiedKFold
skf = StratifiedKFold(y, n_folds=5)
for train_index, test_index in skf:
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

回答 7

感谢pberkes的回答。我只是对其进行了修改,以避免(1)在训练和测试中采样(2)重复的实例时进行替换:

training_idx = np.random.choice(X.shape[0], int(np.round(X.shape[0] * 0.8)),replace=False)
training_idx = np.random.permutation(np.arange(X.shape[0]))[:np.round(X.shape[0] * 0.8)]
    test_idx = np.setdiff1d( np.arange(0,X.shape[0]), training_idx)

Thanks pberkes for your answer. I just modified it to avoid (1) replacement while sampling (2) duplicated instances occurred in both training and testing:

training_idx = np.random.choice(X.shape[0], int(np.round(X.shape[0] * 0.8)),replace=False)
training_idx = np.random.permutation(np.arange(X.shape[0]))[:np.round(X.shape[0] * 0.8)]
    test_idx = np.setdiff1d( np.arange(0,X.shape[0]), training_idx)

回答 8

在进行了一些阅读并考虑了(许多..)将数据拆分以进行训练和测试的不同方式之后,我决定计时了!

我使用了4种不同的方法(其中没有一种使用的是sklearn库,我相信它将得到最好的结果,因为它是经过精心设计和测试的代码):

  1. 洗净整个矩阵arr,然后拆分数据以进行训练和测试
  2. 随机排列索引,然后将其分配给x和y以拆分数据
  3. 与方法2相同,但以更有效的方式进行
  4. 使用熊猫数据框进行拆分

方法3赢得的时间最短,仅次于方法1,而方法2和4的确效率很低。

我计时的4种不同方法的代码:

import numpy as np
arr = np.random.rand(100, 3)
X = arr[:,:2]
Y = arr[:,2]
spl = 0.7
N = len(arr)
sample = int(spl*N)

#%% Method 1:  shuffle the whole matrix arr and then split
np.random.shuffle(arr)
x_train, x_test, y_train, y_test = X[:sample,:], X[sample:, :], Y[:sample, ], Y[sample:,]

#%% Method 2: shuffle the indecies and then shuffle and apply to X and Y
train_idx = np.random.choice(N, sample)
Xtrain = X[train_idx]
Ytrain = Y[train_idx]

test_idx = [idx for idx in range(N) if idx not in train_idx]
Xtest = X[test_idx]
Ytest = Y[test_idx]

#%% Method 3: shuffle indicies without a for loop
idx = np.random.permutation(arr.shape[0])  # can also use random.shuffle
train_idx, test_idx = idx[:sample], idx[sample:]
x_train, x_test, y_train, y_test = X[train_idx,:], X[test_idx,:], Y[train_idx,], Y[test_idx,]

#%% Method 4: using pandas dataframe to split
import pandas as pd
df = pd.read_csv(file_path, header=None) # Some csv file (I used some file with 3 columns)

train = df.sample(frac=0.7, random_state=200)
test = df.drop(train.index)

在这段时间内,执行1000次循环的3次重复的最短时间为:

  • 方法1:0.35883826200006297秒
  • 方法2:1.7157016959999964秒
  • 方法3:1.7876616719995582秒
  • 方法4:0.07562861499991413秒

希望对您有所帮助!

After doing some reading and taking into account the (many..) different ways of splitting the data to train and test, I decided to timeit!

I used 4 different methods (non of them are using the library sklearn, which I’m sure will give the best results, giving that it is well designed and tested code):

  1. shuffle the whole matrix arr and then split the data to train and test
  2. shuffle the indices and then assign it x and y to split the data
  3. same as method 2, but in a more efficient way to do it
  4. using pandas dataframe to split

method 3 won by far with the shortest time, after that method 1, and method 2 and 4 discovered to be really inefficient.

The code for the 4 different methods I timed:

import numpy as np
arr = np.random.rand(100, 3)
X = arr[:,:2]
Y = arr[:,2]
spl = 0.7
N = len(arr)
sample = int(spl*N)

#%% Method 1:  shuffle the whole matrix arr and then split
np.random.shuffle(arr)
x_train, x_test, y_train, y_test = X[:sample,:], X[sample:, :], Y[:sample, ], Y[sample:,]

#%% Method 2: shuffle the indecies and then shuffle and apply to X and Y
train_idx = np.random.choice(N, sample)
Xtrain = X[train_idx]
Ytrain = Y[train_idx]

test_idx = [idx for idx in range(N) if idx not in train_idx]
Xtest = X[test_idx]
Ytest = Y[test_idx]

#%% Method 3: shuffle indicies without a for loop
idx = np.random.permutation(arr.shape[0])  # can also use random.shuffle
train_idx, test_idx = idx[:sample], idx[sample:]
x_train, x_test, y_train, y_test = X[train_idx,:], X[test_idx,:], Y[train_idx,], Y[test_idx,]

#%% Method 4: using pandas dataframe to split
import pandas as pd
df = pd.read_csv(file_path, header=None) # Some csv file (I used some file with 3 columns)

train = df.sample(frac=0.7, random_state=200)
test = df.drop(train.index)

And for the times, the minimum time to execute out of 3 repetitions of 1000 loops is:

  • Method 1: 0.35883826200006297 seconds
  • Method 2: 1.7157016959999964 seconds
  • Method 3: 1.7876616719995582 seconds
  • Method 4: 0.07562861499991413 seconds

I hope that’s helpful!


回答 9

可能您不仅需要拆分训练和测试,而且还需要交叉验证以确保模型能够概括。在这里,我假设70%的训练数据,20%的验证和10%的坚持/测试数据。

查看np.split

如果indexs_or_sections是一维排序的整数数组,则条目指示沿轴在哪里拆分该数组。例如,对于轴= 0,[2,3]将导致

ary [:2] ary [2:3] ary [3:]

t, v, h = np.split(df.sample(frac=1, random_state=1), [int(0.7*len(df)), int(0.9*len(df))]) 

Likely you will not only need to split into train and test, but also cross validation to make sure your model generalizes. Here I am assuming 70% training data, 20% validation and 10% holdout/test data.

Check out the np.split:

If indices_or_sections is a 1-D array of sorted integers, the entries indicate where along axis the array is split. For example, [2, 3] would, for axis=0, result in

ary[:2] ary[2:3] ary[3:]

t, v, h = np.split(df.sample(frac=1, random_state=1), [int(0.7*len(df)), int(0.9*len(df))]) 

回答 10

分为火车测试并有效

x =np.expand_dims(np.arange(100), -1)


print(x)

indices = np.random.permutation(x.shape[0])

training_idx, test_idx, val_idx = indices[:int(x.shape[0]*.9)], indices[int(x.shape[0]*.9):int(x.shape[0]*.95)],  indices[int(x.shape[0]*.9):int(x.shape[0]*.95)]


training, test, val = x[training_idx,:], x[test_idx,:], x[val_idx,:]

print(training, test, val)

Split into train test and valid

x =np.expand_dims(np.arange(100), -1)


print(x)

indices = np.random.permutation(x.shape[0])

training_idx, test_idx, val_idx = indices[:int(x.shape[0]*.9)], indices[int(x.shape[0]*.9):int(x.shape[0]*.95)],  indices[int(x.shape[0]*.9):int(x.shape[0]*.95)]


training, test, val = x[training_idx,:], x[test_idx,:], x[val_idx,:]

print(training, test, val)

在Python中处理非常大的数字

问题:在Python中处理非常大的数字

我一直在考虑使用Python快速评估手牌。在我看来,加快处理速度的一种方法是将所有牌面和西服表示为质数,然后将它们相乘以表示手。白衣:

class PokerCard:
    faces = '23456789TJQKA'
    suits = 'cdhs'
    facePrimes = [11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 53, 59, 61]
    suitPrimes = [2, 3, 5, 7]

    def HashVal(self):
      return PokerCard.facePrimes[self.cardFace] * PokerCard.suitPrimes[self.cardSuit]

这将为每只手提供一个数值,通过模可以告诉我手中有多少个国王或有多少个心。例如,任何有五个或更多球杆的手都会平均除以2 ^ 5;任何有四位国王的手将平均除以59 ^ 4,依此类推。

问题在于,像AcAdAhAsKdKhKs这样的七张牌手的散列值约为62.7万亿次,这需要超过32位才能在内部进行表示。有没有一种方法可以在Python中存储如此大的数字,从而允许我对其执行算术运算?

I’ve been considering fast poker hand evaluation in Python. It occurred to me that one way to speed the process up would be to represent all the card faces and suits as prime numbers and multiply them together to represent the hands. To whit:

class PokerCard:
    faces = '23456789TJQKA'
    suits = 'cdhs'
    facePrimes = [11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 53, 59, 61]
    suitPrimes = [2, 3, 5, 7]

AND

    def HashVal(self):
      return PokerCard.facePrimes[self.cardFace] * PokerCard.suitPrimes[self.cardSuit]

This would give each hand a numeric value that, through modulo could tell me how many kings are in the hand or how many hearts. For example, any hand with five or more clubs in it would divide evenly by 2^5; any hand with four kings would divide evenly by 59^4, etc.

The problem is that a seven-card hand like AcAdAhAsKdKhKs has a hash value of approximately 62.7 quadrillion, which would take considerably more than 32 bits to represent internally. Is there a way to store such large numbers in Python that will allow me to perform arithmetic operations on it?


回答 0

Python支持“ bignum”整数类型,该整数类型可以处理任意大数。在Python 2.5+中,将调用此类型long并将其与该int类型分开,但是解释器将自动使用更合适的那个类型。在Python 3.0+中,int类型已完全删除。

不过,这只是实现细节-只要您拥有2.5版或更高版本,就只需执行标准数学运算,并且任何超出32位数学界限的数字都将自动(透明地)转换为bignum。

您可以在PEP 0237中找到所有血腥细节。

Python supports a “bignum” integer type which can work with arbitrarily large numbers. In Python 2.5+, this type is called long and is separate from the int type, but the interpreter will automatically use whichever is more appropriate. In Python 3.0+, the int type has been dropped completely.

That’s just an implementation detail, though — as long as you have version 2.5 or better, just perform standard math operations and any number which exceeds the boundaries of 32-bit math will be automatically (and transparently) converted to a bignum.

You can find all the gory details in PEP 0237.


回答 1

python 自然支持任意整数

例:

>>>10 ** 1000 100000000000000000000000000000000000000000000000000000000000000000000000000000000000000 000000000000000000000000000000000000000000000000000000000000000000000000000000000000 000000000000000000000000000000000000000000000000000000000000000000000000000000百万000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

例如,您甚至可以得到一个巨大的整数值fib(4000000)。

但是它仍然(目前)支持任意大的浮点数

如果您需要一个大的,大的浮子,那么请检查小数模块。在这些论坛上有一些使用示例:OverflowError:(34,“结果太大”)

另一个参考:http : //docs.python.org/2/library/decimal.html

如果需要提高速度,甚至可以使用gmpy模块(这可能是您感兴趣的):在代码中处理大数

另一种参考:https : //code.google.com/p/gmpy/

python supports arbitrarily large integers naturally:

example:

>>> 10**1000
10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

You could even get, for example of a huge integer value, fib(4000000).

But still it does not (for now) supports an arbitrarily large float !!

If you need one big, large, float then check up on the decimal Module. There are examples of use on these foruns: OverflowError: (34, ‘Result too large’)

Another reference: http://docs.python.org/2/library/decimal.html

You can even using the gmpy module if you need a speed-up (which is likely to be of your interest): Handling big numbers in code

Another reference: https://code.google.com/p/gmpy/


回答 2

您可以这样做很有趣,但除此之外,这不是一个好主意。它不会加快我能想到的任何速度。

  • 拿到手中的卡将是一个整数分解操作,这比访问数组要昂贵得多。

  • 添加卡将是乘法运算,而删除卡是除法运算,这两个字都是大型的多字数字,这比在列表中添加或删除元素要昂贵得多。

  • 手的实际数字值不会告诉您任何信息。您将需要考虑素数,并遵循扑克规则比较两只手。对于这样的手,h1 <h2毫无意义。

You could do this for the fun of it, but other than that it’s not a good idea. It would not speed up anything I can think of.

  • Getting the cards in a hand will be an integer factoring operation which is much more expensive than just accessing an array.

  • Adding cards would be multiplication, and removing cards division, both of large multi-word numbers, which are more expensive operations than adding or removing elements from lists.

  • The actual numeric value of a hand will tell you nothing. You will need to factor the primes and follow the Poker rules to compare two hands. h1 < h2 for such hands means nothing.


回答 3

python自然支持任意大整数:

In [1]: 59**3*61**4*2*3*5*7*3*5*7
Out[1]: 62702371781194950
In [2]: _ % 61**4
Out[2]: 0

python supports arbitrarily large integers naturally:

In [1]: 59**3*61**4*2*3*5*7*3*5*7
Out[1]: 62702371781194950
In [2]: _ % 61**4
Out[2]: 0

回答 4

python解释器将为您处理它,您只需执行操作(+,-,*,/),它便会正常工作。

int值是无限的。

进行除法时要小心,默认情况下,商会变成float,但float不支持这么大的数。如果收到错误消息说不float支持这么大的数字,则意味着商太大而无法存储在其中float,则必须使用下限除法(//)。

它忽略小数点后的任何小数,这样结果将为int,因此您可以得到大数的结果。

10//3 产出 3

10//4 输出 2

The python interpreter will handle it for you, you just have to do your operations (+, -, *, /), and it will work as normal.

The int value is unlimited.

Careful when doing division, by default the quotient is turned into float, but float does not support such large numbers. If you get an error message saying float does not support such large numbers, then it means the quotient is too large to be stored in float you’ll have to use floor division (//).

It ignores any decimal that comes after the decimal point, this way, the result will be int, so you can have a large number result.

>>>10//3
3

>>>10//4
2

为什么提早归还比其他要慢?

问题:为什么提早归还比其他要慢?

这是我几天前给出的答案的后续问题。编辑:该问题的OP似乎已经使用了我发布给他的代码来问同样的问题,但是我没有意识到。道歉。提供的答案是不同的!

我基本上观察到:

>>> def without_else(param=False):
...     if param:
...         return 1
...     return 0
>>> def with_else(param=False):
...     if param:
...         return 1
...     else:
...         return 0
>>> from timeit import Timer as T
>>> T(lambda : without_else()).repeat()
[0.3011460304260254, 0.2866089344024658, 0.2871549129486084]
>>> T(lambda : with_else()).repeat()
[0.27536892890930176, 0.2693932056427002, 0.27011704444885254]
>>> T(lambda : without_else(True)).repeat()
[0.3383951187133789, 0.32756996154785156, 0.3279120922088623]
>>> T(lambda : with_else(True)).repeat()
[0.3305950164794922, 0.32186388969421387, 0.3209099769592285]

…或者换句话说:else无论if是否触发条件,使子句更快。

我认为这与两者生成的不同字节码有关,但是有人能详细确认/解释吗?

编辑:似乎不是每个人都可以重现我的时间安排,所以我认为在我的系统上提供一些信息可能有用。我正在运行安装了默认python的Ubuntu 11.10 64位。python生成以下版本信息:

Python 2.7.2+ (default, Oct  4 2011, 20:06:09) 
[GCC 4.6.1] on linux2

以下是Python 2.7中反汇编的结果:

>>> dis.dis(without_else)
  2           0 LOAD_FAST                0 (param)
              3 POP_JUMP_IF_FALSE       10

  3           6 LOAD_CONST               1 (1)
              9 RETURN_VALUE        

  4     >>   10 LOAD_CONST               2 (0)
             13 RETURN_VALUE        
>>> dis.dis(with_else)
  2           0 LOAD_FAST                0 (param)
              3 POP_JUMP_IF_FALSE       10

  3           6 LOAD_CONST               1 (1)
              9 RETURN_VALUE        

  5     >>   10 LOAD_CONST               2 (0)
             13 RETURN_VALUE        
             14 LOAD_CONST               0 (None)
             17 RETURN_VALUE        

This is a follow-up question to an answer I gave a few days back. Edit: it seems that the OP of that question already used the code I posted to him to ask the same question, but I was unaware of it. Apologies. The answers provided are different though!

Substantially I observed that:

>>> def without_else(param=False):
...     if param:
...         return 1
...     return 0
>>> def with_else(param=False):
...     if param:
...         return 1
...     else:
...         return 0
>>> from timeit import Timer as T
>>> T(lambda : without_else()).repeat()
[0.3011460304260254, 0.2866089344024658, 0.2871549129486084]
>>> T(lambda : with_else()).repeat()
[0.27536892890930176, 0.2693932056427002, 0.27011704444885254]
>>> T(lambda : without_else(True)).repeat()
[0.3383951187133789, 0.32756996154785156, 0.3279120922088623]
>>> T(lambda : with_else(True)).repeat()
[0.3305950164794922, 0.32186388969421387, 0.3209099769592285]

…or in other words: having the else clause is faster regardless of the if condition being triggered or not.

I assume it has to do with different bytecode generated by the two, but is anybody able to confirm/explain in detail?

EDIT: Seems not everybody is able to reproduce my timings, so I thought it might be useful to give some info on my system. I’m running Ubuntu 11.10 64 bit with the default python installed. python generates the following version information:

Python 2.7.2+ (default, Oct  4 2011, 20:06:09) 
[GCC 4.6.1] on linux2

Here are the results of the disassembly in Python 2.7:

>>> dis.dis(without_else)
  2           0 LOAD_FAST                0 (param)
              3 POP_JUMP_IF_FALSE       10

  3           6 LOAD_CONST               1 (1)
              9 RETURN_VALUE        

  4     >>   10 LOAD_CONST               2 (0)
             13 RETURN_VALUE        
>>> dis.dis(with_else)
  2           0 LOAD_FAST                0 (param)
              3 POP_JUMP_IF_FALSE       10

  3           6 LOAD_CONST               1 (1)
              9 RETURN_VALUE        

  5     >>   10 LOAD_CONST               2 (0)
             13 RETURN_VALUE        
             14 LOAD_CONST               0 (None)
             17 RETURN_VALUE        

回答 0

这纯粹是猜测,我还没有找到一种简单的方法来检查它是否正确,但是我有一个理论适合您。

我尝试了您的代码并获得了相同的结果,但without_else()反复地比with_else()

>>> T(lambda : without_else()).repeat()
[0.42015745017874906, 0.3188967452567226, 0.31984281521812363]
>>> T(lambda : with_else()).repeat()
[0.36009842032996175, 0.28962249392031936, 0.2927151355828528]
>>> T(lambda : without_else(True)).repeat()
[0.31709728471076915, 0.3172671387005721, 0.3285821242644147]
>>> T(lambda : with_else(True)).repeat()
[0.30939889008243426, 0.3035132258429485, 0.3046679117038593]

考虑到字节码是相同的,唯一的区别是函数的名称。特别是时序测试会在全局名称上进行查找。尝试重命名without_else(),区别消失:

>>> def no_else(param=False):
    if param:
        return 1
    return 0

>>> T(lambda : no_else()).repeat()
[0.3359846013948413, 0.29025818923918223, 0.2921801513879245]
>>> T(lambda : no_else(True)).repeat()
[0.3810395594970828, 0.2969634408842694, 0.2960104566362247]

我的猜测是,这without_else与其他内容发生了哈希冲突,globals()因此全局名称查找稍微慢一些。

编辑:具有7个或8个键的字典可能具有32个插槽,因此在此基础上,without_else哈希与__builtins__

>>> [(k, hash(k) % 32) for k in globals().keys() ]
[('__builtins__', 8), ('with_else', 9), ('__package__', 15), ('without_else', 8), ('T', 21), ('__name__', 25), ('no_else', 28), ('__doc__', 29)]

要阐明哈希如何工作:

__builtins__ 散列为-1196389688,这减小了表大小(32)的模数,这意味着它存储在表的#8插槽中。

without_else哈希到505688136,将模32的值降低为8,因此发生冲突。要解决此问题,Python计算:

从…开始:

j = hash % 32
perturb = hash

重复此过程,直到找到可用的插槽:

j = (5*j) + 1 + perturb;
perturb >>= 5;
use j % 2**i as the next table index;

这使它可以用作下一个索引17。幸运的是,它是免费的,因此循环仅重复一次。哈希表的大小是2的幂,哈希表的大小也是2的幂2**ii是哈希值使用的位数j

对该表的每个探测都可以找到以下之一:

  • 插槽为空,在这种情况下,探测停止,并且我们知道该值不在表中。

  • 该广告位尚未使用,但已在过去使用过,在这种情况下,我们尝试使用如上计算的下一个值。

  • 插槽已满,但是存储在表中的完整哈希值与我们要查找的键的哈希值不同(在__builtins__vs 的情况下就是这样without_else)。

  • 插槽已满,并且恰好具有我们想要的哈希值,然后Python检查以查看键和我们正在查找的对象是否是同一对象(在这种情况下,这是因为可能会插入标识符的短字符串被插入了,所以相同的标识符使用完全相同的字符串)。

  • 最终,当插槽已满时,哈希值完全匹配,但是键不是同一对象,然后Python才会尝试比较它们是否相等。这相对较慢,但实际上不应该进行名称查找。

This is a pure guess, and I haven’t figured out an easy way to check whether it is right, but I have a theory for you.

I tried your code and get the same of results, without_else() is repeatedly slightly slower than with_else():

>>> T(lambda : without_else()).repeat()
[0.42015745017874906, 0.3188967452567226, 0.31984281521812363]
>>> T(lambda : with_else()).repeat()
[0.36009842032996175, 0.28962249392031936, 0.2927151355828528]
>>> T(lambda : without_else(True)).repeat()
[0.31709728471076915, 0.3172671387005721, 0.3285821242644147]
>>> T(lambda : with_else(True)).repeat()
[0.30939889008243426, 0.3035132258429485, 0.3046679117038593]

Considering that the bytecode is identical, the only difference is the name of the function. In particular the timing test does a lookup on the global name. Try renaming without_else() and the difference disappears:

>>> def no_else(param=False):
    if param:
        return 1
    return 0

>>> T(lambda : no_else()).repeat()
[0.3359846013948413, 0.29025818923918223, 0.2921801513879245]
>>> T(lambda : no_else(True)).repeat()
[0.3810395594970828, 0.2969634408842694, 0.2960104566362247]

My guess is that without_else has a hash collision with something else in globals() so the global name lookup is slightly slower.

Edit: A dictionary with 7 or 8 keys probably has 32 slots, so on that basis without_else has a hash collision with __builtins__:

>>> [(k, hash(k) % 32) for k in globals().keys() ]
[('__builtins__', 8), ('with_else', 9), ('__package__', 15), ('without_else', 8), ('T', 21), ('__name__', 25), ('no_else', 28), ('__doc__', 29)]

To clarify how the hashing works:

__builtins__ hashes to -1196389688 which reduced modulo the table size (32) means it is stored in the #8 slot of the table.

without_else hashes to 505688136 which reduced modulo 32 is 8 so there’s a collision. To resolve this Python calculates:

Starting with:

j = hash % 32
perturb = hash

Repeat this until we find a free slot:

j = (5*j) + 1 + perturb;
perturb >>= 5;
use j % 2**i as the next table index;

which gives it 17 to use as the next index. Fortunately that’s free so the loop only repeats once. The hash table size is a power of 2, so 2**i is the size of the hash table, i is the number of bits used from the hash value j.

Each probe into the table can find one of these:

  • The slot is empty, in that case the probing stops and we know the value is not in the table.

  • The slot is unused but was used in the past in which case we go try the next value calculated as above.

  • The slot is full but the full hash value stored in the table isn’t the same as the hash of the key we are looking for (that’s what happens in the case of __builtins__ vs without_else).

  • The slot is full and has exactly the hash value we want, then Python checks to see if the key and the object we are looking up are the same object (which in this case they will be because short strings that could be identifiers are interned so identical identifiers use the exact same string).

  • Finally when the slot is full, the hash matches exactly, but the keys are not the identical object, then and only then will Python try comparing them for equality. This is comparatively slow, but in the case of name lookups shouldn’t actually happen.


加权版本的random.choice

问题:加权版本的random.choice

我需要写一个加权版本的random.choice(列表中的每个元素被选择的可能性都不同)。这是我想出的:

def weightedChoice(choices):
    """Like random.choice, but each element can have a different chance of
    being selected.

    choices can be any iterable containing iterables with two items each.
    Technically, they can have more than two items, the rest will just be
    ignored.  The first item is the thing being chosen, the second item is
    its weight.  The weights can be any numeric values, what matters is the
    relative differences between them.
    """
    space = {}
    current = 0
    for choice, weight in choices:
        if weight > 0:
            space[current] = choice
            current += weight
    rand = random.uniform(0, current)
    for key in sorted(space.keys() + [current]):
        if rand < key:
            return choice
        choice = space[key]
    return None

对于我来说,此功能似乎过于复杂且难看。我希望这里的每个人都可以提出一些改进建议或其他替代方法。对于我来说,效率并不像代码的清洁度和可读性那么重要。

I needed to write a weighted version of random.choice (each element in the list has a different probability for being selected). This is what I came up with:

def weightedChoice(choices):
    """Like random.choice, but each element can have a different chance of
    being selected.

    choices can be any iterable containing iterables with two items each.
    Technically, they can have more than two items, the rest will just be
    ignored.  The first item is the thing being chosen, the second item is
    its weight.  The weights can be any numeric values, what matters is the
    relative differences between them.
    """
    space = {}
    current = 0
    for choice, weight in choices:
        if weight > 0:
            space[current] = choice
            current += weight
    rand = random.uniform(0, current)
    for key in sorted(space.keys() + [current]):
        if rand < key:
            return choice
        choice = space[key]
    return None

This function seems overly complex to me, and ugly. I’m hoping everyone here can offer some suggestions on improving it or alternate ways of doing this. Efficiency isn’t as important to me as code cleanliness and readability.


回答 0

从1.7.0版开始,NumPy具有choice支持概率分布的功能。

from numpy.random import choice
draw = choice(list_of_candidates, number_of_items_to_pick,
              p=probability_distribution)

请注意,这probability_distribution是顺序相同的序列list_of_candidates。您也可以使用关键字replace=False来更改行为,以便不替换绘制的项目。

Since version 1.7.0, NumPy has a choice function that supports probability distributions.

from numpy.random import choice
draw = choice(list_of_candidates, number_of_items_to_pick,
              p=probability_distribution)

Note that probability_distribution is a sequence in the same order of list_of_candidates. You can also use the keyword replace=False to change the behavior so that drawn items are not replaced.


回答 1

从Python 3.6 choices开始,random模块提供了一种方法。

Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04)
Type 'copyright', 'credits' or 'license' for more information
IPython 6.0.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import random

In [2]: random.choices(
...:     population=[['a','b'], ['b','a'], ['c','b']],
...:     weights=[0.2, 0.2, 0.6],
...:     k=10
...: )

Out[2]:
[['c', 'b'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['c', 'b']]

请注意,random.choices将根据docs 进行替换示例:

返回k从总体中选择并替换的元素的大小列表。

如果您需要采样而不进行替换,那么正如@ronan-paixão出色的回答所言,您可以使用numpy.choice,其replace参数控制着这种行为。

Since Python 3.6 there is a method choices from the random module.

Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04)
Type 'copyright', 'credits' or 'license' for more information
IPython 6.0.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import random

In [2]: random.choices(
...:     population=[['a','b'], ['b','a'], ['c','b']],
...:     weights=[0.2, 0.2, 0.6],
...:     k=10
...: )

Out[2]:
[['c', 'b'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['b', 'a'],
 ['c', 'b'],
 ['c', 'b']]

Note that random.choices will sample with replacement, per the docs:

Return a k sized list of elements chosen from the population with replacement.

Note for completeness of answer:

When a sampling unit is drawn from a finite population and is returned to that population, after its characteristic(s) have been recorded, before the next unit is drawn, the sampling is said to be “with replacement”. It basically means each element may be chosen more than once.

If you need to sample without replacement, then as @ronan-paixão’s brilliant answer states, you can use numpy.choice, whose replace argument controls such behaviour.


回答 2

def weighted_choice(choices):
   total = sum(w for c, w in choices)
   r = random.uniform(0, total)
   upto = 0
   for c, w in choices:
      if upto + w >= r:
         return c
      upto += w
   assert False, "Shouldn't get here"
def weighted_choice(choices):
   total = sum(w for c, w in choices)
   r = random.uniform(0, total)
   upto = 0
   for c, w in choices:
      if upto + w >= r:
         return c
      upto += w
   assert False, "Shouldn't get here"

回答 3

  1. 将权重排列为累积分布。
  2. 使用random.random()选择一个随机float 0.0 <= x < total
  3. http://docs.python.org/dev/library/bisect.html#other-examples中的示例所示,使用bisect.bisect搜索发行版。
from random import random
from bisect import bisect

def weighted_choice(choices):
    values, weights = zip(*choices)
    total = 0
    cum_weights = []
    for w in weights:
        total += w
        cum_weights.append(total)
    x = random() * total
    i = bisect(cum_weights, x)
    return values[i]

>>> weighted_choice([("WHITE",90), ("RED",8), ("GREEN",2)])
'WHITE'

如果您需要做出多个选择,请将其拆分为两个函数,一个用于构建累加权重,另一个用于平分到随机点。

  1. Arrange the weights into a cumulative distribution.
  2. Use random.random() to pick a random float 0.0 <= x < total.
  3. Search the distribution using bisect.bisect as shown in the example at http://docs.python.org/dev/library/bisect.html#other-examples.
from random import random
from bisect import bisect

def weighted_choice(choices):
    values, weights = zip(*choices)
    total = 0
    cum_weights = []
    for w in weights:
        total += w
        cum_weights.append(total)
    x = random() * total
    i = bisect(cum_weights, x)
    return values[i]

>>> weighted_choice([("WHITE",90), ("RED",8), ("GREEN",2)])
'WHITE'

If you need to make more than one choice, split this into two functions, one to build the cumulative weights and another to bisect to a random point.


回答 4

如果您不介意使用numpy,则可以使用numpy.random.choice

例如:

import numpy

items  = [["item1", 0.2], ["item2", 0.3], ["item3", 0.45], ["item4", 0.05]
elems = [i[0] for i in items]
probs = [i[1] for i in items]

trials = 1000
results = [0] * len(items)
for i in range(trials):
    res = numpy.random.choice(items, p=probs)  #This is where the item is selected!
    results[items.index(res)] += 1
results = [r / float(trials) for r in results]
print "item\texpected\tactual"
for i in range(len(probs)):
    print "%s\t%0.4f\t%0.4f" % (items[i], probs[i], results[i])

如果您知道需要预先选择多少个选项,则可以像这样循环执行:

numpy.random.choice(items, trials, p=probs)

If you don’t mind using numpy, you can use numpy.random.choice.

For example:

import numpy

items  = [["item1", 0.2], ["item2", 0.3], ["item3", 0.45], ["item4", 0.05]
elems = [i[0] for i in items]
probs = [i[1] for i in items]

trials = 1000
results = [0] * len(items)
for i in range(trials):
    res = numpy.random.choice(items, p=probs)  #This is where the item is selected!
    results[items.index(res)] += 1
results = [r / float(trials) for r in results]
print "item\texpected\tactual"
for i in range(len(probs)):
    print "%s\t%0.4f\t%0.4f" % (items[i], probs[i], results[i])

If you know how many selections you need to make in advance, you can do it without a loop like this:

numpy.random.choice(items, trials, p=probs)

回答 5

粗略,但可能足够:

import random
weighted_choice = lambda s : random.choice(sum(([v]*wt for v,wt in s),[]))

它行得通吗?

# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]

# initialize tally dict
tally = dict.fromkeys(choices, 0)

# tally up 1000 weighted choices
for i in xrange(1000):
    tally[weighted_choice(choices)] += 1

print tally.items()

印刷品:

[('WHITE', 904), ('GREEN', 22), ('RED', 74)]

假定所有权重都是整数。他们不必相加100,我只是这样做以使测试结果更易于解释。(如果权重是浮点数,则将它们全部乘以10,直到所有权重> =1。)

weights = [.6, .2, .001, .199]
while any(w < 1.0 for w in weights):
    weights = [w*10 for w in weights]
weights = map(int, weights)

Crude, but may be sufficient:

import random
weighted_choice = lambda s : random.choice(sum(([v]*wt for v,wt in s),[]))

Does it work?

# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]

# initialize tally dict
tally = dict.fromkeys(choices, 0)

# tally up 1000 weighted choices
for i in xrange(1000):
    tally[weighted_choice(choices)] += 1

print tally.items()

Prints:

[('WHITE', 904), ('GREEN', 22), ('RED', 74)]

Assumes that all weights are integers. They don’t have to add up to 100, I just did that to make the test results easier to interpret. (If weights are floating point numbers, multiply them all by 10 repeatedly until all weights >= 1.)

weights = [.6, .2, .001, .199]
while any(w < 1.0 for w in weights):
    weights = [w*10 for w in weights]
weights = map(int, weights)

回答 6

如果您有加权词典而不是列表,则可以这样写

items = { "a": 10, "b": 5, "c": 1 } 
random.choice([k for k in items for dummy in range(items[k])])

请注意[k for k in items for dummy in range(items[k])]生成此列表['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'c', 'b', 'b', 'b', 'b', 'b']

If you have a weighted dictionary instead of a list you can write this

items = { "a": 10, "b": 5, "c": 1 } 
random.choice([k for k in items for dummy in range(items[k])])

Note that [k for k in items for dummy in range(items[k])] produces this list ['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'c', 'b', 'b', 'b', 'b', 'b']


回答 7

从Python开始v3.6random.choices可用于list从给定总体中返回具有可选权重的指定大小的元素。

random.choices(population, weights=None, *, cum_weights=None, k=1)

  • 人口list包含独特的观察结果。(如果为空,则引发IndexError

  • 权重:更精确地进行选择所需的相对权重。

  • cum_weights:进行选择所需的累积权重。

  • k:要输出的size(lenlist。(默认len()=1


注意事项:

1)它使用加权抽样进行替换,因此抽取的项目以后将被替换。权重序列中的值本身并不重要,但它们的相对比率却无关紧要。

不同于np.random.choice仅可以将概率作为权重并且还必须确保单个概率的总和不超过1个标准的方法,这里没有这样的规定。只要它们属于数字类型(类型int/float/fraction除外Decimal),它们仍然会执行。

>>> import random
# weights being integers
>>> random.choices(["white", "green", "red"], [12, 12, 4], k=10)
['green', 'red', 'green', 'white', 'white', 'white', 'green', 'white', 'red', 'white']
# weights being floats
>>> random.choices(["white", "green", "red"], [.12, .12, .04], k=10)
['white', 'white', 'green', 'green', 'red', 'red', 'white', 'green', 'white', 'green']
# weights being fractions
>>> random.choices(["white", "green", "red"], [12/100, 12/100, 4/100], k=10)
['green', 'green', 'white', 'red', 'green', 'red', 'white', 'green', 'green', 'green']

2)如果既未指定权重指定cum_weights,则选择的可能性相等。如果提供了权重序列,则其长度必须与总体序列的长度相同。

同时指定权重cum_weights会引发一个TypeError

>>> random.choices(["white", "green", "red"], k=10)
['white', 'white', 'green', 'red', 'red', 'red', 'white', 'white', 'white', 'green']

3)cum_weights通常是itertools.accumulate函数的结果,在这种情况下确实很方便。

从链接的文档中:

在内部,相对权重在选择之前会转换为累积权重,因此提供累积权重可以节省工作。

因此,无论是供应weights=[12, 12, 4]还是cum_weights=[12, 24, 28]为我们精心策划的案例产生相同的结果,后者似乎更快/更有效率。

As of Python v3.6, random.choices could be used to return a list of elements of specified size from the given population with optional weights.

random.choices(population, weights=None, *, cum_weights=None, k=1)

  • population : list containing unique observations. (If empty, raises IndexError)

  • weights : More precisely relative weights required to make selections.

  • cum_weights : cumulative weights required to make selections.

  • k : size(len) of the list to be outputted. (Default len()=1)


Few Caveats:

1) It makes use of weighted sampling with replacement so the drawn items would be later replaced. The values in the weights sequence in itself do not matter, but their relative ratio does.

Unlike np.random.choice which can only take on probabilities as weights and also which must ensure summation of individual probabilities upto 1 criteria, there are no such regulations here. As long as they belong to numeric types (int/float/fraction except Decimal type) , these would still perform.

>>> import random
# weights being integers
>>> random.choices(["white", "green", "red"], [12, 12, 4], k=10)
['green', 'red', 'green', 'white', 'white', 'white', 'green', 'white', 'red', 'white']
# weights being floats
>>> random.choices(["white", "green", "red"], [.12, .12, .04], k=10)
['white', 'white', 'green', 'green', 'red', 'red', 'white', 'green', 'white', 'green']
# weights being fractions
>>> random.choices(["white", "green", "red"], [12/100, 12/100, 4/100], k=10)
['green', 'green', 'white', 'red', 'green', 'red', 'white', 'green', 'green', 'green']

2) If neither weights nor cum_weights are specified, selections are made with equal probability. If a weights sequence is supplied, it must be the same length as the population sequence.

Specifying both weights and cum_weights raises a TypeError.

>>> random.choices(["white", "green", "red"], k=10)
['white', 'white', 'green', 'red', 'red', 'red', 'white', 'white', 'white', 'green']

3) cum_weights are typically a result of itertools.accumulate function which are really handy in such situations.

From the documentation linked:

Internally, the relative weights are converted to cumulative weights before making selections, so supplying the cumulative weights saves work.

So, either supplying weights=[12, 12, 4] or cum_weights=[12, 24, 28] for our contrived case produces the same outcome and the latter seems to be more faster / efficient.


回答 8

这是Python 3.6标准库中包含的版本:

import itertools as _itertools
import bisect as _bisect

class Random36(random.Random):
    "Show the code included in the Python 3.6 version of the Random class"

    def choices(self, population, weights=None, *, cum_weights=None, k=1):
        """Return a k sized list of population elements chosen with replacement.

        If the relative weights or cumulative weights are not specified,
        the selections are made with equal probability.

        """
        random = self.random
        if cum_weights is None:
            if weights is None:
                _int = int
                total = len(population)
                return [population[_int(random() * total)] for i in range(k)]
            cum_weights = list(_itertools.accumulate(weights))
        elif weights is not None:
            raise TypeError('Cannot specify both weights and cumulative weights')
        if len(cum_weights) != len(population):
            raise ValueError('The number of weights does not match the population')
        bisect = _bisect.bisect
        total = cum_weights[-1]
        return [population[bisect(cum_weights, random() * total)] for i in range(k)]

来源:https : //hg.python.org/cpython/file/tip/Lib/random.py#l340

Here’s is the version that is being included in the standard library for Python 3.6:

import itertools as _itertools
import bisect as _bisect

class Random36(random.Random):
    "Show the code included in the Python 3.6 version of the Random class"

    def choices(self, population, weights=None, *, cum_weights=None, k=1):
        """Return a k sized list of population elements chosen with replacement.

        If the relative weights or cumulative weights are not specified,
        the selections are made with equal probability.

        """
        random = self.random
        if cum_weights is None:
            if weights is None:
                _int = int
                total = len(population)
                return [population[_int(random() * total)] for i in range(k)]
            cum_weights = list(_itertools.accumulate(weights))
        elif weights is not None:
            raise TypeError('Cannot specify both weights and cumulative weights')
        if len(cum_weights) != len(population):
            raise ValueError('The number of weights does not match the population')
        bisect = _bisect.bisect
        total = cum_weights[-1]
        return [population[bisect(cum_weights, random() * total)] for i in range(k)]

Source: https://hg.python.org/cpython/file/tip/Lib/random.py#l340


回答 9

import numpy as np
w=np.array([ 0.4,  0.8,  1.6,  0.8,  0.4])
np.random.choice(w, p=w/sum(w))
import numpy as np
w=np.array([ 0.4,  0.8,  1.6,  0.8,  0.4])
np.random.choice(w, p=w/sum(w))

回答 10

我可能为时已晚,无法提供任何有用的信息,但这是一个简单,简短且非常有效的代码段:

def choose_index(probabilies):
    cmf = probabilies[0]
    choice = random.random()
    for k in xrange(len(probabilies)):
        if choice <= cmf:
            return k
        else:
            cmf += probabilies[k+1]

无需对您的概率进行排序或使用cmf创建向量,并且一旦找到选择就终止。内存:O(1),时间:O(N),平均运行时间约为N / 2。

如果您有权重,只需添加一行:

def choose_index(weights):
    probabilities = weights / sum(weights)
    cmf = probabilies[0]
    choice = random.random()
    for k in xrange(len(probabilies)):
        if choice <= cmf:
            return k
        else:
            cmf += probabilies[k+1]

I’m probably too late to contribute anything useful, but here’s a simple, short, and very efficient snippet:

def choose_index(probabilies):
    cmf = probabilies[0]
    choice = random.random()
    for k in xrange(len(probabilies)):
        if choice <= cmf:
            return k
        else:
            cmf += probabilies[k+1]

No need to sort your probabilities or create a vector with your cmf, and it terminates once it finds its choice. Memory: O(1), time: O(N), with average running time ~ N/2.

If you have weights, simply add one line:

def choose_index(weights):
    probabilities = weights / sum(weights)
    cmf = probabilies[0]
    choice = random.random()
    for k in xrange(len(probabilies)):
        if choice <= cmf:
            return k
        else:
            cmf += probabilies[k+1]

回答 11

如果您的加权选择列表相对静态,并且您想要频繁采样,则可以执行一个O(N)预处理步骤,然后使用此相关答案中的函数在O(1)中进行选择。

# run only when `choices` changes.
preprocessed_data = prep(weight for _,weight in choices)

# O(1) selection
value = choices[sample(preprocessed_data)][0]

If your list of weighted choices is relatively static, and you want frequent sampling, you can do one O(N) preprocessing step, and then do the selection in O(1), using the functions in this related answer.

# run only when `choices` changes.
preprocessed_data = prep(weight for _,weight in choices)

# O(1) selection
value = choices[sample(preprocessed_data)][0]

回答 12

我查看了所指向的其他线程,并在我的编码样式中提出了此变体,它返回用于计算目的的选择索引,但是返回字符串很简单(注释返回替代):

import random
import bisect

try:
    range = xrange
except:
    pass

def weighted_choice(choices):
    total, cumulative = 0, []
    for c,w in choices:
        total += w
        cumulative.append((total, c))
    r = random.uniform(0, total)
    # return index
    return bisect.bisect(cumulative, (r,))
    # return item string
    #return choices[bisect.bisect(cumulative, (r,))][0]

# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]

tally = [0 for item in choices]

n = 100000
# tally up n weighted choices
for i in range(n):
    tally[weighted_choice(choices)] += 1

print([t/sum(tally)*100 for t in tally])

I looked the pointed other thread and came up with this variation in my coding style, this returns the index of choice for purpose of tallying, but it is simple to return the string ( commented return alternative):

import random
import bisect

try:
    range = xrange
except:
    pass

def weighted_choice(choices):
    total, cumulative = 0, []
    for c,w in choices:
        total += w
        cumulative.append((total, c))
    r = random.uniform(0, total)
    # return index
    return bisect.bisect(cumulative, (r,))
    # return item string
    #return choices[bisect.bisect(cumulative, (r,))][0]

# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]

tally = [0 for item in choices]

n = 100000
# tally up n weighted choices
for i in range(n):
    tally[weighted_choice(choices)] += 1

print([t/sum(tally)*100 for t in tally])

回答 13

这取决于您要对分布进行采样的次数。

假设您要采样K次分布。那么,np.random.choice()每次使用的时间复杂度是O(K(n + log(n)))when n是分发中项目的数量。

在我的情况下,我需要对同一分布进行多次采样,采样顺序为10 ^ 3,其中n为10 ^ 6。我使用了下面的代码,该代码预先计算了累积分布并将其采样到中O(log(n))。总时间复杂度为O(n+K*log(n))

import numpy as np

n,k = 10**6,10**3

# Create dummy distribution
a = np.array([i+1 for i in range(n)])
p = np.array([1.0/n]*n)

cfd = p.cumsum()
for _ in range(k):
    x = np.random.uniform()
    idx = cfd.searchsorted(x, side='right')
    sampled_element = a[idx]

It depends on how many times you want to sample the distribution.

Suppose you want to sample the distribution K times. Then, the time complexity using np.random.choice() each time is O(K(n + log(n))) when n is the number of items in the distribution.

In my case, I needed to sample the same distribution multiple times of the order of 10^3 where n is of the order of 10^6. I used the below code, which precomputes the cumulative distribution and samples it in O(log(n)). Overall time complexity is O(n+K*log(n)).

import numpy as np

n,k = 10**6,10**3

# Create dummy distribution
a = np.array([i+1 for i in range(n)])
p = np.array([1.0/n]*n)

cfd = p.cumsum()
for _ in range(k):
    x = np.random.uniform()
    idx = cfd.searchsorted(x, side='right')
    sampled_element = a[idx]

回答 14

通用解决方案:

import random
def weighted_choice(choices, weights):
    total = sum(weights)
    treshold = random.uniform(0, total)
    for k, weight in enumerate(weights):
        total -= weight
        if total < treshold:
            return choices[k]

A general solution:

import random
def weighted_choice(choices, weights):
    total = sum(weights)
    treshold = random.uniform(0, total)
    for k, weight in enumerate(weights):
        total -= weight
        if total < treshold:
            return choices[k]

回答 15

这是使用numpy的weighted_choice的另一个版本。传递权重向量,它将返回一个包含1的0数组,指示选择了哪个bin。该代码默认只进行一次抽奖,但是您可以传递要进行的抽奖次数,并且将返回每个抽奖箱的计数。

如果权重向量的总和不等于1,它将被归一化。

import numpy as np

def weighted_choice(weights, n=1):
    if np.sum(weights)!=1:
        weights = weights/np.sum(weights)

    draws = np.random.random_sample(size=n)

    weights = np.cumsum(weights)
    weights = np.insert(weights,0,0.0)

    counts = np.histogram(draws, bins=weights)
    return(counts[0])

Here is another version of weighted_choice that uses numpy. Pass in the weights vector and it will return an array of 0’s containing a 1 indicating which bin was chosen. The code defaults to just making a single draw but you can pass in the number of draws to be made and the counts per bin drawn will be returned.

If the weights vector does not sum to 1, it will be normalized so that it does.

import numpy as np

def weighted_choice(weights, n=1):
    if np.sum(weights)!=1:
        weights = weights/np.sum(weights)

    draws = np.random.random_sample(size=n)

    weights = np.cumsum(weights)
    weights = np.insert(weights,0,0.0)

    counts = np.histogram(draws, bins=weights)
    return(counts[0])

回答 16

假设我们的权重与元素数组中元素的索引相同,则这是另一种方法。

import numpy as np
weights = [0.1, 0.3, 0.5] #weights for the item at index 0,1,2
# sum of weights should be <=1, you can also divide each weight by sum of all weights to standardise it to <=1 constraint.
trials = 1 #number of trials
num_item = 1 #number of items that can be picked in each trial
selected_item_arr = np.random.multinomial(num_item, weights, trials)
# gives number of times an item was selected at a particular index
# this assumes selection with replacement
# one possible output
# selected_item_arr
# array([[0, 0, 1]])
# say if trials = 5, the the possible output could be 
# selected_item_arr
# array([[1, 0, 0],
#   [0, 0, 1],
#   [0, 0, 1],
#   [0, 1, 0],
#   [0, 0, 1]])

现在假设,我们必须在1次试用中抽取3个项目。您可以假设存在三个球R,G,B,它们的重量比是权重数组给出的重量之比,这可能是以下结果:

num_item = 3
trials = 1
selected_item_arr = np.random.multinomial(num_item, weights, trials)
# selected_item_arr can give output like :
# array([[1, 0, 2]])

您还可以将要选择的项目数视为一组中的二项式/多项式试验数。因此,以上示例仍可以按以下方式工作

num_binomial_trial = 5
weights = [0.1,0.9] #say an unfair coin weights for H/T
num_experiment_set = 1
selected_item_arr = np.random.multinomial(num_binomial_trial, weights, num_experiment_set)
# possible output
# selected_item_arr
# array([[1, 4]])
# i.e H came 1 time and T came 4 times in 5 binomial trials. And one set contains 5 binomial trails.

Another way of doing this, assuming we have weights at the same index as the elements in the element array.

import numpy as np
weights = [0.1, 0.3, 0.5] #weights for the item at index 0,1,2
# sum of weights should be <=1, you can also divide each weight by sum of all weights to standardise it to <=1 constraint.
trials = 1 #number of trials
num_item = 1 #number of items that can be picked in each trial
selected_item_arr = np.random.multinomial(num_item, weights, trials)
# gives number of times an item was selected at a particular index
# this assumes selection with replacement
# one possible output
# selected_item_arr
# array([[0, 0, 1]])
# say if trials = 5, the the possible output could be 
# selected_item_arr
# array([[1, 0, 0],
#   [0, 0, 1],
#   [0, 0, 1],
#   [0, 1, 0],
#   [0, 0, 1]])

Now let’s assume, we have to sample out 3 items in 1 trial. You can assume that there are three balls R,G,B present in large quantity in ratio of their weights given by weight array, the following could be possible outcome:

num_item = 3
trials = 1
selected_item_arr = np.random.multinomial(num_item, weights, trials)
# selected_item_arr can give output like :
# array([[1, 0, 2]])

you can also think number of items to be selected as number of binomial/ multinomial trials within a set. So, the above example can be still work as

num_binomial_trial = 5
weights = [0.1,0.9] #say an unfair coin weights for H/T
num_experiment_set = 1
selected_item_arr = np.random.multinomial(num_binomial_trial, weights, num_experiment_set)
# possible output
# selected_item_arr
# array([[1, 4]])
# i.e H came 1 time and T came 4 times in 5 binomial trials. And one set contains 5 binomial trails.

回答 17

Sebastien Thurn在免费的Udacity机器人技术类AI中对此进行了演讲。基本上,他使用mod运算符制作索引权重的圆形数组%,将变量beta设置为0,随机选择一个索引,通过N进行循环,其中N是索引数,并且在for循环中首先通过以下公式递增beta:

beta = beta +来自{0 … 2 * Weight_max}的均匀样本

然后嵌套在for循环中,下面是while循环:

while w[index] < beta:
    beta = beta - w[index]
    index = index + 1

select p[index]

然后转到下一个索引,以根据概率(或本类中提出的情况下的归一化概率)进行重新采样。

讲座链接:https : //classroom.udacity.com/courses/cs373/lessons/48704330/concepts/487480820923

我使用我的学校帐户登录到Udacity,因此,如果该链接不起作用,则是第8课,机器人人工智能的视频号码21,他正在讲解粒子过滤器。

There is lecture on this by Sebastien Thurn in the free Udacity course AI for Robotics. Basically he makes a circular array of the indexed weights using the mod operator %, sets a variable beta to 0, randomly chooses an index, for loops through N where N is the number of indices and in the for loop firstly increments beta by the formula:

beta = beta + uniform sample from {0…2* Weight_max}

and then nested in the for loop, a while loop per below:

while w[index] < beta:
    beta = beta - w[index]
    index = index + 1

select p[index]

Then on to the next index to resample based on the probabilities (or normalized probability in the case presented in the course).

The lecture link: https://classroom.udacity.com/courses/cs373/lessons/48704330/concepts/487480820923

I am logged into Udacity with my school account so if the link does not work, it is Lesson 8, video number 21 of Artificial Intelligence for Robotics where he is lecturing on particle filters.


回答 18

如果您碰巧拥有Python 3,并且害怕安装numpy或编写自己的循环,则可以执行以下操作:

import itertools, bisect, random

def weighted_choice(choices):
   weights = list(zip(*choices))[1]
   return choices[bisect.bisect(list(itertools.accumulate(weights)),
                                random.uniform(0, sum(weights)))][0]

因为您可以用一袋管道适配器来制造任何东西!尽管…我必须承认,内德的回答虽然稍长,但更容易理解。

If you happen to have Python 3, and are afraid of installing numpy or writing your own loops, you could do:

import itertools, bisect, random

def weighted_choice(choices):
   weights = list(zip(*choices))[1]
   return choices[bisect.bisect(list(itertools.accumulate(weights)),
                                random.uniform(0, sum(weights)))][0]

Because you can build anything out of a bag of plumbing adaptors! Although… I must admit that Ned’s answer, while slightly longer, is easier to understand.


回答 19

一种方法是对所有权重的总和进行随机化,然后将这些值用作每个变量的极限点。这是生成器的粗略实现。

def rand_weighted(weights):
    """
    Generator which uses the weights to generate a
    weighted random values
    """
    sum_weights = sum(weights.values())
    cum_weights = {}
    current_weight = 0
    for key, value in sorted(weights.iteritems()):
        current_weight += value
        cum_weights[key] = current_weight
    while True:
        sel = int(random.uniform(0, 1) * sum_weights)
        for key, value in sorted(cum_weights.iteritems()):
            if sel < value:
                break
        yield key

One way is to randomize on the total of all the weights and then use the values as the limit points for each var. Here is a crude implementation as a generator.

def rand_weighted(weights):
    """
    Generator which uses the weights to generate a
    weighted random values
    """
    sum_weights = sum(weights.values())
    cum_weights = {}
    current_weight = 0
    for key, value in sorted(weights.iteritems()):
        current_weight += value
        cum_weights[key] = current_weight
    while True:
        sel = int(random.uniform(0, 1) * sum_weights)
        for key, value in sorted(cum_weights.iteritems()):
            if sel < value:
                break
        yield key

回答 20

使用numpy

def choice(items, weights):
    return items[np.argmin((np.cumsum(weights) / sum(weights)) < np.random.rand())]

Using numpy

def choice(items, weights):
    return items[np.argmin((np.cumsum(weights) / sum(weights)) < np.random.rand())]

回答 21

我需要快速,非常简单地完成这样的工作,从寻找想法开始,我终于建立了这个模板。这个想法是从api接收json形式的加权值,这里是由dict模拟的。

然后将其转换为一个列表,其中每个值均按其权重成比例地重复,只需使用random.choice从列表中选择一个值即可。

我尝试了运行10、100和1000次迭代。分布似乎很稳定。

def weighted_choice(weighted_dict):
    """Input example: dict(apples=60, oranges=30, pineapples=10)"""
    weight_list = []
    for key in weighted_dict.keys():
        weight_list += [key] * weighted_dict[key]
    return random.choice(weight_list)

I needed to do something like this really fast really simple, from searching for ideas i finally built this template. The idea is receive the weighted values in a form of a json from the api, which here is simulated by the dict.

Then translate it into a list in which each value repeats proportionally to it’s weight, and just use random.choice to select a value from the list.

I tried it running with 10, 100 and 1000 iterations. The distribution seems pretty solid.

def weighted_choice(weighted_dict):
    """Input example: dict(apples=60, oranges=30, pineapples=10)"""
    weight_list = []
    for key in weighted_dict.keys():
        weight_list += [key] * weighted_dict[key]
    return random.choice(weight_list)

回答 22

我不喜欢那些语法。我真的只想指定项目是什么,每个项目的权重是什么。我意识到我可以使用,random.choices但我很快在下面编写了此类。

import random, string
from numpy import cumsum

class randomChoiceWithProportions:
    '''
    Accepts a dictionary of choices as keys and weights as values. Example if you want a unfair dice:


    choiceWeightDic = {"1":0.16666666666666666, "2": 0.16666666666666666, "3": 0.16666666666666666
    , "4": 0.16666666666666666, "5": .06666666666666666, "6": 0.26666666666666666}
    dice = randomChoiceWithProportions(choiceWeightDic)

    samples = []
    for i in range(100000):
        samples.append(dice.sample())

    # Should be close to .26666
    samples.count("6")/len(samples)

    # Should be close to .16666
    samples.count("1")/len(samples)
    '''
    def __init__(self, choiceWeightDic):
        self.choiceWeightDic = choiceWeightDic
        weightSum = sum(self.choiceWeightDic.values())
        assert weightSum == 1, 'Weights sum to ' + str(weightSum) + ', not 1.'
        self.valWeightDict = self._compute_valWeights()

    def _compute_valWeights(self):
        valWeights = list(cumsum(list(self.choiceWeightDic.values())))
        valWeightDict = dict(zip(list(self.choiceWeightDic.keys()), valWeights))
        return valWeightDict

    def sample(self):
        num = random.uniform(0,1)
        for key, val in self.valWeightDict.items():
            if val >= num:
                return key

I didn’t love the syntax of any of those. I really wanted to just specify what the items were and what the weighting of each was. I realize I could have used random.choices but instead I quickly wrote the class below.

import random, string
from numpy import cumsum

class randomChoiceWithProportions:
    '''
    Accepts a dictionary of choices as keys and weights as values. Example if you want a unfair dice:


    choiceWeightDic = {"1":0.16666666666666666, "2": 0.16666666666666666, "3": 0.16666666666666666
    , "4": 0.16666666666666666, "5": .06666666666666666, "6": 0.26666666666666666}
    dice = randomChoiceWithProportions(choiceWeightDic)

    samples = []
    for i in range(100000):
        samples.append(dice.sample())

    # Should be close to .26666
    samples.count("6")/len(samples)

    # Should be close to .16666
    samples.count("1")/len(samples)
    '''
    def __init__(self, choiceWeightDic):
        self.choiceWeightDic = choiceWeightDic
        weightSum = sum(self.choiceWeightDic.values())
        assert weightSum == 1, 'Weights sum to ' + str(weightSum) + ', not 1.'
        self.valWeightDict = self._compute_valWeights()

    def _compute_valWeights(self):
        valWeights = list(cumsum(list(self.choiceWeightDic.values())))
        valWeightDict = dict(zip(list(self.choiceWeightDic.keys()), valWeights))
        return valWeightDict

    def sample(self):
        num = random.uniform(0,1)
        for key, val in self.valWeightDict.items():
            if val >= num:
                return key

回答 23

为random.choice()提供预加权列表:

解决方案和测试:

import random

options = ['a', 'b', 'c', 'd']
weights = [1, 2, 5, 2]

weighted_options = [[opt]*wgt for opt, wgt in zip(options, weights)]
weighted_options = [opt for sublist in weighted_options for opt in sublist]
print(weighted_options)

# test

counts = {c: 0 for c in options}
for x in range(10000):
    counts[random.choice(weighted_options)] += 1

for opt, wgt in zip(options, weights):
    wgt_r = counts[opt] / 10000 * sum(weights)
    print(opt, counts[opt], wgt, wgt_r)

输出:

['a', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'd', 'd']
a 1025 1 1.025
b 1948 2 1.948
c 5019 5 5.019
d 2008 2 2.008

Provide random.choice() with a pre-weighted list:

Solution & Test:

import random

options = ['a', 'b', 'c', 'd']
weights = [1, 2, 5, 2]

weighted_options = [[opt]*wgt for opt, wgt in zip(options, weights)]
weighted_options = [opt for sublist in weighted_options for opt in sublist]
print(weighted_options)

# test

counts = {c: 0 for c in options}
for x in range(10000):
    counts[random.choice(weighted_options)] += 1

for opt, wgt in zip(options, weights):
    wgt_r = counts[opt] / 10000 * sum(weights)
    print(opt, counts[opt], wgt, wgt_r)

Output:

['a', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'd', 'd']
a 1025 1 1.025
b 1948 2 1.948
c 5019 5 5.019
d 2008 2 2.008

列出N以下所有素数的最快方法

问题:列出N以下所有素数的最快方法

这是我能想到的最好的算法。

def get_primes(n):
    numbers = set(range(n, 1, -1))
    primes = []
    while numbers:
        p = numbers.pop()
        primes.append(p)
        numbers.difference_update(set(range(p*2, n+1, p)))
    return primes

>>> timeit.Timer(stmt='get_primes.get_primes(1000000)', setup='import   get_primes').timeit(1)
1.1499958793645562

可以使它更快吗?

此代码有一个缺陷:由于numbers是无序集合,因此不能保证numbers.pop()从集合中删除最低的数字。但是,它对某些输入数字有效(至少对我而言):

>>> sum(get_primes(2000000))
142913828922L
#That's the correct sum of all numbers below 2 million
>>> 529 in get_primes(1000)
False
>>> 529 in get_primes(530)
True

This is the best algorithm I could come up.

def get_primes(n):
    numbers = set(range(n, 1, -1))
    primes = []
    while numbers:
        p = numbers.pop()
        primes.append(p)
        numbers.difference_update(set(range(p*2, n+1, p)))
    return primes

>>> timeit.Timer(stmt='get_primes.get_primes(1000000)', setup='import   get_primes').timeit(1)
1.1499958793645562

Can it be made even faster?

This code has a flaw: Since numbers is an unordered set, there is no guarantee that numbers.pop() will remove the lowest number from the set. Nevertheless, it works (at least for me) for some input numbers:

>>> sum(get_primes(2000000))
142913828922L
#That's the correct sum of all numbers below 2 million
>>> 529 in get_primes(1000)
False
>>> 529 in get_primes(530)
True

回答 0

警告: timeit由于硬件或Python版本的差异,结果可能会有所不同。

下面是一个脚本,比较了许多实现:

非常感谢斯蒂芬为使sieve_wheel_30引起我的注意。幸得罗伯特·威廉·汉克斯为primesfrom2to,primesfrom3to,rwh_primes,rwh_primes1和rwh_primes2。

使用psyco测试的简单Python方法中,对于n = 1000000, rwh_primes1是测试最快的方法。

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| rwh_primes1         | 43.0  |
| sieveOfAtkin        | 46.4  |
| rwh_primes          | 57.4  |
| sieve_wheel_30      | 63.0  |
| rwh_primes2         | 67.8  |    
| sieveOfEratosthenes | 147.0 |
| ambi_sieve_plain    | 152.0 |
| sundaram3           | 194.0 |
+---------------------+-------+

没有psyco的情况下经过测试的普通Python方法中,对于n = 1000000, rwh_primes2是最快的。

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| rwh_primes2         | 68.1  |
| rwh_primes1         | 93.7  |
| rwh_primes          | 94.6  |
| sieve_wheel_30      | 97.4  |
| sieveOfEratosthenes | 178.0 |
| ambi_sieve_plain    | 286.0 |
| sieveOfAtkin        | 314.0 |
| sundaram3           | 416.0 |
+---------------------+-------+

在所有测试的方法中,允许numpy,对于n = 1000000, primesfrom2to是测试最快的方法。

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| primesfrom2to       | 15.9  |
| primesfrom3to       | 18.4  |
| ambi_sieve          | 29.3  |
+---------------------+-------+

使用以下命令测量时间:

python -mtimeit -s"import primes" "primes.{method}(1000000)"

{method}每个方法名称替换。

primes.py:

#!/usr/bin/env python
import psyco; psyco.full()
from math import sqrt, ceil
import numpy as np

def rwh_primes(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i]:
            sieve[i*i::2*i]=[False]*((n-i*i-1)/(2*i)+1)
    return [2] + [i for i in xrange(3,n,2) if sieve[i]]

def rwh_primes1(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * (n/2)
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i/2]:
            sieve[i*i/2::i] = [False] * ((n-i*i-1)/(2*i)+1)
    return [2] + [2*i+1 for i in xrange(1,n/2) if sieve[i]]

def rwh_primes2(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    correction = (n%6>1)
    n = {0:n,1:n-1,2:n+4,3:n+3,4:n+2,5:n+1}[n%6]
    sieve = [True] * (n/3)
    sieve[0] = False
    for i in xrange(int(n**0.5)/3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      ((k*k)/3)      ::2*k]=[False]*((n/6-(k*k)/6-1)/k+1)
        sieve[(k*k+4*k-2*k*(i&1))/3::2*k]=[False]*((n/6-(k*k+4*k-2*k*(i&1))/6-1)/k+1)
    return [2,3] + [3*i+1|1 for i in xrange(1,n/3-correction) if sieve[i]]

def sieve_wheel_30(N):
    # http://zerovolt.com/?p=88
    ''' Returns a list of primes <= N using wheel criterion 2*3*5 = 30

Copyright 2009 by zerovolt.com
This code is free for non-commercial purposes, in which case you can just leave this comment as a credit for my work.
If you need this code for commercial purposes, please contact me by sending an email to: info [at] zerovolt [dot] com.'''
    __smallp = ( 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
    61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139,
    149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227,
    229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311,
    313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401,
    409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491,
    499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599,
    601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683,
    691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797,
    809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887,
    907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997)

    wheel = (2, 3, 5)
    const = 30
    if N < 2:
        return []
    if N <= const:
        pos = 0
        while __smallp[pos] <= N:
            pos += 1
        return list(__smallp[:pos])
    # make the offsets list
    offsets = (7, 11, 13, 17, 19, 23, 29, 1)
    # prepare the list
    p = [2, 3, 5]
    dim = 2 + N // const
    tk1  = [True] * dim
    tk7  = [True] * dim
    tk11 = [True] * dim
    tk13 = [True] * dim
    tk17 = [True] * dim
    tk19 = [True] * dim
    tk23 = [True] * dim
    tk29 = [True] * dim
    tk1[0] = False
    # help dictionary d
    # d[a , b] = c  ==> if I want to find the smallest useful multiple of (30*pos)+a
    # on tkc, then I need the index given by the product of [(30*pos)+a][(30*pos)+b]
    # in general. If b < a, I need [(30*pos)+a][(30*(pos+1))+b]
    d = {}
    for x in offsets:
        for y in offsets:
            res = (x*y) % const
            if res in offsets:
                d[(x, res)] = y
    # another help dictionary: gives tkx calling tmptk[x]
    tmptk = {1:tk1, 7:tk7, 11:tk11, 13:tk13, 17:tk17, 19:tk19, 23:tk23, 29:tk29}
    pos, prime, lastadded, stop = 0, 0, 0, int(ceil(sqrt(N)))
    # inner functions definition
    def del_mult(tk, start, step):
        for k in xrange(start, len(tk), step):
            tk[k] = False
    # end of inner functions definition
    cpos = const * pos
    while prime < stop:
        # 30k + 7
        if tk7[pos]:
            prime = cpos + 7
            p.append(prime)
            lastadded = 7
            for off in offsets:
                tmp = d[(7, off)]
                start = (pos + prime) if off == 7 else (prime * (const * (pos + 1 if tmp < 7 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 11
        if tk11[pos]:
            prime = cpos + 11
            p.append(prime)
            lastadded = 11
            for off in offsets:
                tmp = d[(11, off)]
                start = (pos + prime) if off == 11 else (prime * (const * (pos + 1 if tmp < 11 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 13
        if tk13[pos]:
            prime = cpos + 13
            p.append(prime)
            lastadded = 13
            for off in offsets:
                tmp = d[(13, off)]
                start = (pos + prime) if off == 13 else (prime * (const * (pos + 1 if tmp < 13 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 17
        if tk17[pos]:
            prime = cpos + 17
            p.append(prime)
            lastadded = 17
            for off in offsets:
                tmp = d[(17, off)]
                start = (pos + prime) if off == 17 else (prime * (const * (pos + 1 if tmp < 17 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 19
        if tk19[pos]:
            prime = cpos + 19
            p.append(prime)
            lastadded = 19
            for off in offsets:
                tmp = d[(19, off)]
                start = (pos + prime) if off == 19 else (prime * (const * (pos + 1 if tmp < 19 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 23
        if tk23[pos]:
            prime = cpos + 23
            p.append(prime)
            lastadded = 23
            for off in offsets:
                tmp = d[(23, off)]
                start = (pos + prime) if off == 23 else (prime * (const * (pos + 1 if tmp < 23 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 29
        if tk29[pos]:
            prime = cpos + 29
            p.append(prime)
            lastadded = 29
            for off in offsets:
                tmp = d[(29, off)]
                start = (pos + prime) if off == 29 else (prime * (const * (pos + 1 if tmp < 29 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # now we go back to top tk1, so we need to increase pos by 1
        pos += 1
        cpos = const * pos
        # 30k + 1
        if tk1[pos]:
            prime = cpos + 1
            p.append(prime)
            lastadded = 1
            for off in offsets:
                tmp = d[(1, off)]
                start = (pos + prime) if off == 1 else (prime * (const * pos + tmp) )//const
                del_mult(tmptk[off], start, prime)
    # time to add remaining primes
    # if lastadded == 1, remove last element and start adding them from tk1
    # this way we don't need an "if" within the last while
    if lastadded == 1:
        p.pop()
    # now complete for every other possible prime
    while pos < len(tk1):
        cpos = const * pos
        if tk1[pos]: p.append(cpos + 1)
        if tk7[pos]: p.append(cpos + 7)
        if tk11[pos]: p.append(cpos + 11)
        if tk13[pos]: p.append(cpos + 13)
        if tk17[pos]: p.append(cpos + 17)
        if tk19[pos]: p.append(cpos + 19)
        if tk23[pos]: p.append(cpos + 23)
        if tk29[pos]: p.append(cpos + 29)
        pos += 1
    # remove exceeding if present
    pos = len(p) - 1
    while p[pos] > N:
        pos -= 1
    if pos < len(p) - 1:
        del p[pos+1:]
    # return p list
    return p

def sieveOfEratosthenes(n):
    """sieveOfEratosthenes(n): return the list of the primes < n."""
    # Code from: <dickinsm@gmail.com>, Nov 30 2006
    # http://groups.google.com/group/comp.lang.python/msg/f1f10ced88c68c2d
    if n <= 2:
        return []
    sieve = range(3, n, 2)
    top = len(sieve)
    for si in sieve:
        if si:
            bottom = (si*si - 3) // 2
            if bottom >= top:
                break
            sieve[bottom::si] = [0] * -((bottom - top) // si)
    return [2] + [el for el in sieve if el]

def sieveOfAtkin(end):
    """sieveOfAtkin(end): return a list of all the prime numbers <end
    using the Sieve of Atkin."""
    # Code by Steve Krenzel, <Sgk284@gmail.com>, improved
    # Code: https://web.archive.org/web/20080324064651/http://krenzel.info/?p=83
    # Info: http://en.wikipedia.org/wiki/Sieve_of_Atkin
    assert end > 0
    lng = ((end-1) // 2)
    sieve = [False] * (lng + 1)

    x_max, x2, xd = int(sqrt((end-1)/4.0)), 0, 4
    for xd in xrange(4, 8*x_max + 2, 8):
        x2 += xd
        y_max = int(sqrt(end-x2))
        n, n_diff = x2 + y_max*y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in xrange((n_diff - 1) << 1, -1, -8):
            m = n % 12
            if m == 1 or m == 5:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, x2, xd = int(sqrt((end-1) / 3.0)), 0, 3
    for xd in xrange(3, 6 * x_max + 2, 6):
        x2 += xd
        y_max = int(sqrt(end-x2))
        n, n_diff = x2 + y_max*y_max, (y_max << 1) - 1
        if not(n & 1):
            n -= n_diff
            n_diff -= 2
        for d in xrange((n_diff - 1) << 1, -1, -8):
            if n % 12 == 7:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, y_min, x2, xd = int((2 + sqrt(4-8*(1-end)))/4), -1, 0, 3
    for x in xrange(1, x_max + 1):
        x2 += xd
        xd += 6
        if x2 >= end: y_min = (((int(ceil(sqrt(x2 - end))) - 1) << 1) - 2) << 1
        n, n_diff = ((x*x + x) << 1) - 1, (((x-1) << 1) - 2) << 1
        for d in xrange(n_diff, y_min, -8):
            if n % 12 == 11:
                m = n >> 1
                sieve[m] = not sieve[m]
            n += d

    primes = [2, 3]
    if end <= 3:
        return primes[:max(0,end-2)]

    for n in xrange(5 >> 1, (int(sqrt(end))+1) >> 1):
        if sieve[n]:
            primes.append((n << 1) + 1)
            aux = (n << 1) + 1
            aux *= aux
            for k in xrange(aux, end, 2 * aux):
                sieve[k >> 1] = False

    s  = int(sqrt(end)) + 1
    if s  % 2 == 0:
        s += 1
    primes.extend([i for i in xrange(s, end, 2) if sieve[i >> 1]])

    return primes

def ambi_sieve_plain(n):
    s = range(3, n, 2)
    for m in xrange(3, int(n**0.5)+1, 2): 
        if s[(m-3)/2]: 
            for t in xrange((m*m-3)/2,(n>>1)-1,m):
                s[t]=0
    return [2]+[t for t in s if t>0]

def sundaram3(max_n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/2073279#2073279
    numbers = range(3, max_n+1, 2)
    half = (max_n)//2
    initial = 4

    for step in xrange(3, max_n+1, 2):
        for i in xrange(initial, half, step):
            numbers[i-1] = 0
        initial += 2*(step+1)

        if initial > half:
            return [2] + filter(None, numbers)

################################################################################
# Using Numpy:
def ambi_sieve(n):
    # http://tommih.blogspot.com/2009/04/fast-prime-number-generator.html
    s = np.arange(3, n, 2)
    for m in xrange(3, int(n ** 0.5)+1, 2): 
        if s[(m-3)/2]: 
            s[(m*m-3)/2::m]=0
    return np.r_[2, s[s>0]]

def primesfrom3to(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns a array of primes, p < n """
    assert n>=2
    sieve = np.ones(n/2, dtype=np.bool)
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i/2]:
            sieve[i*i/2::i] = False
    return np.r_[2, 2*np.nonzero(sieve)[0][1::]+1]    

def primesfrom2to(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    sieve = np.ones(n/3 + (n%6==2), dtype=np.bool)
    sieve[0] = False
    for i in xrange(int(n**0.5)/3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[      ((k*k)/3)      ::2*k] = False
            sieve[(k*k+4*k-2*k*(i&1))/3::2*k] = False
    return np.r_[2,3,((3*np.nonzero(sieve)[0]+1)|1)]

if __name__=='__main__':
    import itertools
    import sys

    def test(f1,f2,num):
        print('Testing {f1} and {f2} return same results'.format(
            f1=f1.func_name,
            f2=f2.func_name))
        if not all([a==b for a,b in itertools.izip_longest(f1(num),f2(num))]):
            sys.exit("Error: %s(%s) != %s(%s)"%(f1.func_name,num,f2.func_name,num))

    n=1000000
    test(sieveOfAtkin,sieveOfEratosthenes,n)
    test(sieveOfAtkin,ambi_sieve,n)
    test(sieveOfAtkin,ambi_sieve_plain,n) 
    test(sieveOfAtkin,sundaram3,n)
    test(sieveOfAtkin,sieve_wheel_30,n)
    test(sieveOfAtkin,primesfrom3to,n)
    test(sieveOfAtkin,primesfrom2to,n)
    test(sieveOfAtkin,rwh_primes,n)
    test(sieveOfAtkin,rwh_primes1,n)         
    test(sieveOfAtkin,rwh_primes2,n)

运行脚本会测试所有实现都给出相同的结果。

Warning: timeit results may vary due to differences in hardware or version of Python.

Below is a script which compares a number of implementations:

Many thanks to stephan for bringing sieve_wheel_30 to my attention. Credit goes to Robert William Hanks for primesfrom2to, primesfrom3to, rwh_primes, rwh_primes1, and rwh_primes2.

Of the plain Python methods tested, with psyco, for n=1000000, rwh_primes1 was the fastest tested.

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| rwh_primes1         | 43.0  |
| sieveOfAtkin        | 46.4  |
| rwh_primes          | 57.4  |
| sieve_wheel_30      | 63.0  |
| rwh_primes2         | 67.8  |    
| sieveOfEratosthenes | 147.0 |
| ambi_sieve_plain    | 152.0 |
| sundaram3           | 194.0 |
+---------------------+-------+

Of the plain Python methods tested, without psyco, for n=1000000, rwh_primes2 was the fastest.

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| rwh_primes2         | 68.1  |
| rwh_primes1         | 93.7  |
| rwh_primes          | 94.6  |
| sieve_wheel_30      | 97.4  |
| sieveOfEratosthenes | 178.0 |
| ambi_sieve_plain    | 286.0 |
| sieveOfAtkin        | 314.0 |
| sundaram3           | 416.0 |
+---------------------+-------+

Of all the methods tested, allowing numpy, for n=1000000, primesfrom2to was the fastest tested.

+---------------------+-------+
| Method              | ms    |
+---------------------+-------+
| primesfrom2to       | 15.9  |
| primesfrom3to       | 18.4  |
| ambi_sieve          | 29.3  |
+---------------------+-------+

Timings were measured using the command:

python -mtimeit -s"import primes" "primes.{method}(1000000)"

with {method} replaced by each of the method names.

primes.py:

#!/usr/bin/env python
import psyco; psyco.full()
from math import sqrt, ceil
import numpy as np

def rwh_primes(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i]:
            sieve[i*i::2*i]=[False]*((n-i*i-1)/(2*i)+1)
    return [2] + [i for i in xrange(3,n,2) if sieve[i]]

def rwh_primes1(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * (n/2)
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i/2]:
            sieve[i*i/2::i] = [False] * ((n-i*i-1)/(2*i)+1)
    return [2] + [2*i+1 for i in xrange(1,n/2) if sieve[i]]

def rwh_primes2(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    correction = (n%6>1)
    n = {0:n,1:n-1,2:n+4,3:n+3,4:n+2,5:n+1}[n%6]
    sieve = [True] * (n/3)
    sieve[0] = False
    for i in xrange(int(n**0.5)/3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      ((k*k)/3)      ::2*k]=[False]*((n/6-(k*k)/6-1)/k+1)
        sieve[(k*k+4*k-2*k*(i&1))/3::2*k]=[False]*((n/6-(k*k+4*k-2*k*(i&1))/6-1)/k+1)
    return [2,3] + [3*i+1|1 for i in xrange(1,n/3-correction) if sieve[i]]

def sieve_wheel_30(N):
    # http://zerovolt.com/?p=88
    ''' Returns a list of primes <= N using wheel criterion 2*3*5 = 30

Copyright 2009 by zerovolt.com
This code is free for non-commercial purposes, in which case you can just leave this comment as a credit for my work.
If you need this code for commercial purposes, please contact me by sending an email to: info [at] zerovolt [dot] com.'''
    __smallp = ( 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
    61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139,
    149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227,
    229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311,
    313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401,
    409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491,
    499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599,
    601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683,
    691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797,
    809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887,
    907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997)

    wheel = (2, 3, 5)
    const = 30
    if N < 2:
        return []
    if N <= const:
        pos = 0
        while __smallp[pos] <= N:
            pos += 1
        return list(__smallp[:pos])
    # make the offsets list
    offsets = (7, 11, 13, 17, 19, 23, 29, 1)
    # prepare the list
    p = [2, 3, 5]
    dim = 2 + N // const
    tk1  = [True] * dim
    tk7  = [True] * dim
    tk11 = [True] * dim
    tk13 = [True] * dim
    tk17 = [True] * dim
    tk19 = [True] * dim
    tk23 = [True] * dim
    tk29 = [True] * dim
    tk1[0] = False
    # help dictionary d
    # d[a , b] = c  ==> if I want to find the smallest useful multiple of (30*pos)+a
    # on tkc, then I need the index given by the product of [(30*pos)+a][(30*pos)+b]
    # in general. If b < a, I need [(30*pos)+a][(30*(pos+1))+b]
    d = {}
    for x in offsets:
        for y in offsets:
            res = (x*y) % const
            if res in offsets:
                d[(x, res)] = y
    # another help dictionary: gives tkx calling tmptk[x]
    tmptk = {1:tk1, 7:tk7, 11:tk11, 13:tk13, 17:tk17, 19:tk19, 23:tk23, 29:tk29}
    pos, prime, lastadded, stop = 0, 0, 0, int(ceil(sqrt(N)))
    # inner functions definition
    def del_mult(tk, start, step):
        for k in xrange(start, len(tk), step):
            tk[k] = False
    # end of inner functions definition
    cpos = const * pos
    while prime < stop:
        # 30k + 7
        if tk7[pos]:
            prime = cpos + 7
            p.append(prime)
            lastadded = 7
            for off in offsets:
                tmp = d[(7, off)]
                start = (pos + prime) if off == 7 else (prime * (const * (pos + 1 if tmp < 7 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 11
        if tk11[pos]:
            prime = cpos + 11
            p.append(prime)
            lastadded = 11
            for off in offsets:
                tmp = d[(11, off)]
                start = (pos + prime) if off == 11 else (prime * (const * (pos + 1 if tmp < 11 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 13
        if tk13[pos]:
            prime = cpos + 13
            p.append(prime)
            lastadded = 13
            for off in offsets:
                tmp = d[(13, off)]
                start = (pos + prime) if off == 13 else (prime * (const * (pos + 1 if tmp < 13 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 17
        if tk17[pos]:
            prime = cpos + 17
            p.append(prime)
            lastadded = 17
            for off in offsets:
                tmp = d[(17, off)]
                start = (pos + prime) if off == 17 else (prime * (const * (pos + 1 if tmp < 17 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 19
        if tk19[pos]:
            prime = cpos + 19
            p.append(prime)
            lastadded = 19
            for off in offsets:
                tmp = d[(19, off)]
                start = (pos + prime) if off == 19 else (prime * (const * (pos + 1 if tmp < 19 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 23
        if tk23[pos]:
            prime = cpos + 23
            p.append(prime)
            lastadded = 23
            for off in offsets:
                tmp = d[(23, off)]
                start = (pos + prime) if off == 23 else (prime * (const * (pos + 1 if tmp < 23 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # 30k + 29
        if tk29[pos]:
            prime = cpos + 29
            p.append(prime)
            lastadded = 29
            for off in offsets:
                tmp = d[(29, off)]
                start = (pos + prime) if off == 29 else (prime * (const * (pos + 1 if tmp < 29 else 0) + tmp) )//const
                del_mult(tmptk[off], start, prime)
        # now we go back to top tk1, so we need to increase pos by 1
        pos += 1
        cpos = const * pos
        # 30k + 1
        if tk1[pos]:
            prime = cpos + 1
            p.append(prime)
            lastadded = 1
            for off in offsets:
                tmp = d[(1, off)]
                start = (pos + prime) if off == 1 else (prime * (const * pos + tmp) )//const
                del_mult(tmptk[off], start, prime)
    # time to add remaining primes
    # if lastadded == 1, remove last element and start adding them from tk1
    # this way we don't need an "if" within the last while
    if lastadded == 1:
        p.pop()
    # now complete for every other possible prime
    while pos < len(tk1):
        cpos = const * pos
        if tk1[pos]: p.append(cpos + 1)
        if tk7[pos]: p.append(cpos + 7)
        if tk11[pos]: p.append(cpos + 11)
        if tk13[pos]: p.append(cpos + 13)
        if tk17[pos]: p.append(cpos + 17)
        if tk19[pos]: p.append(cpos + 19)
        if tk23[pos]: p.append(cpos + 23)
        if tk29[pos]: p.append(cpos + 29)
        pos += 1
    # remove exceeding if present
    pos = len(p) - 1
    while p[pos] > N:
        pos -= 1
    if pos < len(p) - 1:
        del p[pos+1:]
    # return p list
    return p

def sieveOfEratosthenes(n):
    """sieveOfEratosthenes(n): return the list of the primes < n."""
    # Code from: <dickinsm@gmail.com>, Nov 30 2006
    # http://groups.google.com/group/comp.lang.python/msg/f1f10ced88c68c2d
    if n <= 2:
        return []
    sieve = range(3, n, 2)
    top = len(sieve)
    for si in sieve:
        if si:
            bottom = (si*si - 3) // 2
            if bottom >= top:
                break
            sieve[bottom::si] = [0] * -((bottom - top) // si)
    return [2] + [el for el in sieve if el]

def sieveOfAtkin(end):
    """sieveOfAtkin(end): return a list of all the prime numbers <end
    using the Sieve of Atkin."""
    # Code by Steve Krenzel, <Sgk284@gmail.com>, improved
    # Code: https://web.archive.org/web/20080324064651/http://krenzel.info/?p=83
    # Info: http://en.wikipedia.org/wiki/Sieve_of_Atkin
    assert end > 0
    lng = ((end-1) // 2)
    sieve = [False] * (lng + 1)

    x_max, x2, xd = int(sqrt((end-1)/4.0)), 0, 4
    for xd in xrange(4, 8*x_max + 2, 8):
        x2 += xd
        y_max = int(sqrt(end-x2))
        n, n_diff = x2 + y_max*y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in xrange((n_diff - 1) << 1, -1, -8):
            m = n % 12
            if m == 1 or m == 5:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, x2, xd = int(sqrt((end-1) / 3.0)), 0, 3
    for xd in xrange(3, 6 * x_max + 2, 6):
        x2 += xd
        y_max = int(sqrt(end-x2))
        n, n_diff = x2 + y_max*y_max, (y_max << 1) - 1
        if not(n & 1):
            n -= n_diff
            n_diff -= 2
        for d in xrange((n_diff - 1) << 1, -1, -8):
            if n % 12 == 7:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, y_min, x2, xd = int((2 + sqrt(4-8*(1-end)))/4), -1, 0, 3
    for x in xrange(1, x_max + 1):
        x2 += xd
        xd += 6
        if x2 >= end: y_min = (((int(ceil(sqrt(x2 - end))) - 1) << 1) - 2) << 1
        n, n_diff = ((x*x + x) << 1) - 1, (((x-1) << 1) - 2) << 1
        for d in xrange(n_diff, y_min, -8):
            if n % 12 == 11:
                m = n >> 1
                sieve[m] = not sieve[m]
            n += d

    primes = [2, 3]
    if end <= 3:
        return primes[:max(0,end-2)]

    for n in xrange(5 >> 1, (int(sqrt(end))+1) >> 1):
        if sieve[n]:
            primes.append((n << 1) + 1)
            aux = (n << 1) + 1
            aux *= aux
            for k in xrange(aux, end, 2 * aux):
                sieve[k >> 1] = False

    s  = int(sqrt(end)) + 1
    if s  % 2 == 0:
        s += 1
    primes.extend([i for i in xrange(s, end, 2) if sieve[i >> 1]])

    return primes

def ambi_sieve_plain(n):
    s = range(3, n, 2)
    for m in xrange(3, int(n**0.5)+1, 2): 
        if s[(m-3)/2]: 
            for t in xrange((m*m-3)/2,(n>>1)-1,m):
                s[t]=0
    return [2]+[t for t in s if t>0]

def sundaram3(max_n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/2073279#2073279
    numbers = range(3, max_n+1, 2)
    half = (max_n)//2
    initial = 4

    for step in xrange(3, max_n+1, 2):
        for i in xrange(initial, half, step):
            numbers[i-1] = 0
        initial += 2*(step+1)

        if initial > half:
            return [2] + filter(None, numbers)

################################################################################
# Using Numpy:
def ambi_sieve(n):
    # http://tommih.blogspot.com/2009/04/fast-prime-number-generator.html
    s = np.arange(3, n, 2)
    for m in xrange(3, int(n ** 0.5)+1, 2): 
        if s[(m-3)/2]: 
            s[(m*m-3)/2::m]=0
    return np.r_[2, s[s>0]]

def primesfrom3to(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns a array of primes, p < n """
    assert n>=2
    sieve = np.ones(n/2, dtype=np.bool)
    for i in xrange(3,int(n**0.5)+1,2):
        if sieve[i/2]:
            sieve[i*i/2::i] = False
    return np.r_[2, 2*np.nonzero(sieve)[0][1::]+1]    

def primesfrom2to(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    sieve = np.ones(n/3 + (n%6==2), dtype=np.bool)
    sieve[0] = False
    for i in xrange(int(n**0.5)/3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[      ((k*k)/3)      ::2*k] = False
            sieve[(k*k+4*k-2*k*(i&1))/3::2*k] = False
    return np.r_[2,3,((3*np.nonzero(sieve)[0]+1)|1)]

if __name__=='__main__':
    import itertools
    import sys

    def test(f1,f2,num):
        print('Testing {f1} and {f2} return same results'.format(
            f1=f1.func_name,
            f2=f2.func_name))
        if not all([a==b for a,b in itertools.izip_longest(f1(num),f2(num))]):
            sys.exit("Error: %s(%s) != %s(%s)"%(f1.func_name,num,f2.func_name,num))

    n=1000000
    test(sieveOfAtkin,sieveOfEratosthenes,n)
    test(sieveOfAtkin,ambi_sieve,n)
    test(sieveOfAtkin,ambi_sieve_plain,n) 
    test(sieveOfAtkin,sundaram3,n)
    test(sieveOfAtkin,sieve_wheel_30,n)
    test(sieveOfAtkin,primesfrom3to,n)
    test(sieveOfAtkin,primesfrom2to,n)
    test(sieveOfAtkin,rwh_primes,n)
    test(sieveOfAtkin,rwh_primes1,n)         
    test(sieveOfAtkin,rwh_primes2,n)

Running the script tests that all implementations give the same result.


回答 1

更快,更明智的纯Python代码:

def primes(n):
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i]:
            sieve[i*i::2*i]=[False]*((n-i*i-1)//(2*i)+1)
    return [2] + [i for i in range(3,n,2) if sieve[i]]

或从半筛开始

def primes1(n):
    """ Returns  a list of primes < n """
    sieve = [True] * (n//2)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = [False] * ((n-i*i-1)//(2*i)+1)
    return [2] + [2*i+1 for i in range(1,n//2) if sieve[i]]

更快,更明智的内存numpy代码:

import numpy
def primesfrom3to(n):
    """ Returns a array of primes, 3 <= p < n """
    sieve = numpy.ones(n//2, dtype=numpy.bool)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = False
    return 2*numpy.nonzero(sieve)[0][1::]+1

从三分之一的筛子开始的更快的变化:

import numpy
def primesfrom2to(n):
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    sieve = numpy.ones(n//3 + (n%6==2), dtype=numpy.bool)
    for i in range(1,int(n**0.5)//3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[       k*k//3     ::2*k] = False
            sieve[k*(k-2*(i&1)+4)//3::2*k] = False
    return numpy.r_[2,3,((3*numpy.nonzero(sieve)[0][1:]+1)|1)]

上述代码的(难编码)纯python版本为:

def primes2(n):
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    n, correction = n-n%6+6, 2-(n%6>1)
    sieve = [True] * (n//3)
    for i in range(1,int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      k*k//3      ::2*k] = [False] * ((n//6-k*k//6-1)//k+1)
        sieve[k*(k-2*(i&1)+4)//3::2*k] = [False] * ((n//6-k*(k-2*(i&1)+4)//6-1)//k+1)
    return [2,3] + [3*i+1|1 for i in range(1,n//3-correction) if sieve[i]]

不幸的是,纯python并没有采用更简单,更快速的numpy方式进行赋值,并且len()像in [False]*len(sieve[((k*k)//3)::2*k])中那样在循环内调用太慢了。因此,我不得不即兴改正输入(避免更多的数学运算),并做一些极端的(令人痛苦的)数学魔术。

我个人认为numpy(已被广泛使用)不是Python标准库的一部分,而Python开发人员似乎完全忽略了语法和速度方面的改进,这是一个遗憾。

Faster & more memory-wise pure Python code:

def primes(n):
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i]:
            sieve[i*i::2*i]=[False]*((n-i*i-1)//(2*i)+1)
    return [2] + [i for i in range(3,n,2) if sieve[i]]

or starting with half sieve

def primes1(n):
    """ Returns  a list of primes < n """
    sieve = [True] * (n//2)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = [False] * ((n-i*i-1)//(2*i)+1)
    return [2] + [2*i+1 for i in range(1,n//2) if sieve[i]]

Faster & more memory-wise numpy code:

import numpy
def primesfrom3to(n):
    """ Returns a array of primes, 3 <= p < n """
    sieve = numpy.ones(n//2, dtype=numpy.bool)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = False
    return 2*numpy.nonzero(sieve)[0][1::]+1

a faster variation starting with a third of a sieve:

import numpy
def primesfrom2to(n):
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    sieve = numpy.ones(n//3 + (n%6==2), dtype=numpy.bool)
    for i in range(1,int(n**0.5)//3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[       k*k//3     ::2*k] = False
            sieve[k*(k-2*(i&1)+4)//3::2*k] = False
    return numpy.r_[2,3,((3*numpy.nonzero(sieve)[0][1:]+1)|1)]

A (hard-to-code) pure-python version of the above code would be:

def primes2(n):
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    n, correction = n-n%6+6, 2-(n%6>1)
    sieve = [True] * (n//3)
    for i in range(1,int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      k*k//3      ::2*k] = [False] * ((n//6-k*k//6-1)//k+1)
        sieve[k*(k-2*(i&1)+4)//3::2*k] = [False] * ((n//6-k*(k-2*(i&1)+4)//6-1)//k+1)
    return [2,3] + [3*i+1|1 for i in range(1,n//3-correction) if sieve[i]]

Unfortunately pure-python don’t adopt the simpler and faster numpy way of doing assignment, and calling len() inside the loop as in [False]*len(sieve[((k*k)//3)::2*k]) is too slow. So I had to improvise to correct input (& avoid more math) and do some extreme (& painful) math-magic.

Personally I think it is a shame that numpy (which is so widely used) is not part of Python standard library, and that the improvements in syntax and speed seem to be completely overlooked by Python developers.


回答 2

有从Python食谱一个漂亮整洁的样品这里 -建议对URL最快的版本是:

import itertools
def erat2( ):
    D = {  }
    yield 2
    for q in itertools.islice(itertools.count(3), 0, None, 2):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            x = p + q
            while x in D or not (x&1):
                x += p
            D[x] = p

这样就可以

def get_primes_erat(n):
  return list(itertools.takewhile(lambda p: p<n, erat2()))

使用pri.py中的这段代码在shell提示符下(我喜欢这样做)进行测量,我观察到:

$ python2.5 -mtimeit -s'import pri' 'pri.get_primes(1000000)'
10 loops, best of 3: 1.69 sec per loop
$ python2.5 -mtimeit -s'import pri' 'pri.get_primes_erat(1000000)'
10 loops, best of 3: 673 msec per loop

因此,看起来Cookbook解决方案的速度是以前的两倍。

There’s a pretty neat sample from the Python Cookbook here — the fastest version proposed on that URL is:

import itertools
def erat2( ):
    D = {  }
    yield 2
    for q in itertools.islice(itertools.count(3), 0, None, 2):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            x = p + q
            while x in D or not (x&1):
                x += p
            D[x] = p

so that would give

def get_primes_erat(n):
  return list(itertools.takewhile(lambda p: p<n, erat2()))

Measuring at the shell prompt (as I prefer to do) with this code in pri.py, I observe:

$ python2.5 -mtimeit -s'import pri' 'pri.get_primes(1000000)'
10 loops, best of 3: 1.69 sec per loop
$ python2.5 -mtimeit -s'import pri' 'pri.get_primes_erat(1000000)'
10 loops, best of 3: 673 msec per loop

so it looks like the Cookbook solution is over twice as fast.


回答 3

我认为使用Sundaram的Sieve打破了纯Python的记录:

def sundaram3(max_n):
    numbers = range(3, max_n+1, 2)
    half = (max_n)//2
    initial = 4

    for step in xrange(3, max_n+1, 2):
        for i in xrange(initial, half, step):
            numbers[i-1] = 0
        initial += 2*(step+1)

        if initial > half:
            return [2] + filter(None, numbers)

对比:

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.get_primes_erat(1000000)"
10 loops, best of 3: 710 msec per loop

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.daniel_sieve_2(1000000)"
10 loops, best of 3: 435 msec per loop

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.sundaram3(1000000)"
10 loops, best of 3: 327 msec per loop

Using Sundaram’s Sieve, I think I broke pure-Python’s record:

def sundaram3(max_n):
    numbers = range(3, max_n+1, 2)
    half = (max_n)//2
    initial = 4

    for step in xrange(3, max_n+1, 2):
        for i in xrange(initial, half, step):
            numbers[i-1] = 0
        initial += 2*(step+1)

        if initial > half:
            return [2] + filter(None, numbers)

Comparasion:

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.get_primes_erat(1000000)"
10 loops, best of 3: 710 msec per loop

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.daniel_sieve_2(1000000)"
10 loops, best of 3: 435 msec per loop

C:\USERS>python -m timeit -n10 -s "import get_primes" "get_primes.sundaram3(1000000)"
10 loops, best of 3: 327 msec per loop

回答 4

该算法速度很快,但存在严重缺陷:

>>> sorted(get_primes(530))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163,
167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251,
257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349,
353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443,
449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 527, 529]
>>> 17*31
527
>>> 23*23
529

您假设这numbers.pop()将返回集合中的最小数字,但是完全不能保证。集是无序的,并且pop()删除并返回任意元素,因此不能用于从其余数字中选择下一个素数。

The algorithm is fast, but it has a serious flaw:

>>> sorted(get_primes(530))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163,
167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251,
257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349,
353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443,
449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 527, 529]
>>> 17*31
527
>>> 23*23
529

You assume that numbers.pop() would return the smallest number in the set, but this is not guaranteed at all. Sets are unordered and pop() removes and returns an arbitrary element, so it cannot be used to select the next prime from the remaining numbers.


回答 5

对于具有足够大N的真正最快的解决方案,将是下载预先计算的素数列表,将其存储为元组,然后执行以下操作:

for pos,i in enumerate(primes):
    if i > N:
        print primes[:pos]

如果N > primes[-1] 只是这样,则可以计算更多的质数并将新列表保存在您的代码中,因此下一次同样快。

始终在框外思考。

For truly fastest solution with sufficiently large N would be to download a pre-calculated list of primes, store it as a tuple and do something like:

for pos,i in enumerate(primes):
    if i > N:
        print primes[:pos]

If N > primes[-1] only then calculate more primes and save the new list in your code, so next time it is equally as fast.

Always think outside the box.


回答 6

如果您不想重新发明轮子,可以安装符号数学库sympy(是的,它与Python 3兼容)

pip install sympy

并使用primerange功能

from sympy import sieve
primes = list(sieve.primerange(1, 10**6))

If you don’t want to reinvent the wheel, you can install the symbolic maths library sympy (yes it’s Python 3 compatible)

pip install sympy

And use the primerange function

from sympy import sieve
primes = list(sieve.primerange(1, 10**6))

回答 7

如果您接受itertools但不接受numpy,则这是针对Python 3的rwh_primes2的改编版,其在我的计算机上的运行速度约为以前的两倍。唯一的实质性更改是使用字节数组而不是布尔值列表,并使用compress而不是列表推导来构建最终列表。(如果可以的话,我将其添加为类似moarningsun的评论。)

import itertools
izip = itertools.zip_longest
chain = itertools.chain.from_iterable
compress = itertools.compress
def rwh_primes2_python3(n):
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    zero = bytearray([False])
    size = n//3 + (n % 6 == 2)
    sieve = bytearray([True]) * size
    sieve[0] = False
    for i in range(int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        start = (k*k+4*k-2*k*(i&1))//3
        sieve[(k*k)//3::2*k]=zero*((size - (k*k)//3 - 1) // (2 * k) + 1)
        sieve[  start ::2*k]=zero*((size -   start  - 1) // (2 * k) + 1)
    ans = [2,3]
    poss = chain(izip(*[range(i, n, 6) for i in (1,5)]))
    ans.extend(compress(poss, sieve))
    return ans

比较:

>>> timeit.timeit('primes.rwh_primes2(10**6)', setup='import primes', number=1)
0.0652179726976101
>>> timeit.timeit('primes.rwh_primes2_python3(10**6)', setup='import primes', number=1)
0.03267321276325674

>>> timeit.timeit('primes.rwh_primes2(10**8)', setup='import primes', number=1)
6.394284538007014
>>> timeit.timeit('primes.rwh_primes2_python3(10**8)', setup='import primes', number=1)
3.833829450302801

If you accept itertools but not numpy, here is an adaptation of rwh_primes2 for Python 3 that runs about twice as fast on my machine. The only substantial change is using a bytearray instead of a list for the boolean, and using compress instead of a list comprehension to build the final list. (I’d add this as a comment like moarningsun if I were able.)

import itertools
izip = itertools.zip_longest
chain = itertools.chain.from_iterable
compress = itertools.compress
def rwh_primes2_python3(n):
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    zero = bytearray([False])
    size = n//3 + (n % 6 == 2)
    sieve = bytearray([True]) * size
    sieve[0] = False
    for i in range(int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        start = (k*k+4*k-2*k*(i&1))//3
        sieve[(k*k)//3::2*k]=zero*((size - (k*k)//3 - 1) // (2 * k) + 1)
        sieve[  start ::2*k]=zero*((size -   start  - 1) // (2 * k) + 1)
    ans = [2,3]
    poss = chain(izip(*[range(i, n, 6) for i in (1,5)]))
    ans.extend(compress(poss, sieve))
    return ans

Comparisons:

>>> timeit.timeit('primes.rwh_primes2(10**6)', setup='import primes', number=1)
0.0652179726976101
>>> timeit.timeit('primes.rwh_primes2_python3(10**6)', setup='import primes', number=1)
0.03267321276325674

and

>>> timeit.timeit('primes.rwh_primes2(10**8)', setup='import primes', number=1)
6.394284538007014
>>> timeit.timeit('primes.rwh_primes2_python3(10**8)', setup='import primes', number=1)
3.833829450302801

回答 8

编写自己的主要发现代码很有启发性,但是手头有一个快速可靠的库也很有用。我围绕C ++库primesieve编写了一个包装器,将其命名为primesieve-python

尝试一下 pip install primesieve

import primesieve
primes = primesieve.generate_primes(10**8)

我很好奇看到速度比较。

It’s instructive to write your own prime finding code, but it’s also useful to have a fast reliable library at hand. I wrote a wrapper around the C++ library primesieve, named it primesieve-python

Try it pip install primesieve

import primesieve
primes = primesieve.generate_primes(10**8)

I’d be curious to see the speed compared.


回答 9

这是最快的功能之一的两个更新版本(纯Python 3.6),

from itertools import compress

def rwh_primes1v1(n):
    """ Returns  a list of primes < n for n > 2 """
    sieve = bytearray([True]) * (n//2)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = bytearray((n-i*i-1)//(2*i)+1)
    return [2,*compress(range(3,n,2), sieve[1:])]

def rwh_primes1v2(n):
    """ Returns a list of primes < n for n > 2 """
    sieve = bytearray([True]) * (n//2+1)
    for i in range(1,int(n**0.5)//2+1):
        if sieve[i]:
            sieve[2*i*(i+1)::2*i+1] = bytearray((n//2-2*i*(i+1))//(2*i+1)+1)
    return [2,*compress(range(3,n,2), sieve[1:])]

Here is two updated (pure Python 3.6) versions of one of the fastest functions,

from itertools import compress

def rwh_primes1v1(n):
    """ Returns  a list of primes < n for n > 2 """
    sieve = bytearray([True]) * (n//2)
    for i in range(3,int(n**0.5)+1,2):
        if sieve[i//2]:
            sieve[i*i//2::i] = bytearray((n-i*i-1)//(2*i)+1)
    return [2,*compress(range(3,n,2), sieve[1:])]

def rwh_primes1v2(n):
    """ Returns a list of primes < n for n > 2 """
    sieve = bytearray([True]) * (n//2+1)
    for i in range(1,int(n**0.5)//2+1):
        if sieve[i]:
            sieve[2*i*(i+1)::2*i+1] = bytearray((n//2-2*i*(i+1))//(2*i+1)+1)
    return [2,*compress(range(3,n,2), sieve[1:])]

回答 10

在N <9,080,191的假设下确定性实施Miller-Rabin素数检验

import sys
import random

def miller_rabin_pass(a, n):
    d = n - 1
    s = 0
    while d % 2 == 0:
        d >>= 1
        s += 1

    a_to_power = pow(a, d, n)
    if a_to_power == 1:
        return True
    for i in xrange(s-1):
        if a_to_power == n - 1:
            return True
        a_to_power = (a_to_power * a_to_power) % n
    return a_to_power == n - 1


def miller_rabin(n):
    for a in [2, 3, 37, 73]:
      if not miller_rabin_pass(a, n):
        return False
    return True


n = int(sys.argv[1])
primes = [2]
for p in range(3,n,2):
  if miller_rabin(p):
    primes.append(p)
print len(primes)

根据Wikipedia上的文章(http://en.wikipedia.org/wiki/Miller–Rabin_primality_test),测试N <9,080,191的a = 2,3,37和73足以确定N是否为复合值。

然后,我从此处找到的原始Miller-Rabin测试的概率实现中改编了源代码:http : //en.literateprograms.org/Miller-Rabin_primality_test_(Python)

A deterministic implementation of Miller-Rabin’s Primality test on the assumption that N < 9,080,191

import sys
import random

def miller_rabin_pass(a, n):
    d = n - 1
    s = 0
    while d % 2 == 0:
        d >>= 1
        s += 1

    a_to_power = pow(a, d, n)
    if a_to_power == 1:
        return True
    for i in xrange(s-1):
        if a_to_power == n - 1:
            return True
        a_to_power = (a_to_power * a_to_power) % n
    return a_to_power == n - 1


def miller_rabin(n):
    for a in [2, 3, 37, 73]:
      if not miller_rabin_pass(a, n):
        return False
    return True


n = int(sys.argv[1])
primes = [2]
for p in range(3,n,2):
  if miller_rabin(p):
    primes.append(p)
print len(primes)

According to the article on Wikipedia (http://en.wikipedia.org/wiki/Miller–Rabin_primality_test) testing N < 9,080,191 for a = 2,3,37, and 73 is enough to decide whether N is composite or not.

And I adapted the source code from the probabilistic implementation of original Miller-Rabin’s test found here: http://en.literateprograms.org/Miller-Rabin_primality_test_(Python)


回答 11

如果您可以控制N,则列出所有素数的最快方法是预先计算它们。说真的 预计算是一种被忽略的优化方法。

If you have control over N, the very fastest way to list all primes is to precompute them. Seriously. Precomputing is a way overlooked optimization.


回答 12

这是我通常用于在Python中生成素数的代码:

$ python -mtimeit -s'import sieve' 'sieve.sieve(1000000)' 
10 loops, best of 3: 445 msec per loop
$ cat sieve.py
from math import sqrt

def sieve(size):
 prime=[True]*size
 rng=xrange
 limit=int(sqrt(size))

 for i in rng(3,limit+1,+2):
  if prime[i]:
   prime[i*i::+i]=[False]*len(prime[i*i::+i])

 return [2]+[i for i in rng(3,size,+2) if prime[i]]

if __name__=='__main__':
 print sieve(100)

它无法与此处发布的更快的解决方案竞争,但至少它是纯python。

感谢您发布此问题。今天我真的学到了很多东西。

Here’s the code I normally use to generate primes in Python:

$ python -mtimeit -s'import sieve' 'sieve.sieve(1000000)' 
10 loops, best of 3: 445 msec per loop
$ cat sieve.py
from math import sqrt

def sieve(size):
 prime=[True]*size
 rng=xrange
 limit=int(sqrt(size))

 for i in rng(3,limit+1,+2):
  if prime[i]:
   prime[i*i::+i]=[False]*len(prime[i*i::+i])

 return [2]+[i for i in rng(3,size,+2) if prime[i]]

if __name__=='__main__':
 print sieve(100)

It can’t compete with the faster solutions posted here, but at least it is pure python.

Thanks for posting this question. I really learnt a lot today.


回答 13

对于最快的代码,numpy解决方案是最好的。不过,出于纯粹的学术原因,我要发布我的纯python版本,该版本比上面发布的食谱版本快50%。由于我将整个列表存储在内存中,因此您需要足够的空间来容纳所有内容,但它似乎可以很好地扩展。

def daniel_sieve_2(maxNumber):
    """
    Given a number, returns all numbers less than or equal to
    that number which are prime.
    """
    allNumbers = range(3, maxNumber+1, 2)
    for mIndex, number in enumerate(xrange(3, maxNumber+1, 2)):
        if allNumbers[mIndex] == 0:
            continue
        # now set all multiples to 0
        for index in xrange(mIndex+number, (maxNumber-3)/2+1, number):
            allNumbers[index] = 0
    return [2] + filter(lambda n: n!=0, allNumbers)

结果:

>>>mine = timeit.Timer("daniel_sieve_2(1000000)",
...                    "from sieves import daniel_sieve_2")
>>>prev = timeit.Timer("get_primes_erat(1000000)",
...                    "from sieves import get_primes_erat")
>>>print "Mine: {0:0.4f} ms".format(min(mine.repeat(3, 1))*1000)
Mine: 428.9446 ms
>>>print "Previous Best {0:0.4f} ms".format(min(prev.repeat(3, 1))*1000)
Previous Best 621.3581 ms

For the fastest code, the numpy solution is the best. For purely academic reasons, though, I’m posting my pure python version, which is a bit less than 50% faster than the cookbook version posted above. Since I make the entire list in memory, you need enough space to hold everything, but it seems to scale fairly well.

def daniel_sieve_2(maxNumber):
    """
    Given a number, returns all numbers less than or equal to
    that number which are prime.
    """
    allNumbers = range(3, maxNumber+1, 2)
    for mIndex, number in enumerate(xrange(3, maxNumber+1, 2)):
        if allNumbers[mIndex] == 0:
            continue
        # now set all multiples to 0
        for index in xrange(mIndex+number, (maxNumber-3)/2+1, number):
            allNumbers[index] = 0
    return [2] + filter(lambda n: n!=0, allNumbers)

And the results:

>>>mine = timeit.Timer("daniel_sieve_2(1000000)",
...                    "from sieves import daniel_sieve_2")
>>>prev = timeit.Timer("get_primes_erat(1000000)",
...                    "from sieves import get_primes_erat")
>>>print "Mine: {0:0.4f} ms".format(min(mine.repeat(3, 1))*1000)
Mine: 428.9446 ms
>>>print "Previous Best {0:0.4f} ms".format(min(prev.repeat(3, 1))*1000)
Previous Best 621.3581 ms

回答 14

使用Numpy的半筛的实现略有不同:

http://rebrained.com/?p=458

导入数学
导入numpy
def prime6(最高):
    primes = numpy.arange(3,最高+1,2)
    isprime = numpy.ones((最多1)/ 2,dtype = bool)
    对于素数[:int(math.sqrt(upto))]:
        如果isprime [(factor-2)/ 2]:isprime [(factor * 3-2)/ 2:(upto-1)/ 2:factor] = 0
    返回numpy.insert(primes [isprime],0,2)

有人可以将此与其他时间进行比较吗?在我的机器上,它看起来可以与其他Numpy半筛媲美。

A slightly different implementation of a half sieve using Numpy:

http://rebrained.com/?p=458

import math
import numpy
def prime6(upto):
    primes=numpy.arange(3,upto+1,2)
    isprime=numpy.ones((upto-1)/2,dtype=bool)
    for factor in primes[:int(math.sqrt(upto))]:
        if isprime[(factor-2)/2]: isprime[(factor*3-2)/2:(upto-1)/2:factor]=0
    return numpy.insert(primes[isprime],0,2)

Can someone compare this with the other timings? On my machine it seems pretty comparable to the other Numpy half-sieve.


回答 15

全部都经过编写和测试。因此,无需重新发明轮子。

python -m timeit -r10 -s"from sympy import sieve" "primes = list(sieve.primerange(1, 10**6))"

给我们打破了12.2毫秒的记录!

10 loops, best of 10: 12.2 msec per loop

如果这还不够快,您可以尝试PyPy:

pypy -m timeit -r10 -s"from sympy import sieve" "primes = list(sieve.primerange(1, 10**6))"

结果是:

10 loops, best of 10: 2.03 msec per loop

247次投票的答案列出了15.9毫秒的最佳解决方案。比较一下!!!

It’s all written and tested. So there is no need to reinvent the wheel.

python -m timeit -r10 -s"from sympy import sieve" "primes = list(sieve.primerange(1, 10**6))"

gives us a record breaking 12.2 msec!

10 loops, best of 10: 12.2 msec per loop

If this is not fast enough, you can try PyPy:

pypy -m timeit -r10 -s"from sympy import sieve" "primes = list(sieve.primerange(1, 10**6))"

which results in:

10 loops, best of 10: 2.03 msec per loop

The answer with 247 up-votes lists 15.9 ms for the best solution. Compare this!!!


回答 16

我测试了一些unutbu的函数,用饥饿的数百万来计算

获奖者是使用numpy库的函数,

注意:进行内存利用率测试也很有趣:)

计算时间结果

样例代码

我的github存储库上的完整代码

#!/usr/bin/env python

import lib
import timeit
import sys
import math
import datetime

import prettyplotlib as ppl
import numpy as np

import matplotlib.pyplot as plt
from prettyplotlib import brewer2mpl

primenumbers_gen = [
    'sieveOfEratosthenes',
    'ambi_sieve',
    'ambi_sieve_plain',
    'sundaram3',
    'sieve_wheel_30',
    'primesfrom3to',
    'primesfrom2to',
    'rwh_primes',
    'rwh_primes1',
    'rwh_primes2',
]

def human_format(num):
    # /programming/579310/formatting-long-numbers-as-strings-in-python?answertab=active#tab-top
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    # add more suffixes if you need them
    return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])


if __name__=='__main__':

    # Vars
    n = 10000000 # number itereration generator
    nbcol = 5 # For decompose prime number generator
    nb_benchloop = 3 # Eliminate false positive value during the test (bench average time)
    datetimeformat = '%Y-%m-%d %H:%M:%S.%f'
    config = 'from __main__ import n; import lib'
    primenumbers_gen = {
        'sieveOfEratosthenes': {'color': 'b'},
        'ambi_sieve': {'color': 'b'},
        'ambi_sieve_plain': {'color': 'b'},
         'sundaram3': {'color': 'b'},
        'sieve_wheel_30': {'color': 'b'},
# # #        'primesfrom2to': {'color': 'b'},
        'primesfrom3to': {'color': 'b'},
        # 'rwh_primes': {'color': 'b'},
        # 'rwh_primes1': {'color': 'b'},
        'rwh_primes2': {'color': 'b'},
    }


    # Get n in command line
    if len(sys.argv)>1:
        n = int(sys.argv[1])

    step = int(math.ceil(n / float(nbcol)))
    nbs = np.array([i * step for i in range(1, int(nbcol) + 1)])
    set2 = brewer2mpl.get_map('Paired', 'qualitative', 12).mpl_colors

    print datetime.datetime.now().strftime(datetimeformat)
    print("Compute prime number to %(n)s" % locals())
    print("")

    results = dict()
    for pgen in primenumbers_gen:
        results[pgen] = dict()
        benchtimes = list()
        for n in nbs:
            t = timeit.Timer("lib.%(pgen)s(n)" % locals(), setup=config)
            execute_times = t.repeat(repeat=nb_benchloop,number=1)
            benchtime = np.mean(execute_times)
            benchtimes.append(benchtime)
        results[pgen] = {'benchtimes':np.array(benchtimes)}

fig, ax = plt.subplots(1)
plt.ylabel('Computation time (in second)')
plt.xlabel('Numbers computed')
i = 0
for pgen in primenumbers_gen:

    bench = results[pgen]['benchtimes']
    avgs = np.divide(bench,nbs)
    avg = np.average(bench, weights=nbs)

    # Compute linear regression
    A = np.vstack([nbs, np.ones(len(nbs))]).T
    a, b = np.linalg.lstsq(A, nbs*avgs)[0]

    # Plot
    i += 1
    #label="%(pgen)s" % locals()
    #ppl.plot(nbs, nbs*avgs, label=label, lw=1, linestyle='--', color=set2[i % 12])
    label="%(pgen)s avg" % locals()
    ppl.plot(nbs, a * nbs + b, label=label, lw=2, color=set2[i % 12])
print datetime.datetime.now().strftime(datetimeformat)

ppl.legend(ax, loc='upper left', ncol=4)

# Change x axis label
ax.get_xaxis().get_major_formatter().set_scientific(False)
fig.canvas.draw()
labels = [human_format(int(item.get_text())) for item in ax.get_xticklabels()]

ax.set_xticklabels(labels)
ax = plt.gca()

plt.show()

I tested some unutbu’s functions, i computed it with hungred millions number

The winners are the functions that use numpy library,

Note: It would also interesting make a memory utilization test :)

Computation time result

Sample code

Complete code on my github repository

#!/usr/bin/env python

import lib
import timeit
import sys
import math
import datetime

import prettyplotlib as ppl
import numpy as np

import matplotlib.pyplot as plt
from prettyplotlib import brewer2mpl

primenumbers_gen = [
    'sieveOfEratosthenes',
    'ambi_sieve',
    'ambi_sieve_plain',
    'sundaram3',
    'sieve_wheel_30',
    'primesfrom3to',
    'primesfrom2to',
    'rwh_primes',
    'rwh_primes1',
    'rwh_primes2',
]

def human_format(num):
    # https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python?answertab=active#tab-top
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    # add more suffixes if you need them
    return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])


if __name__=='__main__':

    # Vars
    n = 10000000 # number itereration generator
    nbcol = 5 # For decompose prime number generator
    nb_benchloop = 3 # Eliminate false positive value during the test (bench average time)
    datetimeformat = '%Y-%m-%d %H:%M:%S.%f'
    config = 'from __main__ import n; import lib'
    primenumbers_gen = {
        'sieveOfEratosthenes': {'color': 'b'},
        'ambi_sieve': {'color': 'b'},
        'ambi_sieve_plain': {'color': 'b'},
         'sundaram3': {'color': 'b'},
        'sieve_wheel_30': {'color': 'b'},
# # #        'primesfrom2to': {'color': 'b'},
        'primesfrom3to': {'color': 'b'},
        # 'rwh_primes': {'color': 'b'},
        # 'rwh_primes1': {'color': 'b'},
        'rwh_primes2': {'color': 'b'},
    }


    # Get n in command line
    if len(sys.argv)>1:
        n = int(sys.argv[1])

    step = int(math.ceil(n / float(nbcol)))
    nbs = np.array([i * step for i in range(1, int(nbcol) + 1)])
    set2 = brewer2mpl.get_map('Paired', 'qualitative', 12).mpl_colors

    print datetime.datetime.now().strftime(datetimeformat)
    print("Compute prime number to %(n)s" % locals())
    print("")

    results = dict()
    for pgen in primenumbers_gen:
        results[pgen] = dict()
        benchtimes = list()
        for n in nbs:
            t = timeit.Timer("lib.%(pgen)s(n)" % locals(), setup=config)
            execute_times = t.repeat(repeat=nb_benchloop,number=1)
            benchtime = np.mean(execute_times)
            benchtimes.append(benchtime)
        results[pgen] = {'benchtimes':np.array(benchtimes)}

fig, ax = plt.subplots(1)
plt.ylabel('Computation time (in second)')
plt.xlabel('Numbers computed')
i = 0
for pgen in primenumbers_gen:

    bench = results[pgen]['benchtimes']
    avgs = np.divide(bench,nbs)
    avg = np.average(bench, weights=nbs)

    # Compute linear regression
    A = np.vstack([nbs, np.ones(len(nbs))]).T
    a, b = np.linalg.lstsq(A, nbs*avgs)[0]

    # Plot
    i += 1
    #label="%(pgen)s" % locals()
    #ppl.plot(nbs, nbs*avgs, label=label, lw=1, linestyle='--', color=set2[i % 12])
    label="%(pgen)s avg" % locals()
    ppl.plot(nbs, a * nbs + b, label=label, lw=2, color=set2[i % 12])
print datetime.datetime.now().strftime(datetimeformat)

ppl.legend(ax, loc='upper left', ncol=4)

# Change x axis label
ax.get_xaxis().get_major_formatter().set_scientific(False)
fig.canvas.draw()
labels = [human_format(int(item.get_text())) for item in ax.get_xticklabels()]

ax.set_xticklabels(labels)
ax = plt.gca()

plt.show()

回答 17

对于Python 3

def rwh_primes2(n):
    correction = (n%6>1)
    n = {0:n,1:n-1,2:n+4,3:n+3,4:n+2,5:n+1}[n%6]
    sieve = [True] * (n//3)
    sieve[0] = False
    for i in range(int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      ((k*k)//3)      ::2*k]=[False]*((n//6-(k*k)//6-1)//k+1)
        sieve[(k*k+4*k-2*k*(i&1))//3::2*k]=[False]*((n//6-(k*k+4*k-2*k*(i&1))//6-1)//k+1)
    return [2,3] + [3*i+1|1 for i in range(1,n//3-correction) if sieve[i]]

For Python 3

def rwh_primes2(n):
    correction = (n%6>1)
    n = {0:n,1:n-1,2:n+4,3:n+3,4:n+2,5:n+1}[n%6]
    sieve = [True] * (n//3)
    sieve[0] = False
    for i in range(int(n**0.5)//3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      ((k*k)//3)      ::2*k]=[False]*((n//6-(k*k)//6-1)//k+1)
        sieve[(k*k+4*k-2*k*(i&1))//3::2*k]=[False]*((n//6-(k*k+4*k-2*k*(i&1))//6-1)//k+1)
    return [2,3] + [3*i+1|1 for i in range(1,n//3-correction) if sieve[i]]

回答 18

纯Python中最快的基本筛

from itertools import compress

def half_sieve(n):
    """
    Returns a list of prime numbers less than `n`.
    """
    if n <= 2:
        return []
    sieve = bytearray([True]) * (n // 2)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = bytearray((n - i * i - 1) // (2 * i) + 1)
    primes = list(compress(range(1, n, 2), sieve))
    primes[0] = 2
    return primes

我优化了Eratosthenes筛的速度和内存。

基准测试

from time import clock
import platform

def benchmark(iterations, limit):
    start = clock()
    for x in range(iterations):
        half_sieve(limit)
    end = clock() - start
    print(f'{end/iterations:.4f} seconds for primes < {limit}')

if __name__ == '__main__':
    print(platform.python_version())
    print(platform.platform())
    print(platform.processor())
    it = 10
    for pw in range(4, 9):
        benchmark(it, 10**pw)

输出量

>>> 3.6.7
>>> Windows-10-10.0.17763-SP0
>>> Intel64 Family 6 Model 78 Stepping 3, GenuineIntel
>>> 0.0003 seconds for primes < 10000
>>> 0.0021 seconds for primes < 100000
>>> 0.0204 seconds for primes < 1000000
>>> 0.2389 seconds for primes < 10000000
>>> 2.6702 seconds for primes < 100000000

Fastest prime sieve in Pure Python:

from itertools import compress

def half_sieve(n):
    """
    Returns a list of prime numbers less than `n`.
    """
    if n <= 2:
        return []
    sieve = bytearray([True]) * (n // 2)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = bytearray((n - i * i - 1) // (2 * i) + 1)
    primes = list(compress(range(1, n, 2), sieve))
    primes[0] = 2
    return primes

I optimised Sieve of Eratosthenes for speed and memory.

Benchmark

from time import clock
import platform

def benchmark(iterations, limit):
    start = clock()
    for x in range(iterations):
        half_sieve(limit)
    end = clock() - start
    print(f'{end/iterations:.4f} seconds for primes < {limit}')

if __name__ == '__main__':
    print(platform.python_version())
    print(platform.platform())
    print(platform.processor())
    it = 10
    for pw in range(4, 9):
        benchmark(it, 10**pw)

Output

>>> 3.6.7
>>> Windows-10-10.0.17763-SP0
>>> Intel64 Family 6 Model 78 Stepping 3, GenuineIntel
>>> 0.0003 seconds for primes < 10000
>>> 0.0021 seconds for primes < 100000
>>> 0.0204 seconds for primes < 1000000
>>> 0.2389 seconds for primes < 10000000
>>> 2.6702 seconds for primes < 100000000

回答 19

第一次使用python,因此我在其中使用的某些方法似乎有点麻烦。我只是将我的c ++代码直接转换为python,这就是我所拥有的(尽管在python中有点slowww)

#!/usr/bin/env python
import time

def GetPrimes(n):

    Sieve = [1 for x in xrange(n)]

    Done = False
    w = 3

    while not Done:

        for q in xrange (3, n, 2):
            Prod = w*q
            if Prod < n:
                Sieve[Prod] = 0
            else:
                break

        if w > (n/2):
            Done = True
        w += 2

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

pythonw Primes.py

在12.799119秒内找到664579质数!

#!/usr/bin/env python
import time

def GetPrimes2(n):

    Sieve = [1 for x in xrange(n)]

    for q in xrange (3, n, 2):
        k = q
        for y in xrange(k*3, n, k*2):
            Sieve[y] = 0

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes2(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

pythonw Primes2.py

在10.230172秒内找到664579质数!

#!/usr/bin/env python
import time

def GetPrimes3(n):

    Sieve = [1 for x in xrange(n)]

    for q in xrange (3, n, 2):
        k = q
        for y in xrange(k*k, n, k << 1):
            Sieve[y] = 0

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes3(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

Python Primes2.py

在7.113776秒内发现664579质数!

First time using python, so some of the methods I use in this might seem a bit cumbersome. I just straight converted my c++ code to python and this is what I have (albeit a tad bit slowww in python)

#!/usr/bin/env python
import time

def GetPrimes(n):

    Sieve = [1 for x in xrange(n)]

    Done = False
    w = 3

    while not Done:

        for q in xrange (3, n, 2):
            Prod = w*q
            if Prod < n:
                Sieve[Prod] = 0
            else:
                break

        if w > (n/2):
            Done = True
        w += 2

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

pythonw Primes.py

Found 664579 primes in 12.799119 seconds!

#!/usr/bin/env python
import time

def GetPrimes2(n):

    Sieve = [1 for x in xrange(n)]

    for q in xrange (3, n, 2):
        k = q
        for y in xrange(k*3, n, k*2):
            Sieve[y] = 0

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes2(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

pythonw Primes2.py

Found 664579 primes in 10.230172 seconds!

#!/usr/bin/env python
import time

def GetPrimes3(n):

    Sieve = [1 for x in xrange(n)]

    for q in xrange (3, n, 2):
        k = q
        for y in xrange(k*k, n, k << 1):
            Sieve[y] = 0

    return Sieve



start = time.clock()

d = 10000000
Primes = GetPrimes3(d)

count = 1 #This is for 2

for x in xrange (3, d, 2):
    if Primes[x]:
        count+=1

elapsed = (time.clock() - start)
print "\nFound", count, "primes in", elapsed, "seconds!\n"

python Primes2.py

Found 664579 primes in 7.113776 seconds!


回答 20

我知道比赛已经结束了几年。…

尽管如此,这是我对纯python主筛的建议,它基于在处理向前的筛子时通过使用适当的步骤来省略2、3和5的倍数。但是,对于N <10 ^ 9而言,它实际上要比@Robert William Hanks高级解决方案rwh_primes2和rwh_primes1慢。通过使用高于1.5 * 10 ^ 8的ctypes.c_ushort筛子数组,可以以某种方式适应内存限制。

10 ^ 6

$ python -mtimeit -s“导入primeSieveSpeedComp”“ primeSieveSpeedComp.primeSieveSeq(1000000)” 10个循环,每个循环最好3:46.7毫秒

比较:$ python -mtimeit -s“导入primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes1(1000000)” 10个循环,最好是3个循环:每个循环要进行43.2毫秒比较:$ python -m timeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes2 (1000000)“ 10个循环,最好为3:每个循环34.5毫秒

10 ^ 7

$ python -mtimeit -s“导入primeSieveSpeedComp”“ primeSieveSpeedComp.primeSieveSeq(10000000)” 10个循环,每循环最好530毫秒

比较:$ python -mtimeit -s“导入primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes1(10000000)” 10个循环,最好是3个循环:494毫秒比较:$ python -m timeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes2 (10000000)“ 10个循环,每个循环最好3:375毫秒

10 ^ 8

$ python -mtimeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.primeSieveSeq(100000000)” 10个循环,每循环最好3:5.55秒

比较:$ python -mtimeit -s“导入primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes1(100000000)” 10个循环,最好每个循环进行3:5.33秒进行比较:$ python -m timeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes2 (100000000)“ 10个循环,每循环3:3.95秒的最佳时间

10 ^ 9

$ python -mtimeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.primeSieveSeq(1000000000)” 10个循环,最好3个循环:每个循环61.2

比较:$ python -mtimeit -n 3 -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes1(1000000000)” 3个循环,最佳3:97.8 每个循环秒

比较:$ python -m timeit -s“ import primeSieveSpeedComp”“ primeSieveSpeedComp.rwh_primes2(1000000000)” 10个循环,最好3个循环:每个循环41.9秒

您可以将以下代码复制到ubuntus primeSieveSpeedComp中,以查看此测试。

def primeSieveSeq(MAX_Int):
    if MAX_Int > 5*10**8:
        import ctypes
        int16Array = ctypes.c_ushort * (MAX_Int >> 1)
        sieve = int16Array()
        #print 'uses ctypes "unsigned short int Array"'
    else:
        sieve = (MAX_Int >> 1) * [False]
        #print 'uses python list() of long long int'
    if MAX_Int < 10**8:
        sieve[4::3] = [True]*((MAX_Int - 8)/6+1)
        sieve[12::5] = [True]*((MAX_Int - 24)/10+1)
    r = [2, 3, 5]
    n = 0
    for i in xrange(int(MAX_Int**0.5)/30+1):
        n += 3
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 3
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
    if MAX_Int < 10**8:
        return [2, 3, 5]+[(p << 1) + 1 for p in [n for n in xrange(3, MAX_Int >> 1) if not sieve[n]]]
    n = n >> 1
    try:
        for i in xrange((MAX_Int-2*n)/30 + 1):
            n += 3
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 3
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
    except:
        pass
    return r

I know the competition is closed for some years. …

Nonetheless this is my suggestion for a pure python prime sieve, based on omitting the multiples of 2, 3 and 5 by using appropriate steps while processing the sieve forward. Nonetheless it is actually slower for N<10^9 than @Robert William Hanks superior solutions rwh_primes2 and rwh_primes1. By using a ctypes.c_ushort sieve array above 1.5* 10^8 it is somehow adaptive to memory limits.

10^6

$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.primeSieveSeq(1000000)” 10 loops, best of 3: 46.7 msec per loop

to compare:$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes1(1000000)” 10 loops, best of 3: 43.2 msec per loop to compare: $ python -m timeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes2(1000000)” 10 loops, best of 3: 34.5 msec per loop

10^7

$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.primeSieveSeq(10000000)” 10 loops, best of 3: 530 msec per loop

to compare:$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes1(10000000)” 10 loops, best of 3: 494 msec per loop to compare: $ python -m timeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes2(10000000)” 10 loops, best of 3: 375 msec per loop

10^8

$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.primeSieveSeq(100000000)” 10 loops, best of 3: 5.55 sec per loop

to compare: $ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes1(100000000)” 10 loops, best of 3: 5.33 sec per loop to compare: $ python -m timeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes2(100000000)” 10 loops, best of 3: 3.95 sec per loop

10^9

$ python -mtimeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.primeSieveSeq(1000000000)” 10 loops, best of 3: 61.2 sec per loop

to compare: $ python -mtimeit -n 3 -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes1(1000000000)” 3 loops, best of 3: 97.8 sec per loop

to compare: $ python -m timeit -s”import primeSieveSpeedComp” “primeSieveSpeedComp.rwh_primes2(1000000000)” 10 loops, best of 3: 41.9 sec per loop

You may copy the code below into ubuntus primeSieveSpeedComp to review this tests.

def primeSieveSeq(MAX_Int):
    if MAX_Int > 5*10**8:
        import ctypes
        int16Array = ctypes.c_ushort * (MAX_Int >> 1)
        sieve = int16Array()
        #print 'uses ctypes "unsigned short int Array"'
    else:
        sieve = (MAX_Int >> 1) * [False]
        #print 'uses python list() of long long int'
    if MAX_Int < 10**8:
        sieve[4::3] = [True]*((MAX_Int - 8)/6+1)
        sieve[12::5] = [True]*((MAX_Int - 24)/10+1)
    r = [2, 3, 5]
    n = 0
    for i in xrange(int(MAX_Int**0.5)/30+1):
        n += 3
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 2
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 3
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
        n += 1
        if not sieve[n]:
            n2 = (n << 1) + 1
            r.append(n2)
            n2q = (n2**2) >> 1
            sieve[n2q::n2] = [True]*(((MAX_Int >> 1) - n2q - 1) / n2 + 1)
    if MAX_Int < 10**8:
        return [2, 3, 5]+[(p << 1) + 1 for p in [n for n in xrange(3, MAX_Int >> 1) if not sieve[n]]]
    n = n >> 1
    try:
        for i in xrange((MAX_Int-2*n)/30 + 1):
            n += 3
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 2
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 3
            if not sieve[n]:
                r.append((n << 1) + 1)
            n += 1
            if not sieve[n]:
                r.append((n << 1) + 1)
    except:
        pass
    return r

回答 21

这是Eratosthenes筛子的numpy版本,具有良好的复杂性(比对长度为n的数组排序要低)和矢量化。与@unutbu相比,此速度与使用46微微秒的软件包查找所有小于一百万的质数的速度一样快。

import numpy as np 
def generate_primes(n):
    is_prime = np.ones(n+1,dtype=bool)
    is_prime[0:2] = False
    for i in range(int(n**0.5)+1):
        if is_prime[i]:
            is_prime[i**2::i]=False
    return np.where(is_prime)[0]

时间:

import time    
for i in range(2,10):
    timer =time.time()
    generate_primes(10**i)
    print('n = 10^',i,' time =', round(time.time()-timer,6))

>> n = 10^ 2  time = 5.6e-05
>> n = 10^ 3  time = 6.4e-05
>> n = 10^ 4  time = 0.000114
>> n = 10^ 5  time = 0.000593
>> n = 10^ 6  time = 0.00467
>> n = 10^ 7  time = 0.177758
>> n = 10^ 8  time = 1.701312
>> n = 10^ 9  time = 19.322478

Here is a numpy version of Sieve of Eratosthenes having both good complexity (lower than sorting an array of length n) and vectorization. Compared to @unutbu times this just as fast as the packages with 46 microsecons to find all primes below a million.

import numpy as np 
def generate_primes(n):
    is_prime = np.ones(n+1,dtype=bool)
    is_prime[0:2] = False
    for i in range(int(n**0.5)+1):
        if is_prime[i]:
            is_prime[i**2::i]=False
    return np.where(is_prime)[0]

Timings:

import time    
for i in range(2,10):
    timer =time.time()
    generate_primes(10**i)
    print('n = 10^',i,' time =', round(time.time()-timer,6))

>> n = 10^ 2  time = 5.6e-05
>> n = 10^ 3  time = 6.4e-05
>> n = 10^ 4  time = 0.000114
>> n = 10^ 5  time = 0.000593
>> n = 10^ 6  time = 0.00467
>> n = 10^ 7  time = 0.177758
>> n = 10^ 8  time = 1.701312
>> n = 10^ 9  time = 19.322478

回答 22

我已经更新了许多Python 3代码,并将其扔到了perfplot(我的一个项目)上,以查看哪个实际上最快。事实证明,对于大蛋糕nprimesfrom{2,3}to可以选择蛋糕:

在此处输入图片说明


复制剧情的代码:

import perfplot
from math import sqrt, ceil
import numpy as np
import sympy


def rwh_primes(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i]:
            sieve[i * i::2 * i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
    return [2] + [i for i in range(3, n, 2) if sieve[i]]


def rwh_primes1(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * (n // 2)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
    return [2] + [2 * i + 1 for i in range(1, n // 2) if sieve[i]]


def rwh_primes2(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """Input n>=6, Returns a list of primes, 2 <= p < n"""
    assert n >= 6
    correction = n % 6 > 1
    n = {0: n, 1: n - 1, 2: n + 4, 3: n + 3, 4: n + 2, 5: n + 1}[n % 6]
    sieve = [True] * (n // 3)
    sieve[0] = False
    for i in range(int(n ** 0.5) // 3 + 1):
        if sieve[i]:
            k = 3 * i + 1 | 1
            sieve[((k * k) // 3)::2 * k] = [False] * (
                (n // 6 - (k * k) // 6 - 1) // k + 1
            )
            sieve[(k * k + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = [False] * (
                (n // 6 - (k * k + 4 * k - 2 * k * (i & 1)) // 6 - 1) // k + 1
            )
    return [2, 3] + [3 * i + 1 | 1 for i in range(1, n // 3 - correction) if sieve[i]]


def sieve_wheel_30(N):
    # http://zerovolt.com/?p=88
    """ Returns a list of primes <= N using wheel criterion 2*3*5 = 30

Copyright 2009 by zerovolt.com
This code is free for non-commercial purposes, in which case you can just leave this comment as a credit for my work.
If you need this code for commercial purposes, please contact me by sending an email to: info [at] zerovolt [dot] com."""
    __smallp = (
        2,
        3,
        5,
        7,
        11,
        13,
        17,
        19,
        23,
        29,
        31,
        37,
        41,
        43,
        47,
        53,
        59,
        61,
        67,
        71,
        73,
        79,
        83,
        89,
        97,
        101,
        103,
        107,
        109,
        113,
        127,
        131,
        137,
        139,
        149,
        151,
        157,
        163,
        167,
        173,
        179,
        181,
        191,
        193,
        197,
        199,
        211,
        223,
        227,
        229,
        233,
        239,
        241,
        251,
        257,
        263,
        269,
        271,
        277,
        281,
        283,
        293,
        307,
        311,
        313,
        317,
        331,
        337,
        347,
        349,
        353,
        359,
        367,
        373,
        379,
        383,
        389,
        397,
        401,
        409,
        419,
        421,
        431,
        433,
        439,
        443,
        449,
        457,
        461,
        463,
        467,
        479,
        487,
        491,
        499,
        503,
        509,
        521,
        523,
        541,
        547,
        557,
        563,
        569,
        571,
        577,
        587,
        593,
        599,
        601,
        607,
        613,
        617,
        619,
        631,
        641,
        643,
        647,
        653,
        659,
        661,
        673,
        677,
        683,
        691,
        701,
        709,
        719,
        727,
        733,
        739,
        743,
        751,
        757,
        761,
        769,
        773,
        787,
        797,
        809,
        811,
        821,
        823,
        827,
        829,
        839,
        853,
        857,
        859,
        863,
        877,
        881,
        883,
        887,
        907,
        911,
        919,
        929,
        937,
        941,
        947,
        953,
        967,
        971,
        977,
        983,
        991,
        997,
    )
    # wheel = (2, 3, 5)
    const = 30
    if N < 2:
        return []
    if N <= const:
        pos = 0
        while __smallp[pos] <= N:
            pos += 1
        return list(__smallp[:pos])
    # make the offsets list
    offsets = (7, 11, 13, 17, 19, 23, 29, 1)
    # prepare the list
    p = [2, 3, 5]
    dim = 2 + N // const
    tk1 = [True] * dim
    tk7 = [True] * dim
    tk11 = [True] * dim
    tk13 = [True] * dim
    tk17 = [True] * dim
    tk19 = [True] * dim
    tk23 = [True] * dim
    tk29 = [True] * dim
    tk1[0] = False
    # help dictionary d
    # d[a , b] = c  ==> if I want to find the smallest useful multiple of (30*pos)+a
    # on tkc, then I need the index given by the product of [(30*pos)+a][(30*pos)+b]
    # in general. If b < a, I need [(30*pos)+a][(30*(pos+1))+b]
    d = {}
    for x in offsets:
        for y in offsets:
            res = (x * y) % const
            if res in offsets:
                d[(x, res)] = y
    # another help dictionary: gives tkx calling tmptk[x]
    tmptk = {1: tk1, 7: tk7, 11: tk11, 13: tk13, 17: tk17, 19: tk19, 23: tk23, 29: tk29}
    pos, prime, lastadded, stop = 0, 0, 0, int(ceil(sqrt(N)))

    # inner functions definition
    def del_mult(tk, start, step):
        for k in range(start, len(tk), step):
            tk[k] = False

    # end of inner functions definition
    cpos = const * pos
    while prime < stop:
        # 30k + 7
        if tk7[pos]:
            prime = cpos + 7
            p.append(prime)
            lastadded = 7
            for off in offsets:
                tmp = d[(7, off)]
                start = (
                    (pos + prime)
                    if off == 7
                    else (prime * (const * (pos + 1 if tmp < 7 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 11
        if tk11[pos]:
            prime = cpos + 11
            p.append(prime)
            lastadded = 11
            for off in offsets:
                tmp = d[(11, off)]
                start = (
                    (pos + prime)
                    if off == 11
                    else (prime * (const * (pos + 1 if tmp < 11 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 13
        if tk13[pos]:
            prime = cpos + 13
            p.append(prime)
            lastadded = 13
            for off in offsets:
                tmp = d[(13, off)]
                start = (
                    (pos + prime)
                    if off == 13
                    else (prime * (const * (pos + 1 if tmp < 13 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 17
        if tk17[pos]:
            prime = cpos + 17
            p.append(prime)
            lastadded = 17
            for off in offsets:
                tmp = d[(17, off)]
                start = (
                    (pos + prime)
                    if off == 17
                    else (prime * (const * (pos + 1 if tmp < 17 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 19
        if tk19[pos]:
            prime = cpos + 19
            p.append(prime)
            lastadded = 19
            for off in offsets:
                tmp = d[(19, off)]
                start = (
                    (pos + prime)
                    if off == 19
                    else (prime * (const * (pos + 1 if tmp < 19 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 23
        if tk23[pos]:
            prime = cpos + 23
            p.append(prime)
            lastadded = 23
            for off in offsets:
                tmp = d[(23, off)]
                start = (
                    (pos + prime)
                    if off == 23
                    else (prime * (const * (pos + 1 if tmp < 23 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 29
        if tk29[pos]:
            prime = cpos + 29
            p.append(prime)
            lastadded = 29
            for off in offsets:
                tmp = d[(29, off)]
                start = (
                    (pos + prime)
                    if off == 29
                    else (prime * (const * (pos + 1 if tmp < 29 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # now we go back to top tk1, so we need to increase pos by 1
        pos += 1
        cpos = const * pos
        # 30k + 1
        if tk1[pos]:
            prime = cpos + 1
            p.append(prime)
            lastadded = 1
            for off in offsets:
                tmp = d[(1, off)]
                start = (
                    (pos + prime)
                    if off == 1
                    else (prime * (const * pos + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
    # time to add remaining primes
    # if lastadded == 1, remove last element and start adding them from tk1
    # this way we don't need an "if" within the last while
    if lastadded == 1:
        p.pop()
    # now complete for every other possible prime
    while pos < len(tk1):
        cpos = const * pos
        if tk1[pos]:
            p.append(cpos + 1)
        if tk7[pos]:
            p.append(cpos + 7)
        if tk11[pos]:
            p.append(cpos + 11)
        if tk13[pos]:
            p.append(cpos + 13)
        if tk17[pos]:
            p.append(cpos + 17)
        if tk19[pos]:
            p.append(cpos + 19)
        if tk23[pos]:
            p.append(cpos + 23)
        if tk29[pos]:
            p.append(cpos + 29)
        pos += 1
    # remove exceeding if present
    pos = len(p) - 1
    while p[pos] > N:
        pos -= 1
    if pos < len(p) - 1:
        del p[pos + 1 :]
    # return p list
    return p


def sieve_of_eratosthenes(n):
    """sieveOfEratosthenes(n): return the list of the primes < n."""
    # Code from: <dickinsm@gmail.com>, Nov 30 2006
    # http://groups.google.com/group/comp.lang.python/msg/f1f10ced88c68c2d
    if n <= 2:
        return []
    sieve = list(range(3, n, 2))
    top = len(sieve)
    for si in sieve:
        if si:
            bottom = (si * si - 3) // 2
            if bottom >= top:
                break
            sieve[bottom::si] = [0] * -((bottom - top) // si)
    return [2] + [el for el in sieve if el]


def sieve_of_atkin(end):
    """return a list of all the prime numbers <end using the Sieve of Atkin."""
    # Code by Steve Krenzel, <Sgk284@gmail.com>, improved
    # Code: https://web.archive.org/web/20080324064651/http://krenzel.info/?p=83
    # Info: http://en.wikipedia.org/wiki/Sieve_of_Atkin
    assert end > 0
    lng = (end - 1) // 2
    sieve = [False] * (lng + 1)

    x_max, x2, xd = int(sqrt((end - 1) / 4.0)), 0, 4
    for xd in range(4, 8 * x_max + 2, 8):
        x2 += xd
        y_max = int(sqrt(end - x2))
        n, n_diff = x2 + y_max * y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in range((n_diff - 1) << 1, -1, -8):
            m = n % 12
            if m == 1 or m == 5:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, x2, xd = int(sqrt((end - 1) / 3.0)), 0, 3
    for xd in range(3, 6 * x_max + 2, 6):
        x2 += xd
        y_max = int(sqrt(end - x2))
        n, n_diff = x2 + y_max * y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in range((n_diff - 1) << 1, -1, -8):
            if n % 12 == 7:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, y_min, x2, xd = int((2 + sqrt(4 - 8 * (1 - end))) / 4), -1, 0, 3
    for x in range(1, x_max + 1):
        x2 += xd
        xd += 6
        if x2 >= end:
            y_min = (((int(ceil(sqrt(x2 - end))) - 1) << 1) - 2) << 1
        n, n_diff = ((x * x + x) << 1) - 1, (((x - 1) << 1) - 2) << 1
        for d in range(n_diff, y_min, -8):
            if n % 12 == 11:
                m = n >> 1
                sieve[m] = not sieve[m]
            n += d

    primes = [2, 3]
    if end <= 3:
        return primes[: max(0, end - 2)]

    for n in range(5 >> 1, (int(sqrt(end)) + 1) >> 1):
        if sieve[n]:
            primes.append((n << 1) + 1)
            aux = (n << 1) + 1
            aux *= aux
            for k in range(aux, end, 2 * aux):
                sieve[k >> 1] = False

    s = int(sqrt(end)) + 1
    if s % 2 == 0:
        s += 1
    primes.extend([i for i in range(s, end, 2) if sieve[i >> 1]])

    return primes


def ambi_sieve_plain(n):
    s = list(range(3, n, 2))
    for m in range(3, int(n ** 0.5) + 1, 2):
        if s[(m - 3) // 2]:
            for t in range((m * m - 3) // 2, (n >> 1) - 1, m):
                s[t] = 0
    return [2] + [t for t in s if t > 0]


def sundaram3(max_n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/2073279#2073279
    numbers = range(3, max_n + 1, 2)
    half = (max_n) // 2
    initial = 4

    for step in range(3, max_n + 1, 2):
        for i in range(initial, half, step):
            numbers[i - 1] = 0
        initial += 2 * (step + 1)

        if initial > half:
            return [2] + filter(None, numbers)


# Using Numpy:
def ambi_sieve(n):
    # http://tommih.blogspot.com/2009/04/fast-prime-number-generator.html
    s = np.arange(3, n, 2)
    for m in range(3, int(n ** 0.5) + 1, 2):
        if s[(m - 3) // 2]:
            s[(m * m - 3) // 2::m] = 0
    return np.r_[2, s[s > 0]]


def primesfrom3to(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns an array of primes, p < n """
    assert n >= 2
    sieve = np.ones(n // 2, dtype=np.bool)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = False
    return np.r_[2, 2 * np.nonzero(sieve)[0][1::] + 1]


def primesfrom2to(n):
    # /programming/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns an array of primes, 2 <= p < n """
    assert n >= 6
    sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool)
    sieve[0] = False
    for i in range(int(n ** 0.5) // 3 + 1):
        if sieve[i]:
            k = 3 * i + 1 | 1
            sieve[((k * k) // 3)::2 * k] = False
            sieve[(k * k + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False
    return np.r_[2, 3, ((3 * np.nonzero(sieve)[0] + 1) | 1)]


def sympy_sieve(n):
    return list(sympy.sieve.primerange(1, n))


perfplot.save(
    "prime.png",
    setup=lambda n: n,
    kernels=[
        rwh_primes,
        rwh_primes1,
        rwh_primes2,
        sieve_wheel_30,
        sieve_of_eratosthenes,
        sieve_of_atkin,
        # ambi_sieve_plain,
        # sundaram3,
        ambi_sieve,
        primesfrom3to,
        primesfrom2to,
        sympy_sieve,
    ],
    n_range=[2 ** k for k in range(3, 25)],
    logx=True,
    logy=True,
    xlabel="n",
)

I’ve updated much of the code for Python 3 and threw it at perfplot (a project of mine) to see which is actually fastest. Turns out that, for large n, primesfrom{2,3}to take the cake:

enter image description here


Code to reproduce the plot:

import perfplot
from math import sqrt, ceil
import numpy as np
import sympy


def rwh_primes(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * n
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i]:
            sieve[i * i::2 * i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
    return [2] + [i for i in range(3, n, 2) if sieve[i]]


def rwh_primes1(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns  a list of primes < n """
    sieve = [True] * (n // 2)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
    return [2] + [2 * i + 1 for i in range(1, n // 2) if sieve[i]]


def rwh_primes2(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """Input n>=6, Returns a list of primes, 2 <= p < n"""
    assert n >= 6
    correction = n % 6 > 1
    n = {0: n, 1: n - 1, 2: n + 4, 3: n + 3, 4: n + 2, 5: n + 1}[n % 6]
    sieve = [True] * (n // 3)
    sieve[0] = False
    for i in range(int(n ** 0.5) // 3 + 1):
        if sieve[i]:
            k = 3 * i + 1 | 1
            sieve[((k * k) // 3)::2 * k] = [False] * (
                (n // 6 - (k * k) // 6 - 1) // k + 1
            )
            sieve[(k * k + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = [False] * (
                (n // 6 - (k * k + 4 * k - 2 * k * (i & 1)) // 6 - 1) // k + 1
            )
    return [2, 3] + [3 * i + 1 | 1 for i in range(1, n // 3 - correction) if sieve[i]]


def sieve_wheel_30(N):
    # http://zerovolt.com/?p=88
    """ Returns a list of primes <= N using wheel criterion 2*3*5 = 30

Copyright 2009 by zerovolt.com
This code is free for non-commercial purposes, in which case you can just leave this comment as a credit for my work.
If you need this code for commercial purposes, please contact me by sending an email to: info [at] zerovolt [dot] com."""
    __smallp = (
        2,
        3,
        5,
        7,
        11,
        13,
        17,
        19,
        23,
        29,
        31,
        37,
        41,
        43,
        47,
        53,
        59,
        61,
        67,
        71,
        73,
        79,
        83,
        89,
        97,
        101,
        103,
        107,
        109,
        113,
        127,
        131,
        137,
        139,
        149,
        151,
        157,
        163,
        167,
        173,
        179,
        181,
        191,
        193,
        197,
        199,
        211,
        223,
        227,
        229,
        233,
        239,
        241,
        251,
        257,
        263,
        269,
        271,
        277,
        281,
        283,
        293,
        307,
        311,
        313,
        317,
        331,
        337,
        347,
        349,
        353,
        359,
        367,
        373,
        379,
        383,
        389,
        397,
        401,
        409,
        419,
        421,
        431,
        433,
        439,
        443,
        449,
        457,
        461,
        463,
        467,
        479,
        487,
        491,
        499,
        503,
        509,
        521,
        523,
        541,
        547,
        557,
        563,
        569,
        571,
        577,
        587,
        593,
        599,
        601,
        607,
        613,
        617,
        619,
        631,
        641,
        643,
        647,
        653,
        659,
        661,
        673,
        677,
        683,
        691,
        701,
        709,
        719,
        727,
        733,
        739,
        743,
        751,
        757,
        761,
        769,
        773,
        787,
        797,
        809,
        811,
        821,
        823,
        827,
        829,
        839,
        853,
        857,
        859,
        863,
        877,
        881,
        883,
        887,
        907,
        911,
        919,
        929,
        937,
        941,
        947,
        953,
        967,
        971,
        977,
        983,
        991,
        997,
    )
    # wheel = (2, 3, 5)
    const = 30
    if N < 2:
        return []
    if N <= const:
        pos = 0
        while __smallp[pos] <= N:
            pos += 1
        return list(__smallp[:pos])
    # make the offsets list
    offsets = (7, 11, 13, 17, 19, 23, 29, 1)
    # prepare the list
    p = [2, 3, 5]
    dim = 2 + N // const
    tk1 = [True] * dim
    tk7 = [True] * dim
    tk11 = [True] * dim
    tk13 = [True] * dim
    tk17 = [True] * dim
    tk19 = [True] * dim
    tk23 = [True] * dim
    tk29 = [True] * dim
    tk1[0] = False
    # help dictionary d
    # d[a , b] = c  ==> if I want to find the smallest useful multiple of (30*pos)+a
    # on tkc, then I need the index given by the product of [(30*pos)+a][(30*pos)+b]
    # in general. If b < a, I need [(30*pos)+a][(30*(pos+1))+b]
    d = {}
    for x in offsets:
        for y in offsets:
            res = (x * y) % const
            if res in offsets:
                d[(x, res)] = y
    # another help dictionary: gives tkx calling tmptk[x]
    tmptk = {1: tk1, 7: tk7, 11: tk11, 13: tk13, 17: tk17, 19: tk19, 23: tk23, 29: tk29}
    pos, prime, lastadded, stop = 0, 0, 0, int(ceil(sqrt(N)))

    # inner functions definition
    def del_mult(tk, start, step):
        for k in range(start, len(tk), step):
            tk[k] = False

    # end of inner functions definition
    cpos = const * pos
    while prime < stop:
        # 30k + 7
        if tk7[pos]:
            prime = cpos + 7
            p.append(prime)
            lastadded = 7
            for off in offsets:
                tmp = d[(7, off)]
                start = (
                    (pos + prime)
                    if off == 7
                    else (prime * (const * (pos + 1 if tmp < 7 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 11
        if tk11[pos]:
            prime = cpos + 11
            p.append(prime)
            lastadded = 11
            for off in offsets:
                tmp = d[(11, off)]
                start = (
                    (pos + prime)
                    if off == 11
                    else (prime * (const * (pos + 1 if tmp < 11 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 13
        if tk13[pos]:
            prime = cpos + 13
            p.append(prime)
            lastadded = 13
            for off in offsets:
                tmp = d[(13, off)]
                start = (
                    (pos + prime)
                    if off == 13
                    else (prime * (const * (pos + 1 if tmp < 13 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 17
        if tk17[pos]:
            prime = cpos + 17
            p.append(prime)
            lastadded = 17
            for off in offsets:
                tmp = d[(17, off)]
                start = (
                    (pos + prime)
                    if off == 17
                    else (prime * (const * (pos + 1 if tmp < 17 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 19
        if tk19[pos]:
            prime = cpos + 19
            p.append(prime)
            lastadded = 19
            for off in offsets:
                tmp = d[(19, off)]
                start = (
                    (pos + prime)
                    if off == 19
                    else (prime * (const * (pos + 1 if tmp < 19 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 23
        if tk23[pos]:
            prime = cpos + 23
            p.append(prime)
            lastadded = 23
            for off in offsets:
                tmp = d[(23, off)]
                start = (
                    (pos + prime)
                    if off == 23
                    else (prime * (const * (pos + 1 if tmp < 23 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # 30k + 29
        if tk29[pos]:
            prime = cpos + 29
            p.append(prime)
            lastadded = 29
            for off in offsets:
                tmp = d[(29, off)]
                start = (
                    (pos + prime)
                    if off == 29
                    else (prime * (const * (pos + 1 if tmp < 29 else 0) + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
        # now we go back to top tk1, so we need to increase pos by 1
        pos += 1
        cpos = const * pos
        # 30k + 1
        if tk1[pos]:
            prime = cpos + 1
            p.append(prime)
            lastadded = 1
            for off in offsets:
                tmp = d[(1, off)]
                start = (
                    (pos + prime)
                    if off == 1
                    else (prime * (const * pos + tmp)) // const
                )
                del_mult(tmptk[off], start, prime)
    # time to add remaining primes
    # if lastadded == 1, remove last element and start adding them from tk1
    # this way we don't need an "if" within the last while
    if lastadded == 1:
        p.pop()
    # now complete for every other possible prime
    while pos < len(tk1):
        cpos = const * pos
        if tk1[pos]:
            p.append(cpos + 1)
        if tk7[pos]:
            p.append(cpos + 7)
        if tk11[pos]:
            p.append(cpos + 11)
        if tk13[pos]:
            p.append(cpos + 13)
        if tk17[pos]:
            p.append(cpos + 17)
        if tk19[pos]:
            p.append(cpos + 19)
        if tk23[pos]:
            p.append(cpos + 23)
        if tk29[pos]:
            p.append(cpos + 29)
        pos += 1
    # remove exceeding if present
    pos = len(p) - 1
    while p[pos] > N:
        pos -= 1
    if pos < len(p) - 1:
        del p[pos + 1 :]
    # return p list
    return p


def sieve_of_eratosthenes(n):
    """sieveOfEratosthenes(n): return the list of the primes < n."""
    # Code from: <dickinsm@gmail.com>, Nov 30 2006
    # http://groups.google.com/group/comp.lang.python/msg/f1f10ced88c68c2d
    if n <= 2:
        return []
    sieve = list(range(3, n, 2))
    top = len(sieve)
    for si in sieve:
        if si:
            bottom = (si * si - 3) // 2
            if bottom >= top:
                break
            sieve[bottom::si] = [0] * -((bottom - top) // si)
    return [2] + [el for el in sieve if el]


def sieve_of_atkin(end):
    """return a list of all the prime numbers <end using the Sieve of Atkin."""
    # Code by Steve Krenzel, <Sgk284@gmail.com>, improved
    # Code: https://web.archive.org/web/20080324064651/http://krenzel.info/?p=83
    # Info: http://en.wikipedia.org/wiki/Sieve_of_Atkin
    assert end > 0
    lng = (end - 1) // 2
    sieve = [False] * (lng + 1)

    x_max, x2, xd = int(sqrt((end - 1) / 4.0)), 0, 4
    for xd in range(4, 8 * x_max + 2, 8):
        x2 += xd
        y_max = int(sqrt(end - x2))
        n, n_diff = x2 + y_max * y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in range((n_diff - 1) << 1, -1, -8):
            m = n % 12
            if m == 1 or m == 5:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, x2, xd = int(sqrt((end - 1) / 3.0)), 0, 3
    for xd in range(3, 6 * x_max + 2, 6):
        x2 += xd
        y_max = int(sqrt(end - x2))
        n, n_diff = x2 + y_max * y_max, (y_max << 1) - 1
        if not (n & 1):
            n -= n_diff
            n_diff -= 2
        for d in range((n_diff - 1) << 1, -1, -8):
            if n % 12 == 7:
                m = n >> 1
                sieve[m] = not sieve[m]
            n -= d

    x_max, y_min, x2, xd = int((2 + sqrt(4 - 8 * (1 - end))) / 4), -1, 0, 3
    for x in range(1, x_max + 1):
        x2 += xd
        xd += 6
        if x2 >= end:
            y_min = (((int(ceil(sqrt(x2 - end))) - 1) << 1) - 2) << 1
        n, n_diff = ((x * x + x) << 1) - 1, (((x - 1) << 1) - 2) << 1
        for d in range(n_diff, y_min, -8):
            if n % 12 == 11:
                m = n >> 1
                sieve[m] = not sieve[m]
            n += d

    primes = [2, 3]
    if end <= 3:
        return primes[: max(0, end - 2)]

    for n in range(5 >> 1, (int(sqrt(end)) + 1) >> 1):
        if sieve[n]:
            primes.append((n << 1) + 1)
            aux = (n << 1) + 1
            aux *= aux
            for k in range(aux, end, 2 * aux):
                sieve[k >> 1] = False

    s = int(sqrt(end)) + 1
    if s % 2 == 0:
        s += 1
    primes.extend([i for i in range(s, end, 2) if sieve[i >> 1]])

    return primes


def ambi_sieve_plain(n):
    s = list(range(3, n, 2))
    for m in range(3, int(n ** 0.5) + 1, 2):
        if s[(m - 3) // 2]:
            for t in range((m * m - 3) // 2, (n >> 1) - 1, m):
                s[t] = 0
    return [2] + [t for t in s if t > 0]


def sundaram3(max_n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/2073279#2073279
    numbers = range(3, max_n + 1, 2)
    half = (max_n) // 2
    initial = 4

    for step in range(3, max_n + 1, 2):
        for i in range(initial, half, step):
            numbers[i - 1] = 0
        initial += 2 * (step + 1)

        if initial > half:
            return [2] + filter(None, numbers)


# Using Numpy:
def ambi_sieve(n):
    # http://tommih.blogspot.com/2009/04/fast-prime-number-generator.html
    s = np.arange(3, n, 2)
    for m in range(3, int(n ** 0.5) + 1, 2):
        if s[(m - 3) // 2]:
            s[(m * m - 3) // 2::m] = 0
    return np.r_[2, s[s > 0]]


def primesfrom3to(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Returns an array of primes, p < n """
    assert n >= 2
    sieve = np.ones(n // 2, dtype=np.bool)
    for i in range(3, int(n ** 0.5) + 1, 2):
        if sieve[i // 2]:
            sieve[i * i // 2::i] = False
    return np.r_[2, 2 * np.nonzero(sieve)[0][1::] + 1]


def primesfrom2to(n):
    # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
    """ Input n>=6, Returns an array of primes, 2 <= p < n """
    assert n >= 6
    sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool)
    sieve[0] = False
    for i in range(int(n ** 0.5) // 3 + 1):
        if sieve[i]:
            k = 3 * i + 1 | 1
            sieve[((k * k) // 3)::2 * k] = False
            sieve[(k * k + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False
    return np.r_[2, 3, ((3 * np.nonzero(sieve)[0] + 1) | 1)]


def sympy_sieve(n):
    return list(sympy.sieve.primerange(1, n))


perfplot.save(
    "prime.png",
    setup=lambda n: n,
    kernels=[
        rwh_primes,
        rwh_primes1,
        rwh_primes2,
        sieve_wheel_30,
        sieve_of_eratosthenes,
        sieve_of_atkin,
        # ambi_sieve_plain,
        # sundaram3,
        ambi_sieve,
        primesfrom3to,
        primesfrom2to,
        sympy_sieve,
    ],
    n_range=[2 ** k for k in range(3, 25)],
    logx=True,
    logy=True,
    xlabel="n",
)

回答 23

我的猜测是最快所有方式中的就是对代码中的素数进行硬编码。

因此,为什么不编写一个缓慢的脚本来生成另一个包含所有数字的源文件,然后在运行实际程序时导入该源文件。

当然,仅当您在编译时知道N的上限时,此方法才有效,但是(几乎)所有项目Euler问题都是如此。

 

PS: 我可能错了,尽管使用硬连线素数解析源代码比首先计算它们要慢,但是据我所知Python运行于编译.pyc文件中,因此读取所有素数不超过N的二进制数组应该很血腥在这种情况下很快。

My guess is that the fastest of all ways is to hard code the primes in your code.

So why not just write a slow script that generates another source file that has all numbers hardwired in it, and then import that source file when you run your actual program.

Of course, this works only if you know the upper bound of N at compile time, but thus is the case for (almost) all project Euler problems.

 

PS: I might be wrong though iff parsing the source with hard-wired primes is slower than computing them in the first place, but as far I know Python runs from compiled .pyc files so reading a binary array with all primes up to N should be bloody fast in that case.


回答 24

很抱歉打扰,但是erat2()在算法中存在严重缺陷。

在搜索下一个复合词时,我们仅需要测试奇数。q,p都是奇数;那么q + p是偶数,不需要测试,但是q + 2 * p总是奇数。这消除了while循环条件下的“ if even”测试,并节省了大约30%的运行时间。

当我们讨论它时:使用优雅的’D.pop(q,None)’获取和删除方法,而不是’if q in D:p = D [q],del D [q]’,它快两倍!至少在我的机器上(P3-1Ghz)。所以我建议这种聪明算法的实现:

def erat3( ):
    from itertools import islice, count

    # q is the running integer that's checked for primeness.
    # yield 2 and no other even number thereafter
    yield 2
    D = {}
    # no need to mark D[4] as we will test odd numbers only
    for q in islice(count(3),0,None,2):
        if q in D:                  #  is composite
            p = D[q]
            del D[q]
            # q is composite. p=D[q] is the first prime that
            # divides it. Since we've reached q, we no longer
            # need it in the map, but we'll mark the next
            # multiple of its witnesses to prepare for larger
            # numbers.
            x = q + p+p        # next odd(!) multiple
            while x in D:      # skip composites
                x += p+p
            D[x] = p
        else:                  # is prime
            # q is a new prime.
            # Yield it and mark its first multiple that isn't
            # already marked in previous iterations.
            D[q*q] = q
            yield q

Sorry to bother but erat2() has a serious flaw in the algorithm.

While searching for the next composite, we need to test odd numbers only. q,p both are odd; then q+p is even and doesn’t need to be tested, but q+2*p is always odd. This eliminates the “if even” test in the while loop condition and saves about 30% of the runtime.

While we’re at it: instead of the elegant ‘D.pop(q,None)’ get and delete method use ‘if q in D: p=D[q],del D[q]’ which is twice as fast! At least on my machine (P3-1Ghz). So I suggest this implementation of this clever algorithm:

def erat3( ):
    from itertools import islice, count

    # q is the running integer that's checked for primeness.
    # yield 2 and no other even number thereafter
    yield 2
    D = {}
    # no need to mark D[4] as we will test odd numbers only
    for q in islice(count(3),0,None,2):
        if q in D:                  #  is composite
            p = D[q]
            del D[q]
            # q is composite. p=D[q] is the first prime that
            # divides it. Since we've reached q, we no longer
            # need it in the map, but we'll mark the next
            # multiple of its witnesses to prepare for larger
            # numbers.
            x = q + p+p        # next odd(!) multiple
            while x in D:      # skip composites
                x += p+p
            D[x] = p
        else:                  # is prime
            # q is a new prime.
            # Yield it and mark its first multiple that isn't
            # already marked in previous iterations.
            D[q*q] = q
            yield q

回答 25

到目前为止,我尝试过的最快方法是基于Python Cookbookerat2函数:

import itertools as it
def erat2a( ):
    D = {  }
    yield 2
    for q in it.islice(it.count(3), 0, None, 2):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            x = q + 2*p
            while x in D:
                x += 2*p
            D[x] = p

有关加速的说明,请参见此答案。

The fastest method I’ve tried so far is based on the Python cookbook erat2 function:

import itertools as it
def erat2a( ):
    D = {  }
    yield 2
    for q in it.islice(it.count(3), 0, None, 2):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            x = q + 2*p
            while x in D:
                x += 2*p
            D[x] = p

See this answer for an explanation of the speeding-up.


回答 26

我可能聚会晚了,但是必须为此添加我自己的代码。它占用的空间大约为n / 2,因为我们不需要存储偶数,而且我还使用了bitarray python模块,从而进一步大大减少了内存消耗,并使所有素数的计算能力高达1,000,000,000

from bitarray import bitarray
def primes_to(n):
    size = n//2
    sieve = bitarray(size)
    sieve.setall(1)
    limit = int(n**0.5)
    for i in range(1,limit):
        if sieve[i]:
            val = 2*i+1
            sieve[(i+i*val)::val] = 0
    return [2] + [2*i+1 for i, v in enumerate(sieve) if v and i > 0]

python -m timeit -n10 -s "import euler" "euler.primes_to(1000000000)"
10 loops, best of 3: 46.5 sec per loop

这是在64位2.4GHZ MAC OSX 10.8.3上运行的

I may be late to the party but will have to add my own code for this. It uses approximately n/2 in space because we don’t need to store even numbers and I also make use of the bitarray python module, further draStically cutting down on memory consumption and enabling computing all primes up to 1,000,000,000

from bitarray import bitarray
def primes_to(n):
    size = n//2
    sieve = bitarray(size)
    sieve.setall(1)
    limit = int(n**0.5)
    for i in range(1,limit):
        if sieve[i]:
            val = 2*i+1
            sieve[(i+i*val)::val] = 0
    return [2] + [2*i+1 for i, v in enumerate(sieve) if v and i > 0]

python -m timeit -n10 -s "import euler" "euler.primes_to(1000000000)"
10 loops, best of 3: 46.5 sec per loop

This was run on a 64bit 2.4GHZ MAC OSX 10.8.3


回答 27

随着时间的推移,我收集了几个质数筛。我的计算机上最快的是:

from time import time
# 175 ms for all the primes up to the value 10**6
def primes_sieve(limit):
    a = [True] * limit
    a[0] = a[1] = False
    #a[2] = True
    for n in xrange(4, limit, 2):
        a[n] = False
    root_limit = int(limit**.5)+1
    for i in xrange(3,root_limit):
        if a[i]:
            for n in xrange(i*i, limit, 2*i):
                a[n] = False
    return a

LIMIT = 10**6
s=time()
primes = primes_sieve(LIMIT)
print time()-s

I collected several prime number sieves over time. The fastest on my computer is this:

from time import time
# 175 ms for all the primes up to the value 10**6
def primes_sieve(limit):
    a = [True] * limit
    a[0] = a[1] = False
    #a[2] = True
    for n in xrange(4, limit, 2):
        a[n] = False
    root_limit = int(limit**.5)+1
    for i in xrange(3,root_limit):
        if a[i]:
            for n in xrange(i*i, limit, 2*i):
                a[n] = False
    return a

LIMIT = 10**6
s=time()
primes = primes_sieve(LIMIT)
print time()-s

回答 28

我对这个问题的回答很慢,但这似乎很有趣。我正在使用numpy,这可能会作弊,我怀疑这种方法是最快的,但应该清楚。它筛选仅引用其索引的布尔数组,并从所有True值的索引中引出质数。无需取模。

import numpy as np
def ajs_primes3a(upto):
    mat = np.ones((upto), dtype=bool)
    mat[0] = False
    mat[1] = False
    mat[4::2] = False
    for idx in range(3, int(upto ** 0.5)+1, 2):
        mat[idx*2::idx] = False
    return np.where(mat == True)[0]

I’m slow responding to this question but it seemed like a fun exercise. I’m using numpy which might be cheating and I doubt this method is the fastest but it should be clear. It sieves a Boolean array referring to its indices only and elicits prime numbers from the indices of all True values. No modulo needed.

import numpy as np
def ajs_primes3a(upto):
    mat = np.ones((upto), dtype=bool)
    mat[0] = False
    mat[1] = False
    mat[4::2] = False
    for idx in range(3, int(upto ** 0.5)+1, 2):
        mat[idx*2::idx] = False
    return np.where(mat == True)[0]

回答 29

这是一种有趣的技术,可以使用python的列表推导来生成质数(但不是最有效的):

noprimes = [j for i in range(2, 8) for j in range(i*2, 50, i)]
primes = [x for x in range(2, 50) if x not in noprimes]

您可以在此处找到示例和一些说明

Here is an interesting technique to generate prime numbers (yet not the most efficient) using python’s list comprehensions:

noprimes = [j for i in range(2, 8) for j in range(i*2, 50, i)]
primes = [x for x in range(2, 50) if x not in noprimes]

You can find the example and some explanations right here


迭代访问列表的最“ pythonic”方法是什么?

问题:迭代访问列表的最“ pythonic”方法是什么?

我有一个Python脚本,它将一个整数列表作为输入,我需要一次处理四个整数。不幸的是,我无法控制输入,或者将其作为四元素元组的列表传递。目前,我正在以这种方式对其进行迭代:

for i in xrange(0, len(ints), 4):
    # dummy op for example code
    foo += ints[i] * ints[i + 1] + ints[i + 2] * ints[i + 3]

不过,它看起来很像“ C思维”,这使我怀疑还有一种处理这种情况的更Python的方法。该列表在迭代后被丢弃,因此不需要保留。也许这样的事情会更好?

while ints:
    foo += ints[0] * ints[1] + ints[2] * ints[3]
    ints[0:4] = []

不过,还是不太“正确”。:-/

相关问题:如何在Python中将列表分成均匀大小的块?

I have a Python script which takes as input a list of integers, which I need to work with four integers at a time. Unfortunately, I don’t have control of the input, or I’d have it passed in as a list of four-element tuples. Currently, I’m iterating over it this way:

for i in xrange(0, len(ints), 4):
    # dummy op for example code
    foo += ints[i] * ints[i + 1] + ints[i + 2] * ints[i + 3]

It looks a lot like “C-think”, though, which makes me suspect there’s a more pythonic way of dealing with this situation. The list is discarded after iterating, so it needn’t be preserved. Perhaps something like this would be better?

while ints:
    foo += ints[0] * ints[1] + ints[2] * ints[3]
    ints[0:4] = []

Still doesn’t quite “feel” right, though. :-/

Related question: How do you split a list into evenly sized chunks in Python?


回答 0

从Python的itertools文档的食谱部分进行了修改:

from itertools import zip_longest

def grouper(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

示例
用伪代码保持示例简洁。

grouper('ABCDEFG', 3, 'x') --> 'ABC' 'DEF' 'Gxx'

注意:在Python 2上,请使用izip_longest代替zip_longest

Modified from the recipes section of Python’s itertools docs:

from itertools import zip_longest

def grouper(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

Example
In pseudocode to keep the example terse.

grouper('ABCDEFG', 3, 'x') --> 'ABC' 'DEF' 'Gxx'

Note: on Python 2 use izip_longest instead of zip_longest.


回答 1

def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
# (in python 2 use xrange() instead of range() to avoid allocating a list)

简单。简单。快速。适用于任何序列:

text = "I am a very, very helpful text"

for group in chunker(text, 7):
   print repr(group),
# 'I am a ' 'very, v' 'ery hel' 'pful te' 'xt'

print '|'.join(chunker(text, 10))
# I am a ver|y, very he|lpful text

animals = ['cat', 'dog', 'rabbit', 'duck', 'bird', 'cow', 'gnu', 'fish']

for group in chunker(animals, 3):
    print group
# ['cat', 'dog', 'rabbit']
# ['duck', 'bird', 'cow']
# ['gnu', 'fish']
def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
# (in python 2 use xrange() instead of range() to avoid allocating a list)

Simple. Easy. Fast. Works with any sequence:

text = "I am a very, very helpful text"

for group in chunker(text, 7):
   print repr(group),
# 'I am a ' 'very, v' 'ery hel' 'pful te' 'xt'

print '|'.join(chunker(text, 10))
# I am a ver|y, very he|lpful text

animals = ['cat', 'dog', 'rabbit', 'duck', 'bird', 'cow', 'gnu', 'fish']

for group in chunker(animals, 3):
    print group
# ['cat', 'dog', 'rabbit']
# ['duck', 'bird', 'cow']
# ['gnu', 'fish']

回答 2

我是的粉丝

chunk_size= 4
for i in range(0, len(ints), chunk_size):
    chunk = ints[i:i+chunk_size]
    # process chunk of size <= chunk_size

I’m a fan of

chunk_size= 4
for i in range(0, len(ints), chunk_size):
    chunk = ints[i:i+chunk_size]
    # process chunk of size <= chunk_size

回答 3

import itertools
def chunks(iterable,size):
    it = iter(iterable)
    chunk = tuple(itertools.islice(it,size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

# though this will throw ValueError if the length of ints
# isn't a multiple of four:
for x1,x2,x3,x4 in chunks(ints,4):
    foo += x1 + x2 + x3 + x4

for chunk in chunks(ints,4):
    foo += sum(chunk)

另一种方式:

import itertools
def chunks2(iterable,size,filler=None):
    it = itertools.chain(iterable,itertools.repeat(filler,size-1))
    chunk = tuple(itertools.islice(it,size))
    while len(chunk) == size:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

# x2, x3 and x4 could get the value 0 if the length is not
# a multiple of 4.
for x1,x2,x3,x4 in chunks2(ints,4,0):
    foo += x1 + x2 + x3 + x4
import itertools
def chunks(iterable,size):
    it = iter(iterable)
    chunk = tuple(itertools.islice(it,size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

# though this will throw ValueError if the length of ints
# isn't a multiple of four:
for x1,x2,x3,x4 in chunks(ints,4):
    foo += x1 + x2 + x3 + x4

for chunk in chunks(ints,4):
    foo += sum(chunk)

Another way:

import itertools
def chunks2(iterable,size,filler=None):
    it = itertools.chain(iterable,itertools.repeat(filler,size-1))
    chunk = tuple(itertools.islice(it,size))
    while len(chunk) == size:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

# x2, x3 and x4 could get the value 0 if the length is not
# a multiple of 4.
for x1,x2,x3,x4 in chunks2(ints,4,0):
    foo += x1 + x2 + x3 + x4

回答 4

from itertools import izip_longest

def chunker(iterable, chunksize, filler):
    return izip_longest(*[iter(iterable)]*chunksize, fillvalue=filler)
from itertools import izip_longest

def chunker(iterable, chunksize, filler):
    return izip_longest(*[iter(iterable)]*chunksize, fillvalue=filler)

回答 5

此问题的理想解决方案适用于迭代器(而不仅仅是序列)。它也应该很快。

这是itertools文档提供的解决方案:

def grouper(n, iterable, fillvalue=None):
    #"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return itertools.izip_longest(fillvalue=fillvalue, *args)

%timeit在我的Macbook Air 上使用ipython ,每个循环可获得47.5美元。

但是,这对我来说真的不起作用,因为结果被填充为甚至大小的组。没有填充的解决方案稍微复杂一些。最幼稚的解决方案可能是:

def grouper(size, iterable):
    i = iter(iterable)
    while True:
        out = []
        try:
            for _ in range(size):
                out.append(i.next())
        except StopIteration:
            yield out
            break

        yield out

简单但很慢:每个循环693 us

我可以想出的最佳解决方案islice用于内部循环:

def grouper(size, iterable):
    it = iter(iterable)
    while True:
        group = tuple(itertools.islice(it, None, size))
        if not group:
            break
        yield group

使用相同的数据集,每个循环可获得305 us。

无法以比这更快的速度获得纯解决方案,我为以下解决方案提供了一个重要的警告:如果输入数据中包含实例,filldata则可能会得到错误的答案。

def grouper(n, iterable, fillvalue=None):
    #"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    for i in itertools.izip_longest(fillvalue=fillvalue, *args):
        if tuple(i)[-1] == fillvalue:
            yield tuple(v for v in i if v != fillvalue)
        else:
            yield i

我真的不喜欢这个答案,但是速度更快。每个循环124 us

The ideal solution for this problem works with iterators (not just sequences). It should also be fast.

This is the solution provided by the documentation for itertools:

def grouper(n, iterable, fillvalue=None):
    #"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return itertools.izip_longest(fillvalue=fillvalue, *args)

Using ipython’s %timeit on my mac book air, I get 47.5 us per loop.

However, this really doesn’t work for me since the results are padded to be even sized groups. A solution without the padding is slightly more complicated. The most naive solution might be:

def grouper(size, iterable):
    i = iter(iterable)
    while True:
        out = []
        try:
            for _ in range(size):
                out.append(i.next())
        except StopIteration:
            yield out
            break

        yield out

Simple, but pretty slow: 693 us per loop

The best solution I could come up with uses islice for the inner loop:

def grouper(size, iterable):
    it = iter(iterable)
    while True:
        group = tuple(itertools.islice(it, None, size))
        if not group:
            break
        yield group

With the same dataset, I get 305 us per loop.

Unable to get a pure solution any faster than that, I provide the following solution with an important caveat: If your input data has instances of filldata in it, you could get wrong answer.

def grouper(n, iterable, fillvalue=None):
    #"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    for i in itertools.izip_longest(fillvalue=fillvalue, *args):
        if tuple(i)[-1] == fillvalue:
            yield tuple(v for v in i if v != fillvalue)
        else:
            yield i

I really don’t like this answer, but it is significantly faster. 124 us per loop


回答 6

我需要一个可以与集合和生成器一起使用的解决方案。我无法提出任何简短而又漂亮的内容,但至少可以理解。

def chunker(seq, size):
    res = []
    for el in seq:
        res.append(el)
        if len(res) == size:
            yield res
            res = []
    if res:
        yield res

清单:

>>> list(chunker([i for i in range(10)], 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

组:

>>> list(chunker(set([i for i in range(10)]), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

生成器:

>>> list(chunker((i for i in range(10)), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

I needed a solution that would also work with sets and generators. I couldn’t come up with anything very short and pretty, but it’s quite readable at least.

def chunker(seq, size):
    res = []
    for el in seq:
        res.append(el)
        if len(res) == size:
            yield res
            res = []
    if res:
        yield res

List:

>>> list(chunker([i for i in range(10)], 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Set:

>>> list(chunker(set([i for i in range(10)]), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Generator:

>>> list(chunker((i for i in range(10)), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

回答 7

与其他建议类似,但不完全相同,我喜欢这样做,因为它简单易读:

it = iter([1, 2, 3, 4, 5, 6, 7, 8, 9])
for chunk in zip(it, it, it, it):
    print chunk

>>> (1, 2, 3, 4)
>>> (5, 6, 7, 8)

这样,您将不会得到最后的部分块。如果要获取(9, None, None, None)最后一块,只需使用izip_longestfrom即可itertools

Similar to other proposals, but not exactly identical, I like doing it this way, because it’s simple and easy to read:

it = iter([1, 2, 3, 4, 5, 6, 7, 8, 9])
for chunk in zip(it, it, it, it):
    print chunk

>>> (1, 2, 3, 4)
>>> (5, 6, 7, 8)

This way you won’t get the last partial chunk. If you want to get (9, None, None, None) as last chunk, just use izip_longest from itertools.


回答 8

如果您不介意使用外部软件包,则可以iteration_utilities.grouper1开始使用。它支持所有可迭代项(不仅限于序列):iteration_utilties

from iteration_utilities import grouper
seq = list(range(20))
for group in grouper(seq, 4):
    print(group)

打印:

(0, 1, 2, 3)
(4, 5, 6, 7)
(8, 9, 10, 11)
(12, 13, 14, 15)
(16, 17, 18, 19)

如果长度不是分组大小的倍数,则它还支持填充(不完整的最后一组)或截断(丢弃不完整的最后一组)最后一个:

from iteration_utilities import grouper
seq = list(range(17))
for group in grouper(seq, 4):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)
# (16,)

for group in grouper(seq, 4, fillvalue=None):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)
# (16, None, None, None)

for group in grouper(seq, 4, truncate=True):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)

基准测试

我还决定比较上述几种方法的运行时间。这是一个对数-对数图,根据大小不同的列表分为“ 10”个元素组。对于定性结果:较低意味着更快:

在此处输入图片说明

至少在此基准测试中,iteration_utilities.grouper效果最佳。其次是疯狂的方法。

基准是使用1创建的。用于运行该基准测试的代码为:simple_benchmark

import iteration_utilities
import itertools
from itertools import zip_longest

def consume_all(it):
    return iteration_utilities.consume(it, None)

import simple_benchmark
b = simple_benchmark.BenchmarkBuilder()

@b.add_function()
def grouper(l, n):
    return consume_all(iteration_utilities.grouper(l, n))

def Craz_inner(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

@b.add_function()
def Craz(iterable, n, fillvalue=None):
    return consume_all(Craz_inner(iterable, n, fillvalue))

def nosklo_inner(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

@b.add_function()
def nosklo(seq, size):
    return consume_all(nosklo_inner(seq, size))

def SLott_inner(ints, chunk_size):
    for i in range(0, len(ints), chunk_size):
        yield ints[i:i+chunk_size]

@b.add_function()
def SLott(ints, chunk_size):
    return consume_all(SLott_inner(ints, chunk_size))

def MarkusJarderot1_inner(iterable,size):
    it = iter(iterable)
    chunk = tuple(itertools.islice(it,size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

@b.add_function()
def MarkusJarderot1(iterable,size):
    return consume_all(MarkusJarderot1_inner(iterable,size))

def MarkusJarderot2_inner(iterable,size,filler=None):
    it = itertools.chain(iterable,itertools.repeat(filler,size-1))
    chunk = tuple(itertools.islice(it,size))
    while len(chunk) == size:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

@b.add_function()
def MarkusJarderot2(iterable,size):
    return consume_all(MarkusJarderot2_inner(iterable,size))

@b.add_arguments()
def argument_provider():
    for exp in range(2, 20):
        size = 2**exp
        yield size, simple_benchmark.MultiArgument([[0] * size, 10])

r = b.run()

1免责声明:我是图书馆的作者iteration_utilitiessimple_benchmark

If you don’t mind using an external package you could use iteration_utilities.grouper from iteration_utilties 1. It supports all iterables (not just sequences):

from iteration_utilities import grouper
seq = list(range(20))
for group in grouper(seq, 4):
    print(group)

which prints:

(0, 1, 2, 3)
(4, 5, 6, 7)
(8, 9, 10, 11)
(12, 13, 14, 15)
(16, 17, 18, 19)

In case the length isn’t a multiple of the groupsize it also supports filling (the incomplete last group) or truncating (discarding the incomplete last group) the last one:

from iteration_utilities import grouper
seq = list(range(17))
for group in grouper(seq, 4):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)
# (16,)

for group in grouper(seq, 4, fillvalue=None):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)
# (16, None, None, None)

for group in grouper(seq, 4, truncate=True):
    print(group)
# (0, 1, 2, 3)
# (4, 5, 6, 7)
# (8, 9, 10, 11)
# (12, 13, 14, 15)

Benchmarks

I also decided to compare the run-time of a few of the mentioned approaches. It’s a log-log plot grouping into groups of “10” elements based on a list of varying size. For qualitative results: Lower means faster:

enter image description here

At least in this benchmark the iteration_utilities.grouper performs best. Followed by the approach of Craz.

The benchmark was created with simple_benchmark1. The code used to run this benchmark was:

import iteration_utilities
import itertools
from itertools import zip_longest

def consume_all(it):
    return iteration_utilities.consume(it, None)

import simple_benchmark
b = simple_benchmark.BenchmarkBuilder()

@b.add_function()
def grouper(l, n):
    return consume_all(iteration_utilities.grouper(l, n))

def Craz_inner(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

@b.add_function()
def Craz(iterable, n, fillvalue=None):
    return consume_all(Craz_inner(iterable, n, fillvalue))

def nosklo_inner(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

@b.add_function()
def nosklo(seq, size):
    return consume_all(nosklo_inner(seq, size))

def SLott_inner(ints, chunk_size):
    for i in range(0, len(ints), chunk_size):
        yield ints[i:i+chunk_size]

@b.add_function()
def SLott(ints, chunk_size):
    return consume_all(SLott_inner(ints, chunk_size))

def MarkusJarderot1_inner(iterable,size):
    it = iter(iterable)
    chunk = tuple(itertools.islice(it,size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

@b.add_function()
def MarkusJarderot1(iterable,size):
    return consume_all(MarkusJarderot1_inner(iterable,size))

def MarkusJarderot2_inner(iterable,size,filler=None):
    it = itertools.chain(iterable,itertools.repeat(filler,size-1))
    chunk = tuple(itertools.islice(it,size))
    while len(chunk) == size:
        yield chunk
        chunk = tuple(itertools.islice(it,size))

@b.add_function()
def MarkusJarderot2(iterable,size):
    return consume_all(MarkusJarderot2_inner(iterable,size))

@b.add_arguments()
def argument_provider():
    for exp in range(2, 20):
        size = 2**exp
        yield size, simple_benchmark.MultiArgument([[0] * size, 10])

r = b.run()

1 Disclaimer: I’m the author of the libraries iteration_utilities and simple_benchmark.


回答 9

由于没有人提到它,所以这里有一个zip()解决方案:

>>> def chunker(iterable, chunksize):
...     return zip(*[iter(iterable)]*chunksize)

仅当序列的长度始终可被块大小整除或不关心尾随的块时,它才起作用。

例:

>>> s = '1234567890'
>>> chunker(s, 3)
[('1', '2', '3'), ('4', '5', '6'), ('7', '8', '9')]
>>> chunker(s, 4)
[('1', '2', '3', '4'), ('5', '6', '7', '8')]
>>> chunker(s, 5)
[('1', '2', '3', '4', '5'), ('6', '7', '8', '9', '0')]

或使用itertools.izip返回迭代器而不是列表:

>>> from itertools import izip
>>> def chunker(iterable, chunksize):
...     return izip(*[iter(iterable)]*chunksize)

可以使用@▼ZΩΤZΙΟΥ的答案来固定填充:

>>> from itertools import chain, izip, repeat
>>> def chunker(iterable, chunksize, fillvalue=None):
...     it   = chain(iterable, repeat(fillvalue, chunksize-1))
...     args = [it] * chunksize
...     return izip(*args)

Since nobody’s mentioned it yet here’s a zip() solution:

>>> def chunker(iterable, chunksize):
...     return zip(*[iter(iterable)]*chunksize)

It works only if your sequence’s length is always divisible by the chunk size or you don’t care about a trailing chunk if it isn’t.

Example:

>>> s = '1234567890'
>>> chunker(s, 3)
[('1', '2', '3'), ('4', '5', '6'), ('7', '8', '9')]
>>> chunker(s, 4)
[('1', '2', '3', '4'), ('5', '6', '7', '8')]
>>> chunker(s, 5)
[('1', '2', '3', '4', '5'), ('6', '7', '8', '9', '0')]

Or using itertools.izip to return an iterator instead of a list:

>>> from itertools import izip
>>> def chunker(iterable, chunksize):
...     return izip(*[iter(iterable)]*chunksize)

Padding can be fixed using @ΤΖΩΤΖΙΟΥ’s answer:

>>> from itertools import chain, izip, repeat
>>> def chunker(iterable, chunksize, fillvalue=None):
...     it   = chain(iterable, repeat(fillvalue, chunksize-1))
...     args = [it] * chunksize
...     return izip(*args)

回答 10

使用map()而不是zip()可解决JF Sebastian的答案中的填充问题:

>>> def chunker(iterable, chunksize):
...   return map(None,*[iter(iterable)]*chunksize)

例:

>>> s = '1234567890'
>>> chunker(s, 3)
[('1', '2', '3'), ('4', '5', '6'), ('7', '8', '9'), ('0', None, None)]
>>> chunker(s, 4)
[('1', '2', '3', '4'), ('5', '6', '7', '8'), ('9', '0', None, None)]
>>> chunker(s, 5)
[('1', '2', '3', '4', '5'), ('6', '7', '8', '9', '0')]

Using map() instead of zip() fixes the padding issue in J.F. Sebastian’s answer:

>>> def chunker(iterable, chunksize):
...   return map(None,*[iter(iterable)]*chunksize)

Example:

>>> s = '1234567890'
>>> chunker(s, 3)
[('1', '2', '3'), ('4', '5', '6'), ('7', '8', '9'), ('0', None, None)]
>>> chunker(s, 4)
[('1', '2', '3', '4'), ('5', '6', '7', '8'), ('9', '0', None, None)]
>>> chunker(s, 5)
[('1', '2', '3', '4', '5'), ('6', '7', '8', '9', '0')]

回答 11

另一种方法是使用以下两个参数的形式iter

from itertools import islice

def group(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

可以很容易地调整它以使用填充(这类似于Markus Jarderot的回答):

from itertools import islice, chain, repeat

def group_pad(it, size, pad=None):
    it = chain(iter(it), repeat(pad))
    return iter(lambda: tuple(islice(it, size)), (pad,) * size)

这些甚至可以结合起来用于可选的填充:

_no_pad = object()
def group(it, size, pad=_no_pad):
    if pad == _no_pad:
        it = iter(it)
        sentinel = ()
    else:
        it = chain(iter(it), repeat(pad))
        sentinel = (pad,) * size
    return iter(lambda: tuple(islice(it, size)), sentinel)

Another approach would be to use the two-argument form of iter:

from itertools import islice

def group(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

This can be adapted easily to use padding (this is similar to Markus Jarderot’s answer):

from itertools import islice, chain, repeat

def group_pad(it, size, pad=None):
    it = chain(iter(it), repeat(pad))
    return iter(lambda: tuple(islice(it, size)), (pad,) * size)

These can even be combined for optional padding:

_no_pad = object()
def group(it, size, pad=_no_pad):
    if pad == _no_pad:
        it = iter(it)
        sentinel = ()
    else:
        it = chain(iter(it), repeat(pad))
        sentinel = (pad,) * size
    return iter(lambda: tuple(islice(it, size)), sentinel)

回答 12

如果列表很大,执行此操作的最高性能方法是使用生成器:

def get_chunk(iterable, chunk_size):
    result = []
    for item in iterable:
        result.append(item)
        if len(result) == chunk_size:
            yield tuple(result)
            result = []
    if len(result) > 0:
        yield tuple(result)

for x in get_chunk([1,2,3,4,5,6,7,8,9,10], 3):
    print x

(1, 2, 3)
(4, 5, 6)
(7, 8, 9)
(10,)

If the list is large, the highest-performing way to do this will be to use a generator:

def get_chunk(iterable, chunk_size):
    result = []
    for item in iterable:
        result.append(item)
        if len(result) == chunk_size:
            yield tuple(result)
            result = []
    if len(result) > 0:
        yield tuple(result)

for x in get_chunk([1,2,3,4,5,6,7,8,9,10], 3):
    print x

(1, 2, 3)
(4, 5, 6)
(7, 8, 9)
(10,)

回答 13

使用小功能和事情确实对我没有吸引力。我更喜欢只使用切片:

data = [...]
chunk_size = 10000 # or whatever
chunks = [data[i:i+chunk_size] for i in xrange(0,len(data),chunk_size)]
for chunk in chunks:
    ...

Using little functions and things really doesn’t appeal to me; I prefer to just use slices:

data = [...]
chunk_size = 10000 # or whatever
chunks = [data[i:i+chunk_size] for i in xrange(0,len(data),chunk_size)]
for chunk in chunks:
    ...

回答 14

为了避免所有转换为列表,import itertools并且:

>>> for k, g in itertools.groupby(xrange(35), lambda x: x/10):
...     list(g)

生成:

... 
0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1 [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
2 [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
3 [30, 31, 32, 33, 34]
>>> 

我检查了一下groupby,它没有转换为列表或使用len所以我(认为)这将延迟每个值的解析,直到实际使用它为止。可悲的是,目前没有可用的答案似乎提供这种变化。

显然,如果您需要依次处理每个项目,请在g上嵌套for循环:

for k,g in itertools.groupby(xrange(35), lambda x: x/10):
    for i in g:
       # do what you need to do with individual items
    # now do what you need to do with the whole group

我对此的特别兴趣是需要消耗一个生成器才能将最多1000个更改批量提交给gmail API:

    messages = a_generator_which_would_not_be_smart_as_a_list
    for idx, batch in groupby(messages, lambda x: x/1000):
        batch_request = BatchHttpRequest()
        for message in batch:
            batch_request.add(self.service.users().messages().modify(userId='me', id=message['id'], body=msg_labels))
        http = httplib2.Http()
        self.credentials.authorize(http)
        batch_request.execute(http=http)

To avoid all conversions to a list import itertools and:

>>> for k, g in itertools.groupby(xrange(35), lambda x: x/10):
...     list(g)

Produces:

... 
0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1 [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
2 [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
3 [30, 31, 32, 33, 34]
>>> 

I checked groupby and it doesn’t convert to list or use len so I (think) this will delay resolution of each value until it is actually used. Sadly none of the available answers (at this time) seemed to offer this variation.

Obviously if you need to handle each item in turn nest a for loop over g:

for k,g in itertools.groupby(xrange(35), lambda x: x/10):
    for i in g:
       # do what you need to do with individual items
    # now do what you need to do with the whole group

My specific interest in this was the need to consume a generator to submit changes in batches of up to 1000 to the gmail API:

    messages = a_generator_which_would_not_be_smart_as_a_list
    for idx, batch in groupby(messages, lambda x: x/1000):
        batch_request = BatchHttpRequest()
        for message in batch:
            batch_request.add(self.service.users().messages().modify(userId='me', id=message['id'], body=msg_labels))
        http = httplib2.Http()
        self.credentials.authorize(http)
        batch_request.execute(http=http)

回答 15

使用NumPy很简单:

ints = array([1, 2, 3, 4, 5, 6, 7, 8])
for int1, int2 in ints.reshape(-1, 2):
    print(int1, int2)

输出:

1 2
3 4
5 6
7 8

With NumPy it’s simple:

ints = array([1, 2, 3, 4, 5, 6, 7, 8])
for int1, int2 in ints.reshape(-1, 2):
    print(int1, int2)

output:

1 2
3 4
5 6
7 8

回答 16

def chunker(iterable, n):
    """Yield iterable in chunk sizes.

    >>> chunks = chunker('ABCDEF', n=4)
    >>> chunks.next()
    ['A', 'B', 'C', 'D']
    >>> chunks.next()
    ['E', 'F']
    """
    it = iter(iterable)
    while True:
        chunk = []
        for i in range(n):
            try:
                chunk.append(next(it))
            except StopIteration:
                yield chunk
                raise StopIteration
        yield chunk

if __name__ == '__main__':
    import doctest

    doctest.testmod()
def chunker(iterable, n):
    """Yield iterable in chunk sizes.

    >>> chunks = chunker('ABCDEF', n=4)
    >>> chunks.next()
    ['A', 'B', 'C', 'D']
    >>> chunks.next()
    ['E', 'F']
    """
    it = iter(iterable)
    while True:
        chunk = []
        for i in range(n):
            try:
                chunk.append(next(it))
            except StopIteration:
                yield chunk
                raise StopIteration
        yield chunk

if __name__ == '__main__':
    import doctest

    doctest.testmod()

回答 17

除非我错过任何事情,否则不会提及以下带有生成器表达式的简单解决方案。它假定块的大小和数量都是已知的(通常是这种情况),并且不需要填充:

def chunks(it, n, m):
    """Make an iterator over m first chunks of size n.
    """
    it = iter(it)
    # Chunks are presented as tuples.
    return (tuple(next(it) for _ in range(n)) for _ in range(m))

Unless I misses something, the following simple solution with generator expressions has not been mentioned. It assumes that both the size and the number of chunks are known (which is often the case), and that no padding is required:

def chunks(it, n, m):
    """Make an iterator over m first chunks of size n.
    """
    it = iter(it)
    # Chunks are presented as tuples.
    return (tuple(next(it) for _ in range(n)) for _ in range(m))

回答 18

在您的第二种方法中,我将通过执行以下操作进入下一组4:

ints = ints[4:]

但是,我还没有进行任何性能评估,所以我不知道哪个效率更高。

话虽如此,我通常会选择第一种方法。这不是很漂亮,但这通常是与外界交互的结果。

In your second method, I would advance to the next group of 4 by doing this:

ints = ints[4:]

However, I haven’t done any performance measurement so I don’t know which one might be more efficient.

Having said that, I would usually choose the first method. It’s not pretty, but that’s often a consequence of interfacing with the outside world.


回答 19

另一个答案,其优点是:

1)易于理解
2)可以处理任何可迭代的对象,而不仅是序列(上面的一些回答会在文件句柄上阻塞)
3)不会一次将块全部加载到内存中
4)不会对引用进行大块的长列表内存中的相同迭代器
5)列表末尾没有填充值

话虽这么说,我还没有计时,所以它可能比一些更聪明的方法要慢,并且某些优点可能与用例无关。

def chunkiter(iterable, size):
  def inneriter(first, iterator, size):
    yield first
    for _ in xrange(size - 1): 
      yield iterator.next()
  it = iter(iterable)
  while True:
    yield inneriter(it.next(), it, size)

In [2]: i = chunkiter('abcdefgh', 3)
In [3]: for ii in i:                                                
          for c in ii:
            print c,
          print ''
        ...:     
        a b c 
        d e f 
        g h 

更新:
由于内循环和外循环从同一个迭代器中提取值而造成的一些弊端:
1)继续无法在外循环中按预期方式工作-继续执行下一个项目而不是跳过一个块。但是,这似乎不是问题,因为在外循环中没有要测试的东西。
2)break不能在内部循环中按预期方式工作-控件将在迭代器中的下一个项目中再次进入内部循环。要跳过整个块,可以将内部迭代器(上面的ii)包装在一个元组中,例如for c in tuple(ii),或者设置一个标志并耗尽迭代器。

Yet another answer, the advantages of which are:

1) Easily understandable
2) Works on any iterable, not just sequences (some of the above answers will choke on filehandles)
3) Does not load the chunk into memory all at once
4) Does not make a chunk-long list of references to the same iterator in memory
5) No padding of fill values at the end of the list

That being said, I haven’t timed it so it might be slower than some of the more clever methods, and some of the advantages may be irrelevant given the use case.

def chunkiter(iterable, size):
  def inneriter(first, iterator, size):
    yield first
    for _ in xrange(size - 1): 
      yield iterator.next()
  it = iter(iterable)
  while True:
    yield inneriter(it.next(), it, size)

In [2]: i = chunkiter('abcdefgh', 3)
In [3]: for ii in i:                                                
          for c in ii:
            print c,
          print ''
        ...:     
        a b c 
        d e f 
        g h 

Update:
A couple of drawbacks due to the fact the inner and outer loops are pulling values from the same iterator:
1) continue doesn’t work as expected in the outer loop – it just continues on to the next item rather than skipping a chunk. However, this doesn’t seem like a problem as there’s nothing to test in the outer loop.
2) break doesn’t work as expected in the inner loop – control will wind up in the inner loop again with the next item in the iterator. To skip whole chunks, either wrap the inner iterator (ii above) in a tuple, e.g. for c in tuple(ii), or set a flag and exhaust the iterator.


回答 20

def group_by(iterable, size):
    """Group an iterable into lists that don't exceed the size given.

    >>> group_by([1,2,3,4,5], 2)
    [[1, 2], [3, 4], [5]]

    """
    sublist = []

    for index, item in enumerate(iterable):
        if index > 0 and index % size == 0:
            yield sublist
            sublist = []

        sublist.append(item)

    if sublist:
        yield sublist
def group_by(iterable, size):
    """Group an iterable into lists that don't exceed the size given.

    >>> group_by([1,2,3,4,5], 2)
    [[1, 2], [3, 4], [5]]

    """
    sublist = []

    for index, item in enumerate(iterable):
        if index > 0 and index % size == 0:
            yield sublist
            sublist = []

        sublist.append(item)

    if sublist:
        yield sublist

回答 21

您可以从funcy库中使用分区功能:

from funcy import partition

for a, b, c, d in partition(4, ints):
    foo += a * b * c * d

这些函数还具有迭代器版本ipartitionichunks,在这种情况下将更加高效。

您也可以查看它们的实现

You can use partition or chunks function from funcy library:

from funcy import partition

for a, b, c, d in partition(4, ints):
    foo += a * b * c * d

These functions also has iterator versions ipartition and ichunks, which will be more efficient in this case.

You can also peek at their implementation.


回答 22

关于J.F. Sebastian 这里给的解决方案:

def chunker(iterable, chunksize):
    return zip(*[iter(iterable)]*chunksize)

它很聪明,但是有一个缺点-总是返回元组。如何获取字符串呢?
当然您可以编写''.join(chunker(...)),但是无论如何都构造了临时元组。

您可以通过编写own来摆脱临时元组zip,如下所示:

class IteratorExhausted(Exception):
    pass

def translate_StopIteration(iterable, to=IteratorExhausted):
    for i in iterable:
        yield i
    raise to # StopIteration would get ignored because this is generator,
             # but custom exception can leave the generator.

def custom_zip(*iterables, reductor=tuple):
    iterators = tuple(map(translate_StopIteration, iterables))
    while True:
        try:
            yield reductor(next(i) for i in iterators)
        except IteratorExhausted: # when any of iterators get exhausted.
            break

然后

def chunker(data, size, reductor=tuple):
    return custom_zip(*[iter(data)]*size, reductor=reductor)

用法示例:

>>> for i in chunker('12345', 2):
...     print(repr(i))
...
('1', '2')
('3', '4')
>>> for i in chunker('12345', 2, ''.join):
...     print(repr(i))
...
'12'
'34'

About solution gave by J.F. Sebastian here:

def chunker(iterable, chunksize):
    return zip(*[iter(iterable)]*chunksize)

It’s clever, but has one disadvantage – always return tuple. How to get string instead?
Of course you can write ''.join(chunker(...)), but the temporary tuple is constructed anyway.

You can get rid of the temporary tuple by writing own zip, like this:

class IteratorExhausted(Exception):
    pass

def translate_StopIteration(iterable, to=IteratorExhausted):
    for i in iterable:
        yield i
    raise to # StopIteration would get ignored because this is generator,
             # but custom exception can leave the generator.

def custom_zip(*iterables, reductor=tuple):
    iterators = tuple(map(translate_StopIteration, iterables))
    while True:
        try:
            yield reductor(next(i) for i in iterators)
        except IteratorExhausted: # when any of iterators get exhausted.
            break

Then

def chunker(data, size, reductor=tuple):
    return custom_zip(*[iter(data)]*size, reductor=reductor)

Example usage:

>>> for i in chunker('12345', 2):
...     print(repr(i))
...
('1', '2')
('3', '4')
>>> for i in chunker('12345', 2, ''.join):
...     print(repr(i))
...
'12'
'34'

回答 23

我喜欢这种方法。它感觉简单而不是魔术,并且支持所有可迭代的类型,并且不需要导入。

def chunk_iter(iterable, chunk_size):
it = iter(iterable)
while True:
    chunk = tuple(next(it) for _ in range(chunk_size))
    if not chunk:
        break
    yield chunk

I like this approach. It feels simple and not magical and supports all iterable types and doesn’t require imports.

def chunk_iter(iterable, chunk_size):
it = iter(iterable)
while True:
    chunk = tuple(next(it) for _ in range(chunk_size))
    if not chunk:
        break
    yield chunk

回答 24

我从不希望填充我的数据块,因此这一要求至关重要。我发现还需要具有处理任何迭代的能力。鉴于此,我决定扩展接受的答案,https://stackoverflow.com/a/434411/1074659

如果由于需要比较和过滤填充值而不需要填充,则此方法的性能会受到轻微影响。但是,对于大块数据,此实用程序非常有效。

#!/usr/bin/env python3
from itertools import zip_longest


_UNDEFINED = object()


def chunker(iterable, chunksize, fillvalue=_UNDEFINED):
    """
    Collect data into chunks and optionally pad it.

    Performance worsens as `chunksize` approaches 1.

    Inspired by:
        https://docs.python.org/3/library/itertools.html#itertools-recipes

    """
    args = [iter(iterable)] * chunksize
    chunks = zip_longest(*args, fillvalue=fillvalue)
    yield from (
        filter(lambda val: val is not _UNDEFINED, chunk)
        if chunk[-1] is _UNDEFINED
        else chunk
        for chunk in chunks
    ) if fillvalue is _UNDEFINED else chunks

I never want my chunks padded, so that requirement is essential. I find that the ability to work on any iterable is also requirement. Given that, I decided to extend on the accepted answer, https://stackoverflow.com/a/434411/1074659.

Performance takes a slight hit in this approach if padding is not wanted due to the need to compare and filter the padded values. However, for large chunk sizes, this utility is very performant.

#!/usr/bin/env python3
from itertools import zip_longest


_UNDEFINED = object()


def chunker(iterable, chunksize, fillvalue=_UNDEFINED):
    """
    Collect data into chunks and optionally pad it.

    Performance worsens as `chunksize` approaches 1.

    Inspired by:
        https://docs.python.org/3/library/itertools.html#itertools-recipes

    """
    args = [iter(iterable)] * chunksize
    chunks = zip_longest(*args, fillvalue=fillvalue)
    yield from (
        filter(lambda val: val is not _UNDEFINED, chunk)
        if chunk[-1] is _UNDEFINED
        else chunk
        for chunk in chunks
    ) if fillvalue is _UNDEFINED else chunks

回答 25

这是一个没有导入功能的分块器,它支持生成器:

def chunks(seq, size):
    it = iter(seq)
    while True:
        ret = tuple(next(it) for _ in range(size))
        if len(ret) == size:
            yield ret
        else:
            raise StopIteration()

使用示例:

>>> def foo():
...     i = 0
...     while True:
...         i += 1
...         yield i
...
>>> c = chunks(foo(), 3)
>>> c.next()
(1, 2, 3)
>>> c.next()
(4, 5, 6)
>>> list(chunks('abcdefg', 2))
[('a', 'b'), ('c', 'd'), ('e', 'f')]

Here is a chunker without imports that supports generators:

def chunks(seq, size):
    it = iter(seq)
    while True:
        ret = tuple(next(it) for _ in range(size))
        if len(ret) == size:
            yield ret
        else:
            raise StopIteration()

Example of use:

>>> def foo():
...     i = 0
...     while True:
...         i += 1
...         yield i
...
>>> c = chunks(foo(), 3)
>>> c.next()
(1, 2, 3)
>>> c.next()
(4, 5, 6)
>>> list(chunks('abcdefg', 2))
[('a', 'b'), ('c', 'd'), ('e', 'f')]

回答 26

在Python 3.8中,您可以使用walrus运算符和itertools.islice

from itertools import islice

list_ = [i for i in range(10, 100)]

def chunker(it, size):
    iterator = iter(it)
    while chunk := list(islice(iterator, size)):
        print(chunk)
In [2]: chunker(list_, 10)                                                         
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]
[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]

With Python 3.8 you can use the walrus operator and itertools.islice.

from itertools import islice

list_ = [i for i in range(10, 100)]

def chunker(it, size):
    iterator = iter(it)
    while chunk := list(islice(iterator, size)):
        print(chunk)
In [2]: chunker(list_, 10)                                                         
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]
[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]


回答 27

似乎没有做到这一点的漂亮方法。 是一个包含许多方法的页面,包括:

def split_seq(seq, size):
    newseq = []
    splitsize = 1.0/size*len(seq)
    for i in range(size):
        newseq.append(seq[int(round(i*splitsize)):int(round((i+1)*splitsize))])
    return newseq

There doesn’t seem to be a pretty way to do this. Here is a page that has a number of methods, including:

def split_seq(seq, size):
    newseq = []
    splitsize = 1.0/size*len(seq)
    for i in range(size):
        newseq.append(seq[int(round(i*splitsize)):int(round((i+1)*splitsize))])
    return newseq

回答 28

如果列表大小相同,则可以将它们组合成4元组的列表zip()。例如:

# Four lists of four elements each.

l1 = range(0, 4)
l2 = range(4, 8)
l3 = range(8, 12)
l4 = range(12, 16)

for i1, i2, i3, i4 in zip(l1, l2, l3, l4):
    ...

下面是什么zip()函数生成:

>>> print l1
[0, 1, 2, 3]
>>> print l2
[4, 5, 6, 7]
>>> print l3
[8, 9, 10, 11]
>>> print l4
[12, 13, 14, 15]
>>> print zip(l1, l2, l3, l4)
[(0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)]

如果列表很大,并且您不想将它们组合成更大的列表,请使用itertools.izip(),它会生成一个迭代器,而不是列表。

from itertools import izip

for i1, i2, i3, i4 in izip(l1, l2, l3, l4):
    ...

If the lists are the same size, you can combine them into lists of 4-tuples with zip(). For example:

# Four lists of four elements each.

l1 = range(0, 4)
l2 = range(4, 8)
l3 = range(8, 12)
l4 = range(12, 16)

for i1, i2, i3, i4 in zip(l1, l2, l3, l4):
    ...

Here’s what the zip() function produces:

>>> print l1
[0, 1, 2, 3]
>>> print l2
[4, 5, 6, 7]
>>> print l3
[8, 9, 10, 11]
>>> print l4
[12, 13, 14, 15]
>>> print zip(l1, l2, l3, l4)
[(0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)]

If the lists are large, and you don’t want to combine them into a bigger list, use itertools.izip(), which produces an iterator, rather than a list.

from itertools import izip

for i1, i2, i3, i4 in izip(l1, l2, l3, l4):
    ...

回答 29

一种单行的即席解决方案,可x对大小成块的列表进行迭代4

for a, b, c, d in zip(x[0::4], x[1::4], x[2::4], x[3::4]):
    ... do something with a, b, c and d ...

One-liner, adhoc solution to iterate over a list x in chunks of size 4

for a, b, c, d in zip(x[0::4], x[1::4], x[2::4], x[3::4]):
    ... do something with a, b, c and d ...

拼合不规则的列表

问题:拼合不规则的列表

是的,我知道以前已经讨论过这个主题(这里这里这里这里),但是据我所知,除一个解决方案外,所有解决方案在这样的列表上都失败了:

L = [[[1, 2, 3], [4, 5]], 6]

所需的输出是

[1, 2, 3, 4, 5, 6]

甚至更好的迭代器。这个问题是我看到的唯一适用于任意嵌套的解决方案:

def flatten(x):
    result = []
    for el in x:
        if hasattr(el, "__iter__") and not isinstance(el, basestring):
            result.extend(flatten(el))
        else:
            result.append(el)
    return result

flatten(L)

这是最好的模型吗?我有事吗 任何问题?

Yes, I know this subject has been covered before (here, here, here, here), but as far as I know, all solutions, except for one, fail on a list like this:

L = [[[1, 2, 3], [4, 5]], 6]

Where the desired output is

[1, 2, 3, 4, 5, 6]

Or perhaps even better, an iterator. The only solution I saw that works for an arbitrary nesting is found in this question:

def flatten(x):
    result = []
    for el in x:
        if hasattr(el, "__iter__") and not isinstance(el, basestring):
            result.extend(flatten(el))
        else:
            result.append(el)
    return result

flatten(L)

Is this the best model? Did I overlook something? Any problems?


回答 0

使用生成器函数可以使您的示例更易于阅读,并可能提高性能。

Python 2

def flatten(l):
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, basestring):
            for sub in flatten(el):
                yield sub
        else:
            yield el

我使用了2.6中添加的Iterable ABC

Python 3

在Python 3中,basestring是没有更多的,但你可以使用一个元组str,并bytes得到同样的效果存在。

yield from运营商从一时间产生一个返回的项目。这句法委派到子发生器在3.3加入

def flatten(l):
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

Using generator functions can make your example a little easier to read and probably boost the performance.

Python 2

def flatten(l):
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, basestring):
            for sub in flatten(el):
                yield sub
        else:
            yield el

I used the Iterable ABC added in 2.6.

Python 3

In Python 3, the basestring is no more, but you can use a tuple of str and bytes to get the same effect there.

The yield from operator returns an item from a generator one at a time. This syntax for delegating to a subgenerator was added in 3.3

def flatten(l):
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

回答 1

我的解决方案:

import collections


def flatten(x):
    if isinstance(x, collections.Iterable):
        return [a for i in x for a in flatten(i)]
    else:
        return [x]

更加简洁,但几乎相同。

My solution:

import collections


def flatten(x):
    if isinstance(x, collections.Iterable):
        return [a for i in x for a in flatten(i)]
    else:
        return [x]

A little more concise, but pretty much the same.


回答 2

使用递归和鸭子类型生成器(针对Python 3更新):

def flatten(L):
    for item in L:
        try:
            yield from flatten(item)
        except TypeError:
            yield item

list(flatten([[[1, 2, 3], [4, 5]], 6]))
>>>[1, 2, 3, 4, 5, 6]

Generator using recursion and duck typing (updated for Python 3):

def flatten(L):
    for item in L:
        try:
            yield from flatten(item)
        except TypeError:
            yield item

list(flatten([[[1, 2, 3], [4, 5]], 6]))
>>>[1, 2, 3, 4, 5, 6]

回答 3

@unutbu的非递归解决方案的生成器版本,由@Andrew在注释中要求:

def genflat(l, ltypes=collections.Sequence):
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        yield l[i]
        i += 1

此生成器的简化版本:

def genflat(l, ltypes=collections.Sequence):
    l = list(l)
    while l:
        while l and isinstance(l[0], ltypes):
            l[0:1] = l[0]
        if l: yield l.pop(0)

Generator version of @unutbu’s non-recursive solution, as requested by @Andrew in a comment:

def genflat(l, ltypes=collections.Sequence):
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        yield l[i]
        i += 1

Slightly simplified version of this generator:

def genflat(l, ltypes=collections.Sequence):
    l = list(l)
    while l:
        while l and isinstance(l[0], ltypes):
            l[0:1] = l[0]
        if l: yield l.pop(0)

回答 4

这是我的功能性版本的递归展平,它既处理元组又处理列表,并允许您引入位置参数的任何组合。返回一个生成器,该生成器按arg由arg的顺序生成整个序列:

flatten = lambda *n: (e for a in n
    for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))

用法:

l1 = ['a', ['b', ('c', 'd')]]
l2 = [0, 1, (2, 3), [[4, 5, (6, 7, (8,), [9]), 10]], (11,)]
print list(flatten(l1, -2, -1, l2))
['a', 'b', 'c', 'd', -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

Here is my functional version of recursive flatten which handles both tuples and lists, and lets you throw in any mix of positional arguments. Returns a generator which produces the entire sequence in order, arg by arg:

flatten = lambda *n: (e for a in n
    for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))

Usage:

l1 = ['a', ['b', ('c', 'd')]]
l2 = [0, 1, (2, 3), [[4, 5, (6, 7, (8,), [9]), 10]], (11,)]
print list(flatten(l1, -2, -1, l2))
['a', 'b', 'c', 'd', -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

回答 5

此版本的版本flatten避免了python的递归限制(因此可用于任意深度的嵌套可迭代对象)。它是一个生成器,可以处理字符串和任意可迭代(甚至是无限的)。

import itertools as IT
import collections

def flatten(iterable, ltypes=collections.Iterable):
    remainder = iter(iterable)
    while True:
        first = next(remainder)
        if isinstance(first, ltypes) and not isinstance(first, (str, bytes)):
            remainder = IT.chain(first, remainder)
        else:
            yield first

以下是一些示例说明其用法:

print(list(IT.islice(flatten(IT.repeat(1)),10)))
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

print(list(IT.islice(flatten(IT.chain(IT.repeat(2,3),
                                       {10,20,30},
                                       'foo bar'.split(),
                                       IT.repeat(1),)),10)))
# [2, 2, 2, 10, 20, 30, 'foo', 'bar', 1, 1]

print(list(flatten([[1,2,[3,4]]])))
# [1, 2, 3, 4]

seq = ([[chr(i),chr(i-32)] for i in range(ord('a'), ord('z')+1)] + list(range(0,9)))
print(list(flatten(seq)))
# ['a', 'A', 'b', 'B', 'c', 'C', 'd', 'D', 'e', 'E', 'f', 'F', 'g', 'G', 'h', 'H',
# 'i', 'I', 'j', 'J', 'k', 'K', 'l', 'L', 'm', 'M', 'n', 'N', 'o', 'O', 'p', 'P',
# 'q', 'Q', 'r', 'R', 's', 'S', 't', 'T', 'u', 'U', 'v', 'V', 'w', 'W', 'x', 'X',
# 'y', 'Y', 'z', 'Z', 0, 1, 2, 3, 4, 5, 6, 7, 8]

尽管flatten可以处理无限生成器,但不能处理无限嵌套:

def infinitely_nested():
    while True:
        yield IT.chain(infinitely_nested(), IT.repeat(1))

print(list(IT.islice(flatten(infinitely_nested()), 10)))
# hangs

This version of flatten avoids python’s recursion limit (and thus works with arbitrarily deep, nested iterables). It is a generator which can handle strings and arbitrary iterables (even infinite ones).

import itertools as IT
import collections

def flatten(iterable, ltypes=collections.Iterable):
    remainder = iter(iterable)
    while True:
        first = next(remainder)
        if isinstance(first, ltypes) and not isinstance(first, (str, bytes)):
            remainder = IT.chain(first, remainder)
        else:
            yield first

Here are some examples demonstrating its use:

print(list(IT.islice(flatten(IT.repeat(1)),10)))
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

print(list(IT.islice(flatten(IT.chain(IT.repeat(2,3),
                                       {10,20,30},
                                       'foo bar'.split(),
                                       IT.repeat(1),)),10)))
# [2, 2, 2, 10, 20, 30, 'foo', 'bar', 1, 1]

print(list(flatten([[1,2,[3,4]]])))
# [1, 2, 3, 4]

seq = ([[chr(i),chr(i-32)] for i in range(ord('a'), ord('z')+1)] + list(range(0,9)))
print(list(flatten(seq)))
# ['a', 'A', 'b', 'B', 'c', 'C', 'd', 'D', 'e', 'E', 'f', 'F', 'g', 'G', 'h', 'H',
# 'i', 'I', 'j', 'J', 'k', 'K', 'l', 'L', 'm', 'M', 'n', 'N', 'o', 'O', 'p', 'P',
# 'q', 'Q', 'r', 'R', 's', 'S', 't', 'T', 'u', 'U', 'v', 'V', 'w', 'W', 'x', 'X',
# 'y', 'Y', 'z', 'Z', 0, 1, 2, 3, 4, 5, 6, 7, 8]

Although flatten can handle infinite generators, it can not handle infinite nesting:

def infinitely_nested():
    while True:
        yield IT.chain(infinitely_nested(), IT.repeat(1))

print(list(IT.islice(flatten(infinitely_nested()), 10)))
# hangs

回答 6

这是另一个更有趣的答案…

import re

def Flatten(TheList):
    a = str(TheList)
    b,crap = re.subn(r'[\[,\]]', ' ', a)
    c = b.split()
    d = [int(x) for x in c]

    return(d)

基本上,它将嵌套列表转换为字符串,使用正则表达式去除嵌套语法,然后将结果转换回(扁平化的)列表。

Here’s another answer that is even more interesting…

import re

def Flatten(TheList):
    a = str(TheList)
    b,crap = re.subn(r'[\[,\]]', ' ', a)
    c = b.split()
    d = [int(x) for x in c]

    return(d)

Basically, it converts the nested list to a string, uses a regex to strip out the nested syntax, and then converts the result back to a (flattened) list.


回答 7

def flatten(xs):
    res = []
    def loop(ys):
        for i in ys:
            if isinstance(i, list):
                loop(i)
            else:
                res.append(i)
    loop(xs)
    return res
def flatten(xs):
    res = []
    def loop(ys):
        for i in ys:
            if isinstance(i, list):
                loop(i)
            else:
                res.append(i)
    loop(xs)
    return res

回答 8

您可以deepflatten在第三方套餐中使用iteration_utilities

>>> from iteration_utilities import deepflatten
>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> list(deepflatten(L))
[1, 2, 3, 4, 5, 6]

>>> list(deepflatten(L, types=list))  # only flatten "inner" lists
[1, 2, 3, 4, 5, 6]

这是一个迭代器,因此您需要对其进行迭代(例如,通过将其包装list或在循环中使用)。在内部,它使用迭代方法而不是递归方法,并且将其编写为C扩展,因此它可以比纯python方法更快:

>>> %timeit list(deepflatten(L))
12.6 µs ± 298 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> %timeit list(deepflatten(L, types=list))
8.7 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit list(flatten(L))   # Cristian - Python 3.x approach from https://stackoverflow.com/a/2158532/5393381
86.4 µs ± 4.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

>>> %timeit list(flatten(L))   # Josh Lee - https://stackoverflow.com/a/2158522/5393381
107 µs ± 2.99 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

>>> %timeit list(genflat(L, list))  # Alex Martelli - https://stackoverflow.com/a/2159079/5393381
23.1 µs ± 710 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

我是iteration_utilities图书馆的作者。

You could use deepflatten from the 3rd party package iteration_utilities:

>>> from iteration_utilities import deepflatten
>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> list(deepflatten(L))
[1, 2, 3, 4, 5, 6]

>>> list(deepflatten(L, types=list))  # only flatten "inner" lists
[1, 2, 3, 4, 5, 6]

It’s an iterator so you need to iterate it (for example by wrapping it with list or using it in a loop). Internally it uses an iterative approach instead of an recursive approach and it’s written as C extension so it can be faster than pure python approaches:

>>> %timeit list(deepflatten(L))
12.6 µs ± 298 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> %timeit list(deepflatten(L, types=list))
8.7 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit list(flatten(L))   # Cristian - Python 3.x approach from https://stackoverflow.com/a/2158532/5393381
86.4 µs ± 4.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

>>> %timeit list(flatten(L))   # Josh Lee - https://stackoverflow.com/a/2158522/5393381
107 µs ± 2.99 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

>>> %timeit list(genflat(L, list))  # Alex Martelli - https://stackoverflow.com/a/2159079/5393381
23.1 µs ± 710 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

I’m the author of the iteration_utilities library.


回答 9

尝试创建一个可以平化Python中不规则列表的函数很有趣,但是当然这就是Python的目的(使编程变得有趣)。以下生成器在某些警告方面工作得很好:

def flatten(iterable):
    try:
        for item in iterable:
            yield from flatten(item)
    except TypeError:
        yield iterable

这将压扁的数据类型,你可能想独自离开(比如bytearraybytesstr对象)。此外,代码还依赖于以下事实:从不可迭代的对象请求迭代器会引发TypeError

>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> def flatten(iterable):
    try:
        for item in iterable:
            yield from flatten(item)
    except TypeError:
        yield iterable


>>> list(flatten(L))
[1, 2, 3, 4, 5, 6]
>>>

编辑:

我不同意以前的实现。问题在于您不应该将无法迭代的东西弄平。这令人困惑,并给人以错误的印象。

>>> list(flatten(123))
[123]
>>>

下面的生成器与第一个生成器几乎相同,但是不存在试图展平不可迭代对象的问题。当给它一个不适当的参数时,它会像人们期望的那样失败。

def flatten(iterable):
    for item in iterable:
        try:
            yield from flatten(item)
        except TypeError:
            yield item

使用提供的列表对生成器进行测试可以正常工作。但是,TypeError当给它一个不可迭代的对象时,新代码将引发一个。下面显示了新行为的示例。

>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> list(flatten(L))
[1, 2, 3, 4, 5, 6]
>>> list(flatten(123))
Traceback (most recent call last):
  File "<pyshell#32>", line 1, in <module>
    list(flatten(123))
  File "<pyshell#27>", line 2, in flatten
    for item in iterable:
TypeError: 'int' object is not iterable
>>>

It was fun trying to create a function that could flatten irregular list in Python, but of course that is what Python is for (to make programming fun). The following generator works fairly well with some caveats:

def flatten(iterable):
    try:
        for item in iterable:
            yield from flatten(item)
    except TypeError:
        yield iterable

It will flatten datatypes that you might want left alone (like bytearray, bytes, and str objects). Also, the code relies on the fact that requesting an iterator from a non-iterable raises a TypeError.

>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> def flatten(iterable):
    try:
        for item in iterable:
            yield from flatten(item)
    except TypeError:
        yield iterable


>>> list(flatten(L))
[1, 2, 3, 4, 5, 6]
>>>

Edit:

I disagree with the previous implementation. The problem is that you should not be able to flatten something that is not an iterable. It is confusing and gives the wrong impression of the argument.

>>> list(flatten(123))
[123]
>>>

The following generator is almost the same as the first but does not have the problem of trying to flatten a non-iterable object. It fails as one would expect when an inappropriate argument is given to it.

def flatten(iterable):
    for item in iterable:
        try:
            yield from flatten(item)
        except TypeError:
            yield item

Testing the generator works fine with the list that was provided. However, the new code will raise a TypeError when a non-iterable object is given to it. Example are shown below of the new behavior.

>>> L = [[[1, 2, 3], [4, 5]], 6]
>>> list(flatten(L))
[1, 2, 3, 4, 5, 6]
>>> list(flatten(123))
Traceback (most recent call last):
  File "<pyshell#32>", line 1, in <module>
    list(flatten(123))
  File "<pyshell#27>", line 2, in flatten
    for item in iterable:
TypeError: 'int' object is not iterable
>>>

回答 10

尽管选择了一个优雅且非常Python化的答案,但我仅出于审查目的而提出我的解决方案:

def flat(l):
    ret = []
    for i in l:
        if isinstance(i, list) or isinstance(i, tuple):
            ret.extend(flat(i))
        else:
            ret.append(i)
    return ret

请告诉我们这段代码的好坏?

Although an elegant and very pythonic answer has been selected I would present my solution just for the review:

def flat(l):
    ret = []
    for i in l:
        if isinstance(i, list) or isinstance(i, tuple):
            ret.extend(flat(i))
        else:
            ret.append(i)
    return ret

Please tell how good or bad this code is?


回答 11

我喜欢简单的答案。没有生成器。没有递归或递归限制。只是迭代:

def flatten(TheList):
    listIsNested = True

    while listIsNested:                 #outer loop
        keepChecking = False
        Temp = []

        for element in TheList:         #inner loop
            if isinstance(element,list):
                Temp.extend(element)
                keepChecking = True
            else:
                Temp.append(element)

        listIsNested = keepChecking     #determine if outer loop exits
        TheList = Temp[:]

    return TheList

这适用于两个列表:内部for循环和外部while循环。

内部的for循环遍历列表。如果找到列表元素,则(1)使用list.extend()展平该部分嵌套的层次,并且(2)将keepChecking切换为True。keepchecking用于控制外部while循环。如果将外部循环设置为true,则会触发内部循环进行另一遍处理。

这些通行证一直发生,直到找不到更多的嵌套列表。当最后一次通过但找不到任何地方的传递时,keepChecking永远不会变为true,这意味着listIsNested保持为false,而外部while循环退出。

然后返回扁平化列表。

测试运行

flatten([1,2,3,4,[100,200,300,[1000,2000,3000]]])

[1, 2, 3, 4, 100, 200, 300, 1000, 2000, 3000]

I prefer simple answers. No generators. No recursion or recursion limits. Just iteration:

def flatten(TheList):
    listIsNested = True

    while listIsNested:                 #outer loop
        keepChecking = False
        Temp = []

        for element in TheList:         #inner loop
            if isinstance(element,list):
                Temp.extend(element)
                keepChecking = True
            else:
                Temp.append(element)

        listIsNested = keepChecking     #determine if outer loop exits
        TheList = Temp[:]

    return TheList

This works with two lists: an inner for loop and an outer while loop.

The inner for loop iterates through the list. If it finds a list element, it (1) uses list.extend() to flatten that part one level of nesting and (2) switches keepChecking to True. keepchecking is used to control the outer while loop. If the outer loop gets set to true, it triggers the inner loop for another pass.

Those passes keep happening until no more nested lists are found. When a pass finally occurs where none are found, keepChecking never gets tripped to true, which means listIsNested stays false and the outer while loop exits.

The flattened list is then returned.

Test-run

flatten([1,2,3,4,[100,200,300,[1000,2000,3000]]])

[1, 2, 3, 4, 100, 200, 300, 1000, 2000, 3000]


回答 12

这是一个简单的函数,可以平铺任意深度的列表。没有递归,以避免堆栈溢出。

from copy import deepcopy

def flatten_list(nested_list):
    """Flatten an arbitrarily nested list, without recursion (to avoid
    stack overflows). Returns a new list, the original list is unchanged.

    >> list(flatten_list([1, 2, 3, [4], [], [[[[[[[[[5]]]]]]]]]]))
    [1, 2, 3, 4, 5]
    >> list(flatten_list([[1, 2], 3]))
    [1, 2, 3]

    """
    nested_list = deepcopy(nested_list)

    while nested_list:
        sublist = nested_list.pop(0)

        if isinstance(sublist, list):
            nested_list = sublist + nested_list
        else:
            yield sublist

Here’s a simple function that flattens lists of arbitrary depth. No recursion, to avoid stack overflow.

from copy import deepcopy

def flatten_list(nested_list):
    """Flatten an arbitrarily nested list, without recursion (to avoid
    stack overflows). Returns a new list, the original list is unchanged.

    >> list(flatten_list([1, 2, 3, [4], [], [[[[[[[[[5]]]]]]]]]]))
    [1, 2, 3, 4, 5]
    >> list(flatten_list([[1, 2], 3]))
    [1, 2, 3]

    """
    nested_list = deepcopy(nested_list)

    while nested_list:
        sublist = nested_list.pop(0)

        if isinstance(sublist, list):
            nested_list = sublist + nested_list
        else:
            yield sublist

回答 13

我很惊讶没有人想到这一点。该死的递归我没有这里的高级人员做出的递归答案。无论如何,这是我的尝试。请注意,这是非常特定于OP的用例的

import re

L = [[[1, 2, 3], [4, 5]], 6]
flattened_list = re.sub("[\[\]]", "", str(L)).replace(" ", "").split(",")
new_list = list(map(int, flattened_list))
print(new_list)

输出:

[1, 2, 3, 4, 5, 6]

I’m surprised no one has thought of this. Damn recursion I don’t get the recursive answers that the advanced people here made. anyway here is my attempt on this. caveat is it’s very specific to the OP’s use case

import re

L = [[[1, 2, 3], [4, 5]], 6]
flattened_list = re.sub("[\[\]]", "", str(L)).replace(" ", "").split(",")
new_list = list(map(int, flattened_list))
print(new_list)

output:

[1, 2, 3, 4, 5, 6]

回答 14

我没有在这里浏览所有已经可用的答案,但这是我想到的一个衬里,它借鉴了Lisp的第一张清单和其余清单的处理方式

def flatten(l): return flatten(l[0]) + (flatten(l[1:]) if len(l) > 1 else []) if type(l) is list else [l]

这是一种简单而又不太简单的情况-

>>> flatten([1,[2,3],4])
[1, 2, 3, 4]

>>> flatten([1, [2, 3], 4, [5, [6, {'name': 'some_name', 'age':30}, 7]], [8, 9, [10, [11, [12, [13, {'some', 'set'}, 14, [15, 'some_string'], 16], 17, 18], 19], 20], 21, 22, [23, 24], 25], 26, 27, 28, 29, 30])
[1, 2, 3, 4, 5, 6, {'age': 30, 'name': 'some_name'}, 7, 8, 9, 10, 11, 12, 13, set(['set', 'some']), 14, 15, 'some_string', 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
>>> 

I didn’t go through all the already available answers here, but here is a one liner I came up with, borrowing from lisp’s way of first and rest list processing

def flatten(l): return flatten(l[0]) + (flatten(l[1:]) if len(l) > 1 else []) if type(l) is list else [l]

here is one simple and one not-so-simple case –

>>> flatten([1,[2,3],4])
[1, 2, 3, 4]

>>> flatten([1, [2, 3], 4, [5, [6, {'name': 'some_name', 'age':30}, 7]], [8, 9, [10, [11, [12, [13, {'some', 'set'}, 14, [15, 'some_string'], 16], 17, 18], 19], 20], 21, 22, [23, 24], 25], 26, 27, 28, 29, 30])
[1, 2, 3, 4, 5, 6, {'age': 30, 'name': 'some_name'}, 7, 8, 9, 10, 11, 12, 13, set(['set', 'some']), 14, 15, 'some_string', 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
>>> 

回答 15

当试图回答这样的问题时,您确实需要给出您提议作为解决方案的代码的限制。如果只考虑性能,我不会太在意,但是提议作为解决方案的大多数代码(包括可接受的答案)都无法使深度大于1000的列表变平。

当我说大多数代码我指的是所有使用任何形式的递归的代码(或调用递归的标准库函数)。所有这些代码都会失败,因为对于每个递归调用,(调用)堆栈都增加一个单位,而(默认)python调用堆栈的大小为1000。

如果您不太熟悉调用堆栈,那么以下内容可能会有所帮助(否则,您可以滚动到Implementation)。

调用堆栈大小和递归编程(类似于地下城)

寻找宝藏并退出

想象一下,您进入一个带编号房间的巨大地牢,寻找宝藏。您不知道这个地方,但是对于如何找到宝藏有一些指示。每个指示都是一个谜(难度各不相同,但是您无法预测它们的难易程度)。您决定对节省时间的策略进行一点思考,然后进行两个观察:

  1. 很难(很长)找到宝藏,因为您必须解决(可能很难)谜团才能到达那里。
  2. 找到宝藏后,返回入口可能很容易,您只需要在另一个方向上使用相同的路径即可(尽管这需要一点记忆才能调用您的路径)。

进入地牢时,您会在这里注意到一个小笔记本。您决定使用它来写下谜题(当进入新房间时)之后退出的每个房间,这样您就可以返回到入口。那是个天才的主意,您甚至都不会花一分钱实施自己的策略。

您进入了地牢,成功地解决了前1001个难题,但是这是您未曾计划的事情,您借用的笔记本中没有剩余空间。您决定放弃自己的任务,因为您更喜欢没有宝物,而不是永远迷失在地牢中(确实看起来很聪明)。

执行递归程序

基本上,这与寻找宝藏完全相同。地牢是计算机的内存,您现在的目标不是找到宝藏,而是要计算某些函数(对于给定x找到f(x))。这些指示只是子例程,可以帮助您解决f(x)。您的策略与调用堆栈策略相同,笔记本是堆栈,房间是函数的返回地址:

x = ["over here", "am", "I"]
y = sorted(x) # You're about to enter a room named `sorted`, note down the current room address here so you can return back: 0x4004f4 (that room address looks weird)
# Seems like you went back from your quest using the return address 0x4004f4
# Let's see what you've collected 
print(' '.join(y))

您在地牢中遇到的问题在这里将是相同的,调用堆栈的大小是有限的(此处为1000),因此,如果您输入了太多函数而没有返回,则您将填充调用堆栈并出现错误就像一次调用自己-一遍又一遍-),您将一遍又一遍地输入,直到计算完成(直到找到宝藏为止),然后返回,直到返回到调用的位置为止 “亲爱的冒险家,很抱歉,您的笔记本已经满了”:最初的地方。直到最后一次将调用栈从所有返回地址中释放出来之前,调用栈将永远不会被释放。RecursionError: maximum recursion depth exceeded。请注意,您不需要递归即可填充调用堆栈,但是非递归程序调用1000函数而永远不会返回的可能性很小。同样重要的是要了解,从函数返回后,调用栈将从使用的地址中释放出来(因此,名称“栈”,返回地址在进入函数之前就被压入,并在返回时被拉出)。在简单递归的特殊情况下(一个函数ffff

如何避免这个问题?

这实际上很简单:“如果您不知道递归的深度,请不要使用递归”。并非总是如此,因为在某些情况下,可以优化尾调用递归(TCO)。但是在python中,情况并非如此,即使“写得很好”的递归函数也无法优化堆栈的使用。Guido有一个有趣的帖子,关于这个问题:尾递归消除

您可以使用一种技术来迭代任何递归函数,我们可以称之为自带笔记本。例如,在我们的特定情况下,我们只是在探索一个列表,进入一个房间等同于进入一个子列表,您应该问自己的问题是如何从列表返回其父列表?答案并不那么复杂,请重复以下操作,直到stack为空:

  1. 推送当前列表,addressindexstack进入新的子列表时将其推入(请注意,列表地址+索引也是地址,因此我们只使用调用堆栈使用的完全相同的技术);
  2. 每次找到一个项目yield(或将它们添加到列表中);
  3. 完全浏览列表后,请使用stack return address(和index)返回父列表。

还要注意,这等效于树中的DFS,其中某些节点是子列表,A = [1, 2]而有些则是简单项:(0, 1, 2, 3, 4用于L = [0, [1,2], 3, 4])。树看起来像这样:

                    L
                    |
           -------------------
           |     |     |     |
           0   --A--   3     4
               |   |
               1   2

DFS遍历的顺序为:L,0,A,1、2、3、4。请记住,要实现迭代DFS,您还需要“堆栈”。我之前提出的实现导致具有以下状态(针对stackflat_list):

init.:  stack=[(L, 0)]
**0**:  stack=[(L, 0)],         flat_list=[0]
**A**:  stack=[(L, 1), (A, 0)], flat_list=[0]
**1**:  stack=[(L, 1), (A, 0)], flat_list=[0, 1]
**2**:  stack=[(L, 1), (A, 1)], flat_list=[0, 1, 2]
**3**:  stack=[(L, 2)],         flat_list=[0, 1, 2, 3]
**3**:  stack=[(L, 3)],         flat_list=[0, 1, 2, 3, 4]
return: stack=[],               flat_list=[0, 1, 2, 3, 4]

在此示例中,堆栈最大大小为2,因为输入列表(因此树)的深度为2。

实作

对于实现,在python中,您可以使用迭代器而不是简单的列表来简化一点。对(子)迭代器的引用将用于存储子列表的返回地址(而不是同时具有列表地址和索引)。这不是什么大的区别,但是我觉得这更具可读性(并且速度更快):

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:   # post-order
            cursor_stack.pop()
            continue
        if is_list_like(item):  # pre-order
            cursor_stack.append(iter(item))
        elif item is not None:
            yield item          # in-order

def is_list_like(item):
    return isinstance(item, list)

另外,请注意,在is_list_likeI have中isinstance(item, list),可以将其更改为处理更多输入类型,在这里,我只想拥有最简单的版本,其中(可迭代)只是一个列表。但是您也可以这样做:

def is_list_like(item):
    try:
        iter(item)
        return not isinstance(item, str)  # strings are not lists (hmm...) 
    except TypeError:
        return False

flatten_iter([["test", "a"], "b])会将字符串视为“简单项目”,因此将返回["test", "a", "b"]而不是["t", "e", "s", "t", "a", "b"]。请注意,在这种情况下,iter(item)每个项目都会被调用两次,让我们假设这是读者练习此清洁器的一种练习。

测试和评论其他实现

最后,请记住,您不能使用来打印无限嵌套的列表Lprint(L)因为它在内部将使用对__repr__RecursionError: maximum recursion depth exceeded while getting the repr of an object)的递归调用。出于相同的原因,flatten涉及解决方案str将失败,并显示相同的错误消息。

如果您需要测试解决方案,则可以使用此函数生成一个简单的嵌套列表:

def build_deep_list(depth):
    """Returns a list of the form $l_{depth} = [depth-1, l_{depth-1}]$
    with $depth > 1$ and $l_0 = [0]$.
    """
    sub_list = [0]
    for d in range(1, depth):
        sub_list = [d, sub_list]
    return sub_list

给出:build_deep_list(5)>>> [4, [3, [2, [1, [0]]]]]

When trying to answer such a question you really need to give the limitations of the code you propose as a solution. If it was only about performances I wouldn’t mind too much, but most of the codes proposed as solution (including the accepted answer) fail to flatten any list that has a depth greater than 1000.

When I say most of the codes I mean all codes that use any form of recursion (or call a standard library function that is recursive). All these codes fail because for every of the recursive call made, the (call) stack grow by one unit, and the (default) python call stack has a size of 1000.

If you’re not too familiar with the call stack, then maybe the following will help (otherwise you can just scroll to the Implementation).

Call stack size and recursive programming (dungeon analogy)

Finding the treasure and exit

Imagine you enter a huge dungeon with numbered rooms, looking for a treasure. You don’t know the place but you have some indications on how to find the treasure. Each indication is a riddle (difficulty varies, but you can’t predict how hard they will be). You decide to think a little bit about a strategy to save time, you make two observations:

  1. It’s hard (long) to find the treasure as you’ll have to solve (potentially hard) riddles to get there.
  2. Once the treasure found, returning to the entrance may be easy, you just have to use the same path in the other direction (though this needs a bit of memory to recall your path).

When entering the dungeon, you notice a small notebook here. You decide to use it to write down every room you exit after solving a riddle (when entering a new room), this way you’ll be able to return back to the entrance. That’s a genius idea, you won’t even spend a cent implementing your strategy.

You enter the dungeon, solving with great success the first 1001 riddles, but here comes something you hadn’t planed, you have no space left in the notebook you borrowed. You decide to abandon your quest as you prefer not having the treasure than being lost forever inside the dungeon (that looks smart indeed).

Executing a recursive program

Basically, it’s the exact same thing as finding the treasure. The dungeon is the computer’s memory, your goal now is not to find a treasure but to compute some function (find f(x) for a given x). The indications simply are sub-routines that will help you solving f(x). Your strategy is the same as the call stack strategy, the notebook is the stack, the rooms are the functions’ return addresses:

x = ["over here", "am", "I"]
y = sorted(x) # You're about to enter a room named `sorted`, note down the current room address here so you can return back: 0x4004f4 (that room address looks weird)
# Seems like you went back from your quest using the return address 0x4004f4
# Let's see what you've collected 
print(' '.join(y))

The problem you encountered in the dungeon will be the same here, the call stack has a finite size (here 1000) and therefore, if you enter too many functions without returning back then you’ll fill the call stack and have an error that look like “Dear adventurer, I’m very sorry but your notebook is full”: RecursionError: maximum recursion depth exceeded. Note that you don’t need recursion to fill the call stack, but it’s very unlikely that a non-recursive program call 1000 functions without ever returning. It’s important to also understand that once you returned from a function, the call stack is freed from the address used (hence the name “stack”, return address are pushed in before entering a function and pulled out when returning). In the special case of a simple recursion (a function f that call itself once — over and over –) you will enter f over and over until the computation is finished (until the treasure is found) and return from f until you go back to the place where you called f in the first place. The call stack will never be freed from anything until the end where it will be freed from all return addresses one after the other.

How to avoid this issue?

That’s actually pretty simple: “don’t use recursion if you don’t know how deep it can go”. That’s not always true as in some cases, Tail Call recursion can be Optimized (TCO). But in python, this is not the case, and even “well written” recursive function will not optimize stack use. There is an interesting post from Guido about this question: Tail Recursion Elimination.

There is a technique that you can use to make any recursive function iterative, this technique we could call bring your own notebook. For example, in our particular case we simply are exploring a list, entering a room is equivalent to entering a sublist, the question you should ask yourself is how can I get back from a list to its parent list? The answer is not that complex, repeat the following until the stack is empty:

  1. push the current list address and index in a stack when entering a new sublist (note that a list address+index is also an address, therefore we just use the exact same technique used by the call stack);
  2. every time an item is found, yield it (or add them in a list);
  3. once a list is fully explored, go back to the parent list using the stack return address (and index).

Also note that this is equivalent to a DFS in a tree where some nodes are sublists A = [1, 2] and some are simple items: 0, 1, 2, 3, 4 (for L = [0, [1,2], 3, 4]). The tree looks like this:

                    L
                    |
           -------------------
           |     |     |     |
           0   --A--   3     4
               |   |
               1   2

The DFS traversal pre-order is: L, 0, A, 1, 2, 3, 4. Remember, in order to implement an iterative DFS you also “need” a stack. The implementation I proposed before result in having the following states (for the stack and the flat_list):

init.:  stack=[(L, 0)]
**0**:  stack=[(L, 0)],         flat_list=[0]
**A**:  stack=[(L, 1), (A, 0)], flat_list=[0]
**1**:  stack=[(L, 1), (A, 0)], flat_list=[0, 1]
**2**:  stack=[(L, 1), (A, 1)], flat_list=[0, 1, 2]
**3**:  stack=[(L, 2)],         flat_list=[0, 1, 2, 3]
**3**:  stack=[(L, 3)],         flat_list=[0, 1, 2, 3, 4]
return: stack=[],               flat_list=[0, 1, 2, 3, 4]

In this example, the stack maximum size is 2, because the input list (and therefore the tree) have depth 2.

Implementation

For the implementation, in python you can simplify a little bit by using iterators instead of simple lists. References to the (sub)iterators will be used to store sublists return addresses (instead of having both the list address and the index). This is not a big difference but I feel this is more readable (and also a bit faster):

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:   # post-order
            cursor_stack.pop()
            continue
        if is_list_like(item):  # pre-order
            cursor_stack.append(iter(item))
        elif item is not None:
            yield item          # in-order

def is_list_like(item):
    return isinstance(item, list)

Also, notice that in is_list_like I have isinstance(item, list), which could be changed to handle more input types, here I just wanted to have the simplest version where (iterable) is just a list. But you could also do that:

def is_list_like(item):
    try:
        iter(item)
        return not isinstance(item, str)  # strings are not lists (hmm...) 
    except TypeError:
        return False

This considers strings as “simple items” and therefore flatten_iter([["test", "a"], "b]) will return ["test", "a", "b"] and not ["t", "e", "s", "t", "a", "b"]. Remark that in that case, iter(item) is called twice on each item, let’s pretend it’s an exercise for the reader to make this cleaner.

Testing and remarks on other implementations

In the end, remember that you can’t print a infinitely nested list L using print(L) because internally it will use recursive calls to __repr__ (RecursionError: maximum recursion depth exceeded while getting the repr of an object). For the same reason, solutions to flatten involving str will fail with the same error message.

If you need to test your solution, you can use this function to generate a simple nested list:

def build_deep_list(depth):
    """Returns a list of the form $l_{depth} = [depth-1, l_{depth-1}]$
    with $depth > 1$ and $l_0 = [0]$.
    """
    sub_list = [0]
    for d in range(1, depth):
        sub_list = [d, sub_list]
    return sub_list

Which gives: build_deep_list(5) >>> [4, [3, [2, [1, [0]]]]].


回答 16

这是compiler.ast.flatten2.7.5中的实现:

def flatten(seq):
    l = []
    for elt in seq:
        t = type(elt)
        if t is tuple or t is list:
            for elt2 in flatten(elt):
                l.append(elt2)
        else:
            l.append(elt)
    return l

有更好,更快的方法(如果您已经到达这里,您已经看到了它们)

另请注意:

自2.6版起弃用:编译器软件包已在Python 3中删除。

Here’s the compiler.ast.flatten implementation in 2.7.5:

def flatten(seq):
    l = []
    for elt in seq:
        t = type(elt)
        if t is tuple or t is list:
            for elt2 in flatten(elt):
                l.append(elt2)
        else:
            l.append(elt)
    return l

There are better, faster methods (If you’ve reached here, you have seen them already)

Also note:

Deprecated since version 2.6: The compiler package has been removed in Python 3.


回答 17

完全hacky,但我认为它可以工作(取决于您的data_type)

flat_list = ast.literal_eval("[%s]"%re.sub("[\[\]]","",str(the_list)))

totally hacky but I think it would work (depending on your data_type)

flat_list = ast.literal_eval("[%s]"%re.sub("[\[\]]","",str(the_list)))

回答 18

只需使用一个funcy库: pip install funcy

import funcy


funcy.flatten([[[[1, 1], 1], 2], 3]) # returns generator
funcy.lflatten([[[[1, 1], 1], 2], 3]) # returns list

Just use a funcy library: pip install funcy

import funcy


funcy.flatten([[[[1, 1], 1], 2], 3]) # returns generator
funcy.lflatten([[[[1, 1], 1], 2], 3]) # returns list

回答 19

这是另一种py2方法,我不确定它是最快还是最优雅也不最安全…

from collections import Iterable
from itertools import imap, repeat, chain


def flat(seqs, ignore=(int, long, float, basestring)):
    return repeat(seqs, 1) if any(imap(isinstance, repeat(seqs), ignore)) or not isinstance(seqs, Iterable) else chain.from_iterable(imap(flat, seqs))

它可以忽略您想要的任何特定(或派生)类型,它返回一个迭代器,因此您可以将其转换为任何特定的容器(例如list,tuple,dict或仅使用它)以减少内存占用,无论是好是坏它可以处理初始的不可迭代对象,例如int …

请注意,大多数繁重的工作都是在C中完成的,因为据我所知,这是itertools的实现方式,因此尽管是递归的,但AFAIK并不受python递归深度的限制,因为函数调用发生在C中,尽管这样做并不意味着您会受到内存的限制,特别是在OS X中,从今天开始,它的堆栈大小有了硬限制(OS X Mavericks)…

有一种稍微快一点的方法,但可移植性较低的方法,只有在可以假定可以明确确定输入的基本元素的情况下,才使用它,否则,将获得无限递归,并且具有有限堆栈大小的OS X将很快地引发细分错误…

def flat(seqs, ignore={int, long, float, str, unicode}):
    return repeat(seqs, 1) if type(seqs) in ignore or not isinstance(seqs, Iterable) else chain.from_iterable(imap(flat, seqs))

在这里,我们使用集合来检查类型,因此需要O(1)与O(类型数)来检查是否应忽略某个元素,尽管任何具有声明的被忽略类型的派生类型的值都将失败,这就是为什么要使用它strunicode因此请谨慎使用…

测试:

import random

def test_flat(test_size=2000):
    def increase_depth(value, depth=1):
        for func in xrange(depth):
            value = repeat(value, 1)
        return value

    def random_sub_chaining(nested_values):
        for values in nested_values:
            yield chain((values,), chain.from_iterable(imap(next, repeat(nested_values, random.randint(1, 10)))))

    expected_values = zip(xrange(test_size), imap(str, xrange(test_size)))
    nested_values = random_sub_chaining((increase_depth(value, depth) for depth, value in enumerate(expected_values)))
    assert not any(imap(cmp, chain.from_iterable(expected_values), flat(chain(((),), nested_values, ((),)))))

>>> test_flat()
>>> list(flat([[[1, 2, 3], [4, 5]], 6]))
[1, 2, 3, 4, 5, 6]
>>>  

$ uname -a
Darwin Samys-MacBook-Pro.local 13.3.0 Darwin Kernel Version 13.3.0: Tue Jun  3 21:27:35 PDT 2014; root:xnu-2422.110.17~1/RELEASE_X86_64 x86_64
$ python --version
Python 2.7.5

Here is another py2 approach, Im not sure if its the fastest or the most elegant nor safest …

from collections import Iterable
from itertools import imap, repeat, chain


def flat(seqs, ignore=(int, long, float, basestring)):
    return repeat(seqs, 1) if any(imap(isinstance, repeat(seqs), ignore)) or not isinstance(seqs, Iterable) else chain.from_iterable(imap(flat, seqs))

It can ignore any specific (or derived) type you would like, it returns an iterator, so you can convert it to any specific container such as list, tuple, dict or simply consume it in order to reduce memory footprint, for better or worse it can handle initial non-iterable objects such as int …

Note most of the heavy lifting is done in C, since as far as I know thats how itertools are implemented, so while it is recursive, AFAIK it isn’t bounded by python recursion depth since the function calls are happening in C, though this doesn’t mean you are bounded by memory, specially in OS X where its stack size has a hard limit as of today (OS X Mavericks) …

there is a slightly faster approach, but less portable method, only use it if you can assume that the base elements of the input can be explicitly determined otherwise, you’ll get an infinite recursion, and OS X with its limited stack size, will throw a segmentation fault fairly quickly …

def flat(seqs, ignore={int, long, float, str, unicode}):
    return repeat(seqs, 1) if type(seqs) in ignore or not isinstance(seqs, Iterable) else chain.from_iterable(imap(flat, seqs))

here we are using sets to check for the type so it takes O(1) vs O(number of types) to check whether or not an element should be ignored, though of course any value with derived type of the stated ignored types will fail, this is why its using str, unicode so use it with caution …

tests:

import random

def test_flat(test_size=2000):
    def increase_depth(value, depth=1):
        for func in xrange(depth):
            value = repeat(value, 1)
        return value

    def random_sub_chaining(nested_values):
        for values in nested_values:
            yield chain((values,), chain.from_iterable(imap(next, repeat(nested_values, random.randint(1, 10)))))

    expected_values = zip(xrange(test_size), imap(str, xrange(test_size)))
    nested_values = random_sub_chaining((increase_depth(value, depth) for depth, value in enumerate(expected_values)))
    assert not any(imap(cmp, chain.from_iterable(expected_values), flat(chain(((),), nested_values, ((),)))))

>>> test_flat()
>>> list(flat([[[1, 2, 3], [4, 5]], 6]))
[1, 2, 3, 4, 5, 6]
>>>  

$ uname -a
Darwin Samys-MacBook-Pro.local 13.3.0 Darwin Kernel Version 13.3.0: Tue Jun  3 21:27:35 PDT 2014; root:xnu-2422.110.17~1/RELEASE_X86_64 x86_64
$ python --version
Python 2.7.5

回答 20

不使用任何库:

def flat(l):
    def _flat(l, r):    
        if type(l) is not list:
            r.append(l)
        else:
            for i in l:
                r = r + flat(i)
        return r
    return _flat(l, [])



# example
test = [[1], [[2]], [3], [['a','b','c'] , [['z','x','y']], ['d','f','g']], 4]    
print flat(test) # prints [1, 2, 3, 'a', 'b', 'c', 'z', 'x', 'y', 'd', 'f', 'g', 4]

Without using any library:

def flat(l):
    def _flat(l, r):    
        if type(l) is not list:
            r.append(l)
        else:
            for i in l:
                r = r + flat(i)
        return r
    return _flat(l, [])



# example
test = [[1], [[2]], [3], [['a','b','c'] , [['z','x','y']], ['d','f','g']], 4]    
print flat(test) # prints [1, 2, 3, 'a', 'b', 'c', 'z', 'x', 'y', 'd', 'f', 'g', 4]

回答 21

使用itertools.chain

import itertools
from collections import Iterable

def list_flatten(lst):
    flat_lst = []
    for item in itertools.chain(lst):
        if isinstance(item, Iterable):
            item = list_flatten(item)
            flat_lst.extend(item)
        else:
            flat_lst.append(item)
    return flat_lst

或不链接:

def flatten(q, final):
    if not q:
        return
    if isinstance(q, list):
        if not isinstance(q[0], list):
            final.append(q[0])
        else:
            flatten(q[0], final)
        flatten(q[1:], final)
    else:
        final.append(q)

Using itertools.chain:

import itertools
from collections import Iterable

def list_flatten(lst):
    flat_lst = []
    for item in itertools.chain(lst):
        if isinstance(item, Iterable):
            item = list_flatten(item)
            flat_lst.extend(item)
        else:
            flat_lst.append(item)
    return flat_lst

Or without chaining:

def flatten(q, final):
    if not q:
        return
    if isinstance(q, list):
        if not isinstance(q[0], list):
            final.append(q[0])
        else:
            flatten(q[0], final)
        flatten(q[1:], final)
    else:
        final.append(q)

回答 22

我使用递归来解决任何深度的嵌套列表

def combine_nlist(nlist,init=0,combiner=lambda x,y: x+y):
    '''
    apply function: combiner to a nested list element by element(treated as flatten list)
    '''
    current_value=init
    for each_item in nlist:
        if isinstance(each_item,list):
            current_value =combine_nlist(each_item,current_value,combiner)
        else:
            current_value = combiner(current_value,each_item)
    return current_value

因此,在定义函数combin_nlist之后,就很容易使用此函数进行展平。或者,您可以将其组合为一个功能。我喜欢我的解决方案,因为它可以应用于任何嵌套列表。

def flatten_nlist(nlist):
    return combine_nlist(nlist,[],lambda x,y:x+[y])

结果

In [379]: flatten_nlist([1,2,3,[4,5],[6],[[[7],8],9],10])
Out[379]: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

I used recursive to solve nested list with any depth

def combine_nlist(nlist,init=0,combiner=lambda x,y: x+y):
    '''
    apply function: combiner to a nested list element by element(treated as flatten list)
    '''
    current_value=init
    for each_item in nlist:
        if isinstance(each_item,list):
            current_value =combine_nlist(each_item,current_value,combiner)
        else:
            current_value = combiner(current_value,each_item)
    return current_value

So after i define function combine_nlist, it is easy to use this function do flatting. Or you can combine it into one function. I like my solution because it can be applied to any nested list.

def flatten_nlist(nlist):
    return combine_nlist(nlist,[],lambda x,y:x+[y])

result

In [379]: flatten_nlist([1,2,3,[4,5],[6],[[[7],8],9],10])
Out[379]: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

回答 23

最简单的方法是使用变身利用图书馆pip install morph

代码是:

import morph

list = [[[1, 2, 3], [4, 5]], 6]
flattened_list = morph.flatten(list)  # returns [1, 2, 3, 4, 5, 6]

The easiest way is to use the morph library using pip install morph.

The code is:

import morph

list = [[[1, 2, 3], [4, 5]], 6]
flattened_list = morph.flatten(list)  # returns [1, 2, 3, 4, 5, 6]

回答 24

我知道已经有很多很棒的答案,但是我想添加一个使用功能性编程方法解决问题的答案。在这个答案中,我使用了双重递归:

def flatten_list(seq):
    if not seq:
        return []
    elif isinstance(seq[0],list):
        return (flatten_list(seq[0])+flatten_list(seq[1:]))
    else:
        return [seq[0]]+flatten_list(seq[1:])

print(flatten_list([1,2,[3,[4],5],[6,7]]))

输出:

[1, 2, 3, 4, 5, 6, 7]

I am aware that there are already many awesome answers but i wanted to add an answer that uses the functional programming method of solving the question. In this answer i make use of double recursion :

def flatten_list(seq):
    if not seq:
        return []
    elif isinstance(seq[0],list):
        return (flatten_list(seq[0])+flatten_list(seq[1:]))
    else:
        return [seq[0]]+flatten_list(seq[1:])

print(flatten_list([1,2,[3,[4],5],[6,7]]))

output:

[1, 2, 3, 4, 5, 6, 7]

回答 25

我不确定这是否一定更快或更有效,但这是我要做的:

def flatten(lst):
    return eval('[' + str(lst).replace('[', '').replace(']', '') + ']')

L = [[[1, 2, 3], [4, 5]], 6]
print(flatten(L))

flatten这里的函数将列表转换为字符串,取出所有方括号,将方括号附加到两端,然后将其重新转换为列表。

虽然,如果您知道列表中的方括号中包含字符串,例如[[1, 2], "[3, 4] and [5]"],则您需要做其他事情。

I’m not sure if this is necessarily quicker or more effective, but this is what I do:

def flatten(lst):
    return eval('[' + str(lst).replace('[', '').replace(']', '') + ']')

L = [[[1, 2, 3], [4, 5]], 6]
print(flatten(L))

The flatten function here turns the list into a string, takes out all of the square brackets, attaches square brackets back onto the ends, and turns it back into a list.

Although, if you knew you would have square brackets in your list in strings, like [[1, 2], "[3, 4] and [5]"], you would have to do something else.


回答 26

这是在python2上进行flatten的简单实现

flatten=lambda l: reduce(lambda x,y:x+y,map(flatten,l),[]) if isinstance(l,list) else [l]

test=[[1,2,3,[3,4,5],[6,7,[8,9,[10,[11,[12,13,14]]]]]],]
print flatten(test)

#output [1, 2, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

This is a simple implement of flatten on python2

flatten=lambda l: reduce(lambda x,y:x+y,map(flatten,l),[]) if isinstance(l,list) else [l]

test=[[1,2,3,[3,4,5],[6,7,[8,9,[10,[11,[12,13,14]]]]]],]
print flatten(test)

#output [1, 2, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

回答 27

这将使列表或字典(或列表列表或字典的字典等)变平。它假定值是字符串,并创建一个字符串,该字符串将每个项目与分隔符参数连接在一起。如果需要,可以使用分隔符随后将结果拆分为列表对象。如果下一个值是列表或字符串,则使用递归。使用key参数来告诉您要使用字典对象的键还是值(将key设置为false)。

def flatten_obj(n_obj, key=True, my_sep=''):
    my_string = ''
    if type(n_obj) == list:
        for val in n_obj:
            my_sep_setter = my_sep if my_string != '' else ''
            if type(val) == list or type(val) == dict:
                my_string += my_sep_setter + flatten_obj(val, key, my_sep)
            else:
                my_string += my_sep_setter + val
    elif type(n_obj) == dict:
        for k, v in n_obj.items():
            my_sep_setter = my_sep if my_string != '' else ''
            d_val = k if key else v
            if type(v) == list or type(v) == dict:
                my_string += my_sep_setter + flatten_obj(v, key, my_sep)
            else:
                my_string += my_sep_setter + d_val
    elif type(n_obj) == str:
        my_sep_setter = my_sep if my_string != '' else ''
        my_string += my_sep_setter + n_obj
        return my_string
    return my_string

print(flatten_obj(['just', 'a', ['test', 'to', 'try'], 'right', 'now', ['or', 'later', 'today'],
                [{'dictionary_test': 'test'}, {'dictionary_test_two': 'later_today'}, 'my power is 9000']], my_sep=', ')

Yield:

just, a, test, to, try, right, now, or, later, today, dictionary_test, dictionary_test_two, my power is 9000

This will flatten a list or dictionary (or list of lists or dictionaries of dictionaries etc). It assumes that the values are strings and it creates a string that concatenates each item with a separator argument. If you wanted you could use the separator to split the result into a list object afterward. It uses recursion if the next value is a list or a string. Use the key argument to tell whether you want the keys or the values (set key to false) from the dictionary object.

def flatten_obj(n_obj, key=True, my_sep=''):
    my_string = ''
    if type(n_obj) == list:
        for val in n_obj:
            my_sep_setter = my_sep if my_string != '' else ''
            if type(val) == list or type(val) == dict:
                my_string += my_sep_setter + flatten_obj(val, key, my_sep)
            else:
                my_string += my_sep_setter + val
    elif type(n_obj) == dict:
        for k, v in n_obj.items():
            my_sep_setter = my_sep if my_string != '' else ''
            d_val = k if key else v
            if type(v) == list or type(v) == dict:
                my_string += my_sep_setter + flatten_obj(v, key, my_sep)
            else:
                my_string += my_sep_setter + d_val
    elif type(n_obj) == str:
        my_sep_setter = my_sep if my_string != '' else ''
        my_string += my_sep_setter + n_obj
        return my_string
    return my_string

print(flatten_obj(['just', 'a', ['test', 'to', 'try'], 'right', 'now', ['or', 'later', 'today'],
                [{'dictionary_test': 'test'}, {'dictionary_test_two': 'later_today'}, 'my power is 9000']], my_sep=', ')

yields:

just, a, test, to, try, right, now, or, later, today, dictionary_test, dictionary_test_two, my power is 9000

回答 28

如果您喜欢递归,这可能是您感兴趣的解决方案:

def f(E):
    if E==[]: 
        return []
    elif type(E) != list: 
        return [E]
    else:
        a = f(E[0])
        b = f(E[1:])
        a.extend(b)
        return a

我实际上是从前一段时间写的一些练习Scheme代码中改编而成的。

请享用!

If you like recursion, this might be a solution of interest to you:

def f(E):
    if E==[]: 
        return []
    elif type(E) != list: 
        return [E]
    else:
        a = f(E[0])
        b = f(E[1:])
        a.extend(b)
        return a

I actually adapted this from some practice Scheme code that I had written a while back.

Enjoy!


回答 29

我是python的新手,来自Lisp背景。这是我想出的(查看lulz的var名称):

def flatten(lst):
    if lst:
        car,*cdr=lst
        if isinstance(car,(list,tuple)):
            if cdr: return flatten(car) + flatten(cdr)
            return flatten(car)
        if cdr: return [car] + flatten(cdr)
        return [car]

似乎可以工作。测试:

flatten((1,2,3,(4,5,6,(7,8,(((1,2)))))))

返回:

[1, 2, 3, 4, 5, 6, 7, 8, 1, 2]

I’m new to python and come from a lisp background. This is what I came up with (check out the var names for lulz):

def flatten(lst):
    if lst:
        car,*cdr=lst
        if isinstance(car,(list,tuple)):
            if cdr: return flatten(car) + flatten(cdr)
            return flatten(car)
        if cdr: return [car] + flatten(cdr)
        return [car]

Seems to work. Test:

flatten((1,2,3,(4,5,6,(7,8,(((1,2)))))))

returns:

[1, 2, 3, 4, 5, 6, 7, 8, 1, 2]

导入语句是否应该始终位于模块的顶部?

问题:导入语句是否应该始终位于模块的顶部?

PEP 08指出:

导入总是放在文件的顶部,紧随任何模块注释和文档字符串之后,以及模块全局变量和常量之前。

但是,如果仅在极少数情况下使用我要导入的类/方法/函数,那么在需要时进行导入肯定会更有效吗?

这不是吗?

class SomeClass(object):

    def not_often_called(self)
        from datetime import datetime
        self.datetime = datetime.now()

比这更有效?

from datetime import datetime

class SomeClass(object):

    def not_often_called(self)
        self.datetime = datetime.now()

PEP 08 states:

Imports are always put at the top of the file, just after any module comments and docstrings, and before module globals and constants.

However if the class/method/function that I am importing is only used in rare cases, surely it is more efficient to do the import when it is needed?

Isn’t this:

class SomeClass(object):

    def not_often_called(self)
        from datetime import datetime
        self.datetime = datetime.now()

more efficient than this?

from datetime import datetime

class SomeClass(object):

    def not_often_called(self)
        self.datetime = datetime.now()

回答 0

模块导入非常快,但不是即时的。这意味着:

  • 将导入放在模块顶部很好,因为这是微不足道的成本,只需要支付一次即可。
  • 将导入放在函数中会导致对该函数的调用花费更长时间。

因此,如果您关心效率,则将进口放在首位。仅在您的分析显示有帮助的情况下,才将它们移入函数中(您进行了概要分析以查看最能改善性能的地方,对吗?)


我见过执行延迟导入的最佳原因是:

  • 可选的库支持。如果您的代码具有使用不同库的多个路径,则在未安装可选库的情况下不要中断。
  • __init__.py插件的中,可能已导入但并未实际使用。例如Bazaar插件,它使用bzrlib的延迟加载框架。

Module importing is quite fast, but not instant. This means that:

  • Putting the imports at the top of the module is fine, because it’s a trivial cost that’s only paid once.
  • Putting the imports within a function will cause calls to that function to take longer.

So if you care about efficiency, put the imports at the top. Only move them into a function if your profiling shows that would help (you did profile to see where best to improve performance, right??)


The best reasons I’ve seen to perform lazy imports are:

  • Optional library support. If your code has multiple paths that use different libraries, don’t break if an optional library is not installed.
  • In the __init__.py of a plugin, which might be imported but not actually used. Examples are Bazaar plugins, which use bzrlib‘s lazy-loading framework.

回答 1

将import语句放在函数内部可以防止循环依赖。例如,如果您有两个模块X.py和Y.py,并且它们都需要互相导入,那么当您导入其中一个模块导致无限循环时,这将导致循环依赖。如果将import语句移动到一个模块中,则它将在调用该函数之前不会尝试导入另一个模块,并且该模块将已经被导入,因此不会出现无限循环。在此处阅读更多内容-effbot.org/zone/import-confusion.htm

Putting the import statement inside of a function can prevent circular dependencies. For example, if you have 2 modules, X.py and Y.py, and they both need to import each other, this will cause a circular dependency when you import one of the modules causing an infinite loop. If you move the import statement in one of the modules then it won’t try to import the other module till the function is called, and that module will already be imported, so no infinite loop. Read here for more – effbot.org/zone/import-confusion.htm


回答 2

我采用了将所有导入放入使用它们的函数中而不是放在模块顶部的做法。

我得到的好处是能够更可靠地进行重构。当我将一个功能从一个模块移动到另一个模块时,我知道该功能将继续使用其完整的测试遗留功能。如果我在模块的顶部放置了导入文件,那么当我移动一个函数时,我发现我花了很多时间来使新模块的导入文件完整而最少。重构IDE可能与此无关。

如其他地方提到的那样,存在速度损失。我已经在我的应用程序中对此进行了测量,发现对于我的目的而言它并不重要。

能够预先查看所有模块依赖性而无需借助搜索(例如grep),也很不错。但是,我关心模块依赖性的原因通常是因为我正在安装,重构或移动包含多个文件的整个系统,而不仅仅是一个模块。在这种情况下,无论如何,我将执行全局搜索以确保我具有系统级依赖项。因此,我还没有发现全局导入可以帮助我在实践中理解系统。

我通常将检查的内容sys放入if __name__=='__main__'检查中,然后将参数(如sys.argv[1:])传递给main()函数。这使我可以mainsys尚未导入的上下文中使用。

I have adopted the practice of putting all imports in the functions that use them, rather than at the top of the module.

The benefit I get is the ability to refactor more reliably. When I move a function from one module to another, I know that the function will continue to work with all of its legacy of testing intact. If I have my imports at the top of the module, when I move a function, I find that I end up spending a lot of time getting the new module’s imports complete and minimal. A refactoring IDE might make this irrelevant.

There is a speed penalty as mentioned elsewhere. I have measured this in my application and found it to be insignificant for my purposes.

It is also nice to be able to see all module dependencies up front without resorting to search (e.g. grep). However, the reason I care about module dependencies is generally because I’m installing, refactoring, or moving an entire system comprising multiple files, not just a single module. In that case, I’m going to perform a global search anyway to make sure I have the system-level dependencies. So I have not found global imports to aid my understanding of a system in practice.

I usually put the import of sys inside the if __name__=='__main__' check and then pass arguments (like sys.argv[1:]) to a main() function. This allows me to use main in a context where sys has not been imported.


回答 3

在大多数情况下,这样做对于保持清晰性和明智性很有用,但并非总是如此。以下是几个可能会在其他地方导入模块的情况的示例。

首先,您可以拥有一个带有以下形式的单元测试的模块:

if __name__ == '__main__':
    import foo
    aa = foo.xyz()         # initiate something for the test

其次,您可能需要在运行时有条件地导入一些不同的模块。

if [condition]:
    import foo as plugin_api
else:
    import bar as plugin_api
xx = plugin_api.Plugin()
[...]

在其他情况下,您可能会将导入放置在代码的其他部分中。

Most of the time this would be useful for clarity and sensible to do but it’s not always the case. Below are a couple of examples of circumstances where module imports might live elsewhere.

Firstly, you could have a module with a unit test of the form:

if __name__ == '__main__':
    import foo
    aa = foo.xyz()         # initiate something for the test

Secondly, you might have a requirement to conditionally import some different module at runtime.

if [condition]:
    import foo as plugin_api
else:
    import bar as plugin_api
xx = plugin_api.Plugin()
[...]

There are probably other situations where you might place imports in other parts in the code.


回答 4

当函数被调用为零或一次时,第一种变体的确比第二种变体更有效。但是,在第二次及其后的调用中,“导入每个调用”方法实际上效率较低。请参阅此链接以获取延迟加载技术,该技术通过执行“延迟导入”结合了两种方法的优点。

但是,除了效率之外,还有其他原因导致您可能会偏爱一个。一种方法是使阅读该模块相关代码的人更加清楚。它们还具有非常不同的故障特征-如果没有“ datetime”模块,第一个将在加载时失败,而第二个在调用该方法之前不会失败。

补充说明:在IronPython中,导入可能比CPython中昂贵得多,因为代码基本上是在导入时进行编译的。

The first variant is indeed more efficient than the second when the function is called either zero or one times. With the second and subsequent invocations, however, the “import every call” approach is actually less efficient. See this link for a lazy-loading technique that combines the best of both approaches by doing a “lazy import”.

But there are reasons other than efficiency why you might prefer one over the other. One approach is makes it much more clear to someone reading the code as to the dependencies that this module has. They also have very different failure characteristics — the first will fail at load time if there’s no “datetime” module while the second won’t fail until the method is called.

Added Note: In IronPython, imports can be quite a bit more expensive than in CPython because the code is basically being compiled as it’s being imported.


回答 5

Curt提出了一个很好的观点:第二个版本更清晰,它将在加载时而不是以后失败,并且出乎意料地失败。

通常,我不必担心模块的加载效率,因为它的速度(a)非常快,而(b)大多仅在启动时发生。

如果必须在意外的时刻加载重量级模块,则可以通过该__import__函数动态加载它们,并确保捕获ImportError异常并以合理的方式处理它们,这可能更有意义。

Curt makes a good point: the second version is clearer and will fail at load time rather than later, and unexpectedly.

Normally I don’t worry about the efficiency of loading modules, since it’s (a) pretty fast, and (b) mostly only happens at startup.

If you have to load heavyweight modules at unexpected times, it probably makes more sense to load them dynamically with the __import__ function, and be sure to catch ImportError exceptions, and handle them in a reasonable manner.


回答 6

我不会担心过多地预先加载模块的效率。模块占用的内存不会很大(假设它足够模块化),启动成本可以忽略不计。

在大多数情况下,您希望将模块加载到源文件的顶部。对于阅读您的代码的人来说,它更容易分辨出哪个功能或对象来自哪个模块。

将模块导入代码中其他位置的一个很好的理由是,如果该模块在调试语句中使用过。

例如:

do_something_with_x(x)

我可以使用以下命令调试它:

from pprint import pprint
pprint(x)
do_something_with_x(x)

当然,将模块导入代码中其他位置的另一个原因是是否需要动态导入它们。这是因为您几乎别无选择。

我不会担心过多地预先加载模块的效率。模块占用的内存不会很大(假设它足够模块化),启动成本可以忽略不计。

I wouldn’t worry about the efficiency of loading the module up front too much. The memory taken up by the module won’t be very big (assuming it’s modular enough) and the startup cost will be negligible.

In most cases you want to load the modules at the top of the source file. For somebody reading your code, it makes it much easier to tell what function or object came from what module.

One good reason to import a module elsewhere in the code is if it’s used in a debugging statement.

For example:

do_something_with_x(x)

I could debug this with:

from pprint import pprint
pprint(x)
do_something_with_x(x)

Of course, the other reason to import modules elsewhere in the code is if you need to dynamically import them. This is because you pretty much don’t have any choice.

I wouldn’t worry about the efficiency of loading the module up front too much. The memory taken up by the module won’t be very big (assuming it’s modular enough) and the startup cost will be negligible.


回答 7

这是一个折衷,只有程序员才能决定进行。

情况1通过在需要之前不导入datetime模块(并进行可能需要的任何初始化)来节省一些内存和启动时间。请注意,“仅在调用时”执行导入也意味着“在调用时每次”进行导入,因此第一个调用之后的每个调用仍会产生执行导入的额外开销。

情况2通过预先导入datetime来节省一些执行时间和延迟,以便not_often_drawn()在调用时将更快地返回,并且还不会在每次调用时都导致导入开销。

除了效率外,如果import语句在…前面,则更容易在前面看到模块依赖性。将它们隐藏在代码中会使您更难于找到所需的模块。

就个人而言,除了单元测试之类的东西外,我通常都遵循PEP,因此我不希望总是加载它,因为我知道除了测试代码之外不会使用它们。

It’s a tradeoff, that only the programmer can decide to make.

Case 1 saves some memory and startup time by not importing the datetime module (and doing whatever initialization it might require) until needed. Note that doing the import ‘only when called’ also means doing it ‘every time when called’, so each call after the first one is still incurring the additional overhead of doing the import.

Case 2 save some execution time and latency by importing datetime beforehand so that not_often_called() will return more quickly when it is called, and also by not incurring the overhead of an import on every call.

Besides efficiency, it’s easier to see module dependencies up front if the import statements are … up front. Hiding them down in the code can make it more difficult to easily find what modules something depends on.

Personally I generally follow the PEP except for things like unit tests and such that I don’t want always loaded because I know they aren’t going to be used except for test code.


回答 8

这是一个示例,其中所有导入都位于最顶部(这是我唯一需要这样做的时间)。我希望能够在Un * x和Windows上终止子进程。

import os
# ...
try:
    kill = os.kill  # will raise AttributeError on Windows
    from signal import SIGTERM
    def terminate(process):
        kill(process.pid, SIGTERM)
except (AttributeError, ImportError):
    try:
        from win32api import TerminateProcess  # use win32api if available
        def terminate(process):
            TerminateProcess(int(process._handle), -1)
    except ImportError:
        def terminate(process):
            raise NotImplementedError  # define a dummy function

(评论:约翰·米利金说的话。)

Here’s an example where all the imports are at the very top (this is the only time I’ve needed to do this). I want to be able to terminate a subprocess on both Un*x and Windows.

import os
# ...
try:
    kill = os.kill  # will raise AttributeError on Windows
    from signal import SIGTERM
    def terminate(process):
        kill(process.pid, SIGTERM)
except (AttributeError, ImportError):
    try:
        from win32api import TerminateProcess  # use win32api if available
        def terminate(process):
            TerminateProcess(int(process._handle), -1)
    except ImportError:
        def terminate(process):
            raise NotImplementedError  # define a dummy function

(On review: what John Millikin said.)


回答 9

就像许多其他优化一样,您会牺牲一些可读性来提高速度。如John所述,如果您完成了概要分析作业,并且发现这是一项非常有用的更改,并且您需要额外的速度,则可以继续进行。在所有其他进口商品上加上注释可能会很好:

from foo import bar
from baz import qux
# Note: datetime is imported in SomeClass below

This is like many other optimizations – you sacrifice some readability for speed. As John mentioned, if you’ve done your profiling homework and found this to be a significantly useful enough change and you need the extra speed, then go for it. It’d probably be good to put a note up with all the other imports:

from foo import bar
from baz import qux
# Note: datetime is imported in SomeClass below

回答 10

模块初始化仅发生一次-在首次导入时。如果有问题的模块来自标准库,那么您也可能会从程序中的其他模块导入它。对于像日期时间一样普遍的模块,它也可能是许多其他标准库的依赖项。由于模块初始化已经发生,因此import语句的花费很少。此时,它所做的全部工作就是将现有模块对象绑定到本地范围。

将该信息与用于可读性的参数相结合,我想说最好在模块范围内使用import语句。

Module initialization only occurs once – on the first import. If the module in question is from the standard library, then you will likely import it from other modules in your program as well. For a module as prevalent as datetime, it is also likely a dependency for a slew of other standard libraries. The import statement would cost very little then since the module intialization would have happened already. All it is doing at this point is binding the existing module object to the local scope.

Couple that information with the argument for readability and I would say that it is best to have the import statement at module scope.


回答 11

只是为了完成萌的答案和原始问题:

当我们不得不处理循环依赖时,我们可以做一些“技巧”。假设我们正在与模块的工作a.py,并b.py包含x()和B y()分别。然后:

  1. 我们可以移动from imports模块底部的之一。
  2. 我们可以移动from imports实际上需要导入的函数或方法的内部之一(这并不总是可能的,因为您可以在多个地方使用它)。
  3. 我们可以将两者之一更改from imports为如下所示的导入:import a

因此,总结一下。如果您不是在处理循环依赖关系,而是采取某种技巧来避免它们,那么最好将所有导入内容放在顶部,因为在此问题的其他答案中已经说明了这些原因。并且,请在做“技巧”时添加评论,我们始终欢迎您!:)

Just to complete Moe’s answer and the original question:

When we have to deal with circular dependences we can do some “tricks”. Assuming we’re working with modules a.py and b.py that contain x() and b y(), respectively. Then:

  1. We can move one of the from imports at the bottom of the module.
  2. We can move one of the from imports inside the function or method that is actually requiring the import (this isn’t always possible, as you may use it from several places).
  3. We can change one of the two from imports to be an import that looks like: import a

So, to conclude. If you aren’t dealing with circular dependencies and doing some kind of trick to avoid them, then it’s better to put all your imports at the top because of the reasons already explained in other answers to this question. And please, when doing this “tricks” include a comment, it’s always welcome! :)


回答 12

除了已经给出的出色答案外,值得注意的是,进口商品的摆放不仅是风格问题。有时,模块具有隐式依赖关系,需要首先导入或初始化,而顶级导入可能会导致违反所需的执行顺序。

这个问题通常出现在Apache Spark的Python API中,您需要在导入任何pyspark软件包或模块之前初始化SparkContext。最好将pyspark导入放置在保证SparkContext可用的范围内。

In addition to the excellent answers already given, it’s worth noting that the placement of imports is not merely a matter of style. Sometimes a module has implicit dependencies that need to be imported or initialized first, and a top-level import could lead to violations of the required order of execution.

This issue often comes up in Apache Spark’s Python API, where you need to initialize the SparkContext before importing any pyspark packages or modules. It’s best to place pyspark imports in a scope where the SparkContext is guaranteed to be available.


回答 13

我很惊讶地没有看到已经发布的重复负载检查的实际成本数字,尽管对预期的结果有很多很好的解释。

如果您在顶部导入,则无论如何都会承受重击。这个数字很小,但是通常以毫秒为单位,而不是纳秒。

如果导入功能(S)之内,那么你只需要命中的加载,如果首次调用这些功能之一。正如许多人指出的那样,如果根本不发生这种情况,则可以节省加载时间。但是,如果函数被调用很多,您将遭受一次重复的打击,尽管命中率要小得多(用于检查它是否已加载;不是实际重新加载)。另一方面,正如@aaronasterling指出的那样,您还可以节省一点,因为在函数中进行导入使该函数可以使用稍快的局部变量查找来稍后标识名称(http://stackoverflow.com/questions/477096/python- import-coding-style / 4789963#4789963)。

这是一个简单测试的结果,该测试从函数内部导入了一些东西。报告的时间(在2.3 GHz Intel Core i7上的Python 2.7.14中)显示如下(第二次调用比以后的调用更多,这似乎是一致的,尽管我不知道为什么)。

 0 foo:   14429.0924 µs
 1 foo:      63.8962 µs
 2 foo:      10.0136 µs
 3 foo:       7.1526 µs
 4 foo:       7.8678 µs
 0 bar:       9.0599 µs
 1 bar:       6.9141 µs
 2 bar:       7.1526 µs
 3 bar:       7.8678 µs
 4 bar:       7.1526 µs

编码:

from __future__ import print_function
from time import time

def foo():
    import collections
    import re
    import string
    import math
    import subprocess
    return

def bar():
    import collections
    import re
    import string
    import math
    import subprocess
    return

t0 = time()
for i in xrange(5):
    foo()
    t1 = time()
    print("    %2d foo: %12.4f \xC2\xB5s" % (i, (t1-t0)*1E6))
    t0 = t1
for i in xrange(5):
    bar()
    t1 = time()
    print("    %2d bar: %12.4f \xC2\xB5s" % (i, (t1-t0)*1E6))
    t0 = t1

I was surprised not to see actual cost numbers for the repeated load-checks posted already, although there are many good explanations of what to expect.

If you import at the top, you take the load hit no matter what. That’s pretty small, but commonly in the milliseconds, not nanoseconds.

If you import within a function(s), then you only take the hit for loading if and when one of those functions is first called. As many have pointed out, if that doesn’t happen at all, you save the load time. But if the function(s) get called a lot, you take a repeated though much smaller hit (for checking that it has been loaded; not for actually re-loading). On the other hand, as @aaronasterling pointed out you also save a little because importing within a function lets the function use slightly-faster local variable lookups to identify the name later (http://stackoverflow.com/questions/477096/python-import-coding-style/4789963#4789963).

Here are the results of a simple test that imports a few things from inside a function. The times reported (in Python 2.7.14 on a 2.3 GHz Intel Core i7) are shown below (the 2nd call taking more than later calls seems consistent, though I don’t know why).

 0 foo:   14429.0924 µs
 1 foo:      63.8962 µs
 2 foo:      10.0136 µs
 3 foo:       7.1526 µs
 4 foo:       7.8678 µs
 0 bar:       9.0599 µs
 1 bar:       6.9141 µs
 2 bar:       7.1526 µs
 3 bar:       7.8678 µs
 4 bar:       7.1526 µs

The code:

from __future__ import print_function
from time import time

def foo():
    import collections
    import re
    import string
    import math
    import subprocess
    return

def bar():
    import collections
    import re
    import string
    import math
    import subprocess
    return

t0 = time()
for i in xrange(5):
    foo()
    t1 = time()
    print("    %2d foo: %12.4f \xC2\xB5s" % (i, (t1-t0)*1E6))
    t0 = t1
for i in xrange(5):
    bar()
    t1 = time()
    print("    %2d bar: %12.4f \xC2\xB5s" % (i, (t1-t0)*1E6))
    t0 = t1

回答 14

我不希望提供完整的答案,因为其他人已经做得很好。当我发现在功能内部导入模块特别有用时,我只想提及一个用例。我的应用程序使用存储在特定位置的python软件包和模块作为插件。在应用程序启动期间,应用程序遍历该位置的所有模块并将其导入,然后在模块内部查找,如果找到了插件的安装点(在我的情况下,它是具有唯一标识的某些基类的子类ID)将其注册。插件的数量很大(现在有几十个,但将来可能有数百个),每个插件很少使用。在应用程序启动过程中,在我的插件模块顶部添加了第三方库,这会带来一些损失。尤其是某些第三方库的导入非常繁重(例如,密谋导入甚至尝试连接到Internet并下载一些内容,这些内容在启动时增加了大约一秒钟的时间)。通过优化插件中的导入(仅在使用它们的函数中调用它们),我设法将启动时间从10秒缩短到大约2秒。对于我的用户而言,这是一个很大的差异。

所以我的答案是不,不要总是将导入放在模块的顶部。

I do not aspire to provide complete answer, because others have already done this very well. I just want to mention one use case when I find especially useful to import modules inside functions. My application uses python packages and modules stored in certain location as plugins. During application startup, the application walks through all the modules in the location and imports them, then it looks inside the modules and if it finds some mounting points for the plugins (in my case it is a subclass of a certain base class having a unique ID) it registers them. The number of plugins is large (now dozens, but maybe hundreds in the future) and each of them is used quite rarely. Having imports of third party libraries at the top of my plugin modules was a bit penalty during application startup. Especially some thirdparty libraries are heavy to import (e.g. import of plotly even tries to connect to internet and download something which was adding about one second to startup). By optimizing imports (calling them only in the functions where they are used) in the plugins I managed to shrink the startup from 10 seconds to some 2 seconds. That is a big difference for my users.

So my answer is no, do not always put the imports at the top of your modules.


回答 15

有趣的是,到目前为止,没有一个答案提到了并行处理,当序列化的函数代码被推到其他内核时,例如ipyparallel的情况,可能需要在函数中引入导入。

It’s interesting that not a single answer mentioned parallel processing so far, where it might be REQUIRED that the imports are in the function, when the serialized function code is what is being pushed around to other cores, e.g. like in the case of ipyparallel.


回答 16

通过将变量/局部作用域导入函数内部,可以提高性能。这取决于函数中导入事物的用法。如果要循环很多次并访问模块全局对象,则将其作为本地导入可以有所帮助。

test.py

X=10
Y=11
Z=12
def add(i):
  i = i + 10

runlocal.py

from test import add, X, Y, Z

    def callme():
      x=X
      y=Y
      z=Z
      ladd=add 
      for i  in range(100000000):
        ladd(i)
        x+y+z

    callme()

运行

from test import add, X, Y, Z

def callme():
  for i in range(100000000):
    add(i)
    X+Y+Z

callme()

在Linux上使用一段时间显示收益很小

/usr/bin/time -f "\t%E real,\t%U user,\t%S sys" python run.py 
    0:17.80 real,   17.77 user, 0.01 sys
/tmp/test$ /usr/bin/time -f "\t%E real,\t%U user,\t%S sys" python runlocal.py 
    0:14.23 real,   14.22 user, 0.01 sys

真正的是壁钟。用户是程序中的时间。sys是时候进行系统调用了。

https://docs.python.org/3.5/reference/executionmodel.html#resolution-of-names

There can be a performance gain by importing variables/local scoping inside of a function. This depends on the usage of the imported thing inside the function. If you are looping many times and accessing a module global object, importing it as local can help.

test.py

X=10
Y=11
Z=12
def add(i):
  i = i + 10

runlocal.py

from test import add, X, Y, Z

    def callme():
      x=X
      y=Y
      z=Z
      ladd=add 
      for i  in range(100000000):
        ladd(i)
        x+y+z

    callme()

run.py

from test import add, X, Y, Z

def callme():
  for i in range(100000000):
    add(i)
    X+Y+Z

callme()

A time on Linux shows a small gain

/usr/bin/time -f "\t%E real,\t%U user,\t%S sys" python run.py 
    0:17.80 real,   17.77 user, 0.01 sys
/tmp/test$ /usr/bin/time -f "\t%E real,\t%U user,\t%S sys" python runlocal.py 
    0:14.23 real,   14.22 user, 0.01 sys

real is wall clock. user is time in program. sys is time for system calls.

https://docs.python.org/3.5/reference/executionmodel.html#resolution-of-names


回答 17

可读性

除了启动性能外,还有一个可读性参数可用于本地化import语句。例如,在我当前的第一个python项目中,使用python行号1283到1296:

listdata.append(['tk font version', font_version])
listdata.append(['Gtk version', str(Gtk.get_major_version())+"."+
                 str(Gtk.get_minor_version())+"."+
                 str(Gtk.get_micro_version())])

import xml.etree.ElementTree as ET

xmltree = ET.parse('/usr/share/gnome/gnome-version.xml')
xmlroot = xmltree.getroot()
result = []
for child in xmlroot:
    result.append(child.text)
listdata.append(['Gnome version', result[0]+"."+result[1]+"."+
                 result[2]+" "+result[3]])

如果该import语句位于文件的顶部,则必须向上滚动很长一段距离,或者按Home,以查找内容ET。然后,我将不得不导航回到第1283行以继续阅读代码。

确实,即使 import语句位于函数(或类)的顶部(如许多语句所放置的那样),也需要向上和向下分页。

显示Gnome版本号的操作很少,因此import文件顶部会引入不必要的启动延迟。

Readability

In addition to startup performance, there is a readability argument to be made for localizing import statements. For example take python line numbers 1283 through 1296 in my current first python project:

listdata.append(['tk font version', font_version])
listdata.append(['Gtk version', str(Gtk.get_major_version())+"."+
                 str(Gtk.get_minor_version())+"."+
                 str(Gtk.get_micro_version())])

import xml.etree.ElementTree as ET

xmltree = ET.parse('/usr/share/gnome/gnome-version.xml')
xmlroot = xmltree.getroot()
result = []
for child in xmlroot:
    result.append(child.text)
listdata.append(['Gnome version', result[0]+"."+result[1]+"."+
                 result[2]+" "+result[3]])

If the import statement was at the top of file I would have to scroll up a long way, or press Home, to find out what ET was. Then I would have to navigate back to line 1283 to continue reading code.

Indeed even if the import statement was at the top of the function (or class) as many would place it, paging up and back down would be required.

Displaying the Gnome version number will rarely be done so the import at top of file introduces unnecessary startup lag.


回答 18

我想提一下我的一个用例,与@John Millikin和@VK提到的用例非常相似:

可选进口

我使用Jupyter Notebook进行数据分析,并且使用相同的IPython Notebook作为所有分析的模板。在某些情况下,我需要导入Tensorflow来进行一些快速的模型运行,但有时我会在未设置tensorflow或导入缓慢的地方工作。在这些情况下,我将依赖Tensorflow的操作封装在一个辅助函数中,将tensorflow导入该函数内部,并将其绑定到按钮。

这样,我可以“重新启动并运行所有程序”,而不必等待导入,也不必在导入失败时恢复其余的单元格。

I would like to mention a usecase of mine, very similar to those mentioned by @John Millikin and @V.K. :

Optional Imports

I do data analysis with Jupyter Notebook, and I use the same IPython notebook as a template for all analyses. In some occasions, I need to import Tensorflow to do some quick model runs, but sometimes I work in places where tensorflow isn’t set up / is slow to import. In those cases, I encapsulate my Tensorflow-dependent operations in a helper function, import tensorflow inside that function, and bind it to a button.

This way, I could do “restart-and-run-all” without having to wait for the import, or having to resume the rest of the cells when it fails.


回答 19

这是一个有趣的讨论。像许多其他人一样,我什至从未考虑过这个话题。由于想要在我的一个库中使用Django ORM,我不得不在函数中具有导入功能。我不得不打电话django.setup()在导入模型类之前,我,因为这是文件的顶部,由于IoC注入器的构造,它被拖到了完全非Django的库代码中。

我有点四处乱窜,最后将django.setup()in放在单例构造函数中,并将相关的导入放在每个类方法的顶部。现在,这种方法工作正常,但是由于进口商品不在顶部而使我感到不安,而且我也开始担心进口商品的额外时间。然后我来到这里,以极大的兴趣阅读了大家对此的看法。

我有很长的C ++背景,现在使用Python / Cython。我对此的看法是,为什么不将导入内容放入函数中,除非它导致概要分析的瓶颈。这就像在需要变量之前为变量声明空间。麻烦的是,我有数千行代码,所有导入都在顶部!所以我想从现在开始,当我经过并有时间时,在这里和那里更改奇数文件。

This is a fascinating discussion. Like many others I had never even considered this topic. I got cornered into having to have the imports in the functions because of wanting to use the Django ORM in one of my libraries. I was having to call django.setup() before importing my model classes and because this was at the top of the file it was being dragged into completely non-Django library code because of the IoC injector construction.

I kind of hacked around a bit and ended up putting the django.setup() in the singleton constructor and the relevant import at the top of each class method. Now this worked fine but made me uneasy because the imports weren’t at the top and also I started worrying about the extra time hit of the imports. Then I came here and read with great interest everybody’s take on this.

I have a long C++ background and now use Python/Cython. My take on this is that why not put the imports in the function unless it causes you a profiled bottleneck. It’s only like declaring space for variables just before you need them. The trouble is I have thousands of lines of code with all the imports at the top! So I think I will do it from now on and change the odd file here and there when I’m passing through and have the time.


Ray 一个开放源码框架,为构建分布式应用程序提供简单、通用的API

https://github.com/ray-project/ray/raw/master/doc/source/images/ray_header_logo.png

Ray为构建分布式应用程序提供了简单、通用的API,为构建分布式应用程序提供简单、通用的API。Ray与RLlib(一个可伸缩的强化学习库)和Tune(一个可伸缩的超参数调整库)可以打包在一起。

Ray附带以下库,用于加速机器学习工作负载:

  • Tune:可伸缩的超参数调整
  • RLlib:可扩展强化学习
  • RaySGD:分布式培训包装器
  • Ray Serve:可扩展、可编程的服务

也有很多community integrations和Ray在一起,包括DaskMARSModinHorovodHugging FaceScikit-learn,以及其他。请查看full list of Ray distributed libraries here

使用以下选项安装Ray:pip install ray有关夜间车轮的信息,请参阅Installation page

快速入门

并行执行Python函数

import ray
ray.init()

@ray.remote
def f(x):
    return x * x

futures = [f.remote(i) for i in range(4)]
print(ray.get(futures))

要使用Ray的演员模型,请执行以下操作:

import ray
ray.init()

@ray.remote
class Counter(object):
    def __init__(self):
        self.n = 0

    def increment(self):
        self.n += 1

    def read(self):
        return self.n

counters = [Counter.remote() for i in range(4)]
[c.increment.remote() for c in counters]
futures = [c.read.remote() for c in counters]
print(ray.get(futures))

Ray程序可以在一台计算机上运行,也可以无缝扩展到大型群集。要在云中执行上述Ray脚本,只需下载this configuration file,然后运行:

ray submit [CLUSTER.YAML] example.py --start

阅读有关以下内容的更多信息launching clusters

调整快速入门

https://github.com/ray-project/ray/raw/master/doc/source/images/tune-wide.png

Tune是一个用于任何规模的超参数调优的库

要运行此示例,您需要安装以下软件:

$ pip install "ray[tune]"

此示例运行并行格网搜索以优化示例目标函数

from ray import tune


def objective(step, alpha, beta):
    return (0.1 + alpha * step / 100)**(-1) + beta * 0.1


def training_function(config):
    # Hyperparameters
    alpha, beta = config["alpha"], config["beta"]
    for step in range(10):
        # Iterative training function - can be any arbitrary training procedure.
        intermediate_score = objective(step, alpha, beta)
        # Feed the score back back to Tune.
        tune.report(mean_loss=intermediate_score)


analysis = tune.run(
    training_function,
    config={
        "alpha": tune.grid_search([0.001, 0.01, 0.1]),
        "beta": tune.choice([1, 2, 3])
    })

print("Best config: ", analysis.get_best_config(metric="mean_loss", mode="min"))

# Get a dataframe for analyzing trial results.
df = analysis.results_df

如果安装了TensorBoard,则自动可视化所有试验结果:

tensorboard --logdir ~/ray_results

RLlib快速入门

https://github.com/ray-project/ray/raw/master/doc/source/images/rllib-wide.jpg

RLlib是构建在Ray之上的用于强化学习的开源库,它为各种应用程序提供了高可伸缩性和统一的API

pip install tensorflow  # or tensorflow-gpu
pip install "ray[rllib]"

import gym
from gym.spaces import Discrete, Box
from ray import tune

class SimpleCorridor(gym.Env):
    def __init__(self, config):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(2)
        self.observation_space = Box(0.0, self.end_pos, shape=(1, ))

    def reset(self):
        self.cur_pos = 0
        return [self.cur_pos]

    def step(self, action):
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
        elif action == 1:
            self.cur_pos += 1
        done = self.cur_pos >= self.end_pos
        return [self.cur_pos], 1 if done else 0, done, {}

tune.run(
    "PPO",
    config={
        "env": SimpleCorridor,
        "num_workers": 4,
        "env_config": {"corridor_length": 5}})

Ray Serve快速入门

Ray Serve是一个构建在Ray之上的可伸缩的模型服务库。它是:

  • 框架不可知性:使用相同的工具包提供各种服务,从使用PyTorch或TensorFlow&Kera等框架构建的深度学习模型到Scikit-Learning模型或任意业务逻辑
  • Python优先:在纯Python中配置声明性服务的模型,不需要YAML或JSON配置
  • 以性能为导向:启用批处理、流水线和GPU加速以提高模型的吞吐量
  • 原生合成:允许您通过将多个模型组合在一起来驱动单个预测来创建“模型管道”
  • 水平可扩展:随着您添加更多的机器,Serve可以线性扩展。使您的ML支持的服务能够处理不断增长的流量

要运行此示例,您需要安装以下软件:

$ pip install scikit-learn
$ pip install "ray[serve]"

此示例Run服务于一个SCRICKIT-LEARN梯度增强分类器

from ray import serve
import pickle
import requests
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier

# Train model
iris_dataset = load_iris()
model = GradientBoostingClassifier()
model.fit(iris_dataset["data"], iris_dataset["target"])

# Define Ray Serve model,
class BoostingModel:
    def __init__(self):
        self.model = model
        self.label_list = iris_dataset["target_names"].tolist()

    def __call__(self, flask_request):
        payload = flask_request.json["vector"]
        print("Worker: received flask request with data", payload)

        prediction = self.model.predict([payload])[0]
        human_name = self.label_list[prediction]
        return {"result": human_name}


# Deploy model
client = serve.start()
client.create_backend("iris:v1", BoostingModel)
client.create_endpoint("iris_classifier", backend="iris:v1", route="/iris")

# Query it!
sample_request_input = {"vector": [1.2, 1.0, 1.1, 0.9]}
response = requests.get("http://localhost:8000/iris", json=sample_request_input)
print(response.text)
# Result:
# {
#  "result": "versicolor"
# }

更多信息

较旧的文档:

参与其中