PyTorch - to NumPy yields unsized object?
Perhaps you'd want to use torch.argmax(nn_result, dim=1)
? Since dim
defaults to 0, it returns just a single number constructed as a tensor. Let me illustrate with the below example:
>>> x = np.array(1)>>> x.shape()>>> len(x)Traceback (most recent call last): File "<stdin>", line 1, in <module>TypeError: len() of unsized object>>> x = np.array([1])>>> x.shape(1,)>>> len(x)1
Essentially np.array
will take up any object
type that you construct with. In the first case, object is not an array because of which you do not see a valid shape. Since it is not an array calling len
throws an error.
torch.argmax
with dim=0
returns a tensor as seen in the first case of the above example and hence the error.