Commit 67c84dc1 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

conditioning batch-buffer on inner loop training

parent a3098084
Loading
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -668,6 +668,7 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
    logFreq   = params[ 'log_frequency' ]
    traceStep = params[ 'trace_step' ]
    maxTime = params.get('max_time', 1e12)
    inner_loop = params.get('inner_iter', 1e12)
    
    val_results = []
    loss_results = []
@@ -735,9 +736,10 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
            # constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch})
            # print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val))
        # current_batch_list = []
        if inner_loop < 100:
            batch_buffer.append(current_batch)
        # print_rank(len(batch_buffer))
        if bool(train_elf.last_step % 10 == 0 and train_elf.last_step >= 10):
        if bool(train_elf.last_step % inner_loop == 0 and train_elf.last_step >= 10):
            for itr, current_batch in enumerate(batch_buffer):
                # noise = np.random.random(images.shape.as_list()[1:]) 
                # noise = noise.astype(np.float32)