n次元ループを1次元ループに変換する (3)

#2358をジェネレータに対応するように書き換えた。

import itertools

def flatten_for(*args):
    D = dict([(id(L), L) for L in args])
    def flatten_for_(*args):
        if not args:
            yield []
        elif len(args) == 1:
            D[id(args[0])], g = itertools.tee(D[id(args[0])])
            for x in g:
                yield [x]
        else:
            D[id(args[0])], g = itertools.tee(D[id(args[0])])
            for x in g:
                for y in flatten_for_(*args[1:]):
                    yield [x] + y
    return flatten_for_(*args)

def foo():
	yield 1
	yield 2
	
def bar():
	yield 'a'
	yield 'b'
	yield 'c'

print list(flatten_for(foo(), bar()))

# 出力
[[1, 'a'], [1, 'b'], [1, 'c'], [2, 'a'], [2, 'b'], [2, 'c']]

関数名がcross_productだと数学の外積と間違えやすいので名前はflatten_forのままにしておいた。実装上のポイントは以下の2つ。

  1. itertools.teeによりジェネレータを複製して、複製したジェネレータの1つを保存しておく
  2. 値渡しによりargsの値を書き換えられないので、staticなオブジェクトに保存する

ちなみに、以前の自分の実装方法はリストの長さ分、リストの回転を遅らせるという考え方だが、ジェネレータは当然リストの長さがあらかじめ分からないので使えない。今回の実装は再帰だが、この場合for文を使って「リストの個数がn個」を扱うのは難しいかもしれない。