dimanche 12 février 2017

error happened during using TensorArray and while_loop to do dynamic rnn

Vote count: 0

My intention is to do CNN (VGG16) and then push the output of it to two-layer lstm for every frame. However, error happens when I try to use sess.run(). That suggests the graph is constructed correctly. So where's my bug? Here's the error info.

tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'total_loss' has inputs from different frames. The input 'vgg_16/fc8/weights/Regularizer/l2_regularizer' is in frame 'while/while/'. The input 'Mean' is in frame ''.

And here's some of my code:

lstm=tf.nn.rnn_cell.BasicLSTMCell(128)
cell=tf.nn.rnn_cell.MultiRNNCell([lstm]*2)
state=cell.zero_state(1, tf.float32)

inputImgA=tf.TensorArray(tf.string, length)
outputLSTM=tf.TensorArray(tf.float32, length)
lossLSTM=tf.TensorArray(tf.float32, length)

img=sequence_parsed['imgs_list']
inputImgA=inputImgA.unpack(img)

i=tf.constant(0)


def cond(i, state, inputImgA, outputLSTM, lossLSTM):
    return tf.less(i, length)

def body(i, state, inputImgA, outputLSTM, lossLSTM):
    imcontent=inputImgA.read(i)
    image=tf.image.decode_jpeg(imcontent, 3, name='decode_image')
    with tf.variable_scope('Image_Process'):
        image=tf.image.resize_images(image, [224, 224])
        channels = tf.split(2, 3, image)
        channels[0] -= _R_MEAN
        channels[1] -= _G_MEAN
        channels[2] -= _B_MEAN
        image=tf.concat(2, channels)
        images=tf.expand_dims(image, 0)

    net, end = VggNet(images, is_training=True)
    output, state=cell(net, state)
    outputLSTM=outputLSTM.write(i, output)

    loss=tf.nn.sparse_softmax_cross_entropy_with_logits(output, label)
    lossLSTM=lossLSTM.write(i, loss)
    return (i+1, state, inputImgA, outputLSTM, lossLSTM) 


_, _, _, outputLSTM, lossLSTM=tf.while_loop(cond, body, [i, state, inputImgA, outputLSTM, lossLSTM])

output=outputLSTM.pack()
loss=lossLSTM.pack()
loss=tf.reduce_mean(loss)
losses.add_loss(loss)
total_loss=losses.get_total_loss()

asked 27 secs ago

Let's block ads! (Why?)



error happened during using TensorArray and while_loop to do dynamic rnn

Aucun commentaire:

Enregistrer un commentaire