n次元ループを1次元ループに変換する

以前の日記の内容をまとめただけだが、以下のようになる。以下の flatten_for は、n次元ループを1次元ループに変換する汎用的な関数である。

from itertools import islice, izip, cycle
import operator

def product(L):
    return reduce(operator.mul, L)

def islow_list(L, n):
    for v in L:
        for i in range(n):
            yield v

def flatten_for(*seqs):
    lens = map(len, seqs)
    nums = [product(lens[i+1:] + [1]) for i, l in enumerate(seqs)]
    nrep = product(lens)
    return islice(
        izip(*(cycle(islow_list(L, n)) for L, n in izip(seqs, nums)))
        , nrep)

for i, j, k in flatten_for([0,1], 'abc', ['Foo', 'Bar']):
    print [i, j, k]	
	
# 以下と同じ意味
# for i in [0,1]:
#     for j in 'abc':
#         for k in ['Foo', 'Bar']:
#             print [i, j, k]

# 出力
[0, 'a', 'Foo']
[0, 'a', 'Bar']
[0, 'b', 'Foo']
[0, 'b', 'Bar']
[0, 'c', 'Foo']
[0, 'c', 'Bar']
[1, 'a', 'Foo']
[1, 'a', 'Bar']
[1, 'b', 'Foo']
[1, 'b', 'Bar']
[1, 'c', 'Foo']
[1, 'c', 'Bar']

を flatten_for を使って解いてみた。

def counter(symbols, n):
    L = []
    seqs = [symbols]
    cnt = 0
    while True:
        for i in flatten_for(*seqs):
            cnt += 1
            if cnt > n: break
            L.append(''.join(i))
        if cnt > n: break
        seqs.append(symbols)
    return L
	
print ', '.join(counter('ABCDEFGHIJKLMNOPQRSTUVWXYZ', 100))

考え方としては、桁の個数、つまりflatten_forに渡すシーケンスの個数を動的に増やしてループを回していく。但し、繰り返し回数が指定回数を超えたらループを終了する。flatten_forがあれば複雑なループ処理も結構簡単な考え方で書ける。


追記:
現在のバージョンの flatten_for は、以下の2つの問題点を持つ。

  1. flatten_for に渡す引数にイテレータやジェネレータを渡せない。リストのみ。
  2. ある階層のfor文とfor文の間に処理を入れたい場合、インデックスのリストもしくは、enumerateなリストを渡してインデックスを監視すれば可能だが面倒。

2つ目の問題は厄介で難しいが、逆に flatten_for を使うメリットとしては、forループの次数が可変で、かつ最終的な値のリストのみ必要な場合に実装が簡単になるというメリットがある。