Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

I have to use the output of a generator in another generator.

Below is the code -

Here Generator 2 is called within generator 1 and the final output is received from generator 2.

I am trying to use something like below , can anyone suggest a solution ?

def sub_gen(data) : for r in res_gen() : yield each train_datagen(r)

Generator 1

def res_gen (num_threads = 4 ):
    while (True) :
      for i in range(0,len(file_list),num_threads):
        # use multi-process to speed up
        res = []
        p = Pool(num_threads)
        patch = p.map(gen_patches,file_list[i:min(i+num_threads,len(file_list))])
        #patch = p.map(gen_patches,file_list[i:i+num_threads])
        for x in patch:
            res += x
        res1 = np.array(res)
        res1 = res1.reshape((res1.shape[0],res1.shape[1],res1.shape[2],1))
        res1 = res1.astype('float32')/255.0
        yield res1

Generator 2

def train_datagen(res1, batch_size=4):
    indices = list(range(res1.shape[0]))
    while(True):
        np.random.shuffle(indices)    # shuffle
        for i in range(0, len(indices), batch_size):
            ge_batch_y = res1[indices[i:i+batch_size]]
            noise =  np.random.normal(0, sigma/255.0, ge_batch_y.shape)   
            #noise =  K.random_normal(ge_batch_y.shape, mean=0, stddev=sigma/255.0)
            ge_batch_x = ge_batch_y + noise  # input image = clean image + noise
            yield ge_batch_x, ge_batch_y
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
124 views
Welcome To Ask or Share your Answers For Others

1 Answer

I'm pretty sure the only issue in your short sub_gen generator is that you've written yield each instead of yield from. The latter expects an iterable value after it (often another generator), and it yields each value just like an explicit for loop

So I think your code should be:

def sub_gen(data) :
  for r in res_gen() :
      yield from train_datagen(r)

Lets test this with much simpler generator functions:

def foo():
    yield [1, 2]
    yield [3, 4]

def bar(iterable):
    for x in iterable:
        yield 10+x
        yield 20+x

def baz():
    for iterable in foo():
        yield from bar(iterable)

for value in baz():   # use the top-level generator!
    print(value)      # prints 11, 21, 12, 22, 13, 23, 14, 24 each on its own line

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...