How to iterate through tensors in custom loss function? How to iterate through tensors in custom loss function? python python

How to iterate through tensors in custom loss function?


As usual, don't loop. There are severe performance drawbacks and also bugs. Use only backend functions unless totally unavoidable (usually it's not unavoidable)


Solution for example 3:

So, there is a very weird thing there...

Do you really want to simply ignore half of your model's predictions? (Example 3)

Assuming this is true, just duplicate your tensor in the last dimension, flatten and discard half of it. You have the exact effect you want.

def custom_loss(true, pred):    n = K.shape(pred)[0:1]    pred = K.concatenate([pred]*2, axis=-1) #duplicate in the last axis    pred = K.flatten(pred)                  #flatten     pred = K.slice(pred,                    #take only half (= n samples)                   K.constant([0], dtype="int32"),                    n)     return K.abs(true - pred)

Solution for your loss function:

If you have sorted times from greater to lower, just do a cumulative sum.

Warning: If you have one time per sample, you cannot train with mini-batches!!!
batch_size = len(labels)

It makes sense to have time in an additional dimension (many times per sample), as is done in recurrent and 1D conv netoworks. Anyway, considering your example as expressed, that is shape (samples_equal_times,) for yTime:

def neg_log_likelihood(yTrue,yPred):    yStatus = yTrue[:,0]    yTime = yTrue[:,1]        n = K.shape(yTrue)[0]        #sort the times and everything else from greater to lower:    #obs, you can have the data sorted already and avoid doing it here for performance    #important, yTime will be sorted in the last dimension, make sure its (None,) in this case    # or that it's (None, time_length) in the case of many times per sample    sortedTime, sortedIndices = tf.math.top_k(yTime, n, True)        sortedStatus = K.gather(yStatus, sortedIndices)    sortedPreds = K.gather(yPred, sortedIndices)    #do the calculations    exp = K.exp(sortedPreds)    sums = K.cumsum(exp)  #this will have the sum for j >= i in the loop    logsums = K.log(sums)    return K.sum(sortedStatus * sortedPreds - logsums)