Reshapeing operations
Reshapeing operations
Suppose we have the following tensor:
t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
], dtype=torch.float32)
We have two ways to get the shape:
> t.size()
torch.Size([3, 4])
> t.shape
torch.Size([3, 4])
The rank of a tensor is equal to the length of the tensor's shape.
> len(t.shape)
2
We can also deduce the number of elements contained within the tensor.
> torch.tensor(t.shape).prod()
tensor(12)
In PyTorch, there is a dedicated function for this:
> t.numel()
12
Reshaping a tensor in PyTorch
> t.reshape([2,6])
tensor([[1., 1., 1., 1., 2., 2.],
[2., 2., 3., 3., 3., 3.]])
> t.reshape([3,4])
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
> t.reshape([4,3])
tensor([[1., 1., 1.],
[1., 2., 2.],
[2., 2., 3.],
[3., 3., 3.]])
> t.reshape(6,2)
tensor([[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[3., 3.],
[3., 3.]])
> t.reshape(12,1)
tensor([[1.],
[1.],
[1.],
[1.],
[2.],
[2.],
[2.],
[2.],
[3.],
[3.],
[3.],
[3.]])
In this example, we increase the rank to 3
:
> t.reshape(2,2,3)
tensor(
[
[
[1., 1., 1.],
[1., 2., 2.]
],
[
[2., 2., 3.],
[3., 3., 3.]
]
])
Note:PyTorch has another function view() that does the same thing as the reshape().
Changing shape by squeezing and unsqueezing
These functions allow us to expand or shrink the rank (number of dimensions) of our tensor.
> print(t.reshape([1,12]))
> print(t.reshape([1,12]).shape)
tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])
> print(t.reshape([1,12]).squeeze())
> print(t.reshape([1,12]).squeeze().shape)
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
torch.Size([12])
> print(t.reshape([1,12]).squeeze().unsqueeze(dim=0))
> print(t.reshape([1,12]).squeeze().unsqueeze(dim=0).shape)
tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])
Let’s look at a common use case for squeezing a tensor by building a flatten function.
Flatten a tensor
Flattening a tensor means to remove all of the dimensions except for one.
A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements.
def flatten(t):
t = t.reshape(1, -1)
t = t.squeeze()
return t
> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
We'll see that flatten operations are required when passing an output tensor from a convolutional layer to a linear layer.
In these examples, we have flattened the entire tensor, however, it is possible to flatten only specific parts of a tensor. For example, suppose we have a tensor of shape [2,1,28,28]
for a CNN. This means that we have a batch of 2
grayscale images with height and width dimensions of 28 x 28
, respectively.
Here, we can specifically flatten the two images. To get the following shape: [2,1,784]
. We could also squeeze off the channel axes to get the following shape: [2,784]
.
Concatenating tensors
We combine tensors using the cat()
function
> t1 = torch.tensor([
[1,2],
[3,4]
])
> t2 = torch.tensor([
[5,6],
[7,8]
])
We can combine t1
and t2
row-wise (axis-0) in the following way:
> torch.cat((t1, t2), dim=0)
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
We can combine them column-wise (axis-1) like this:
> torch.cat((t1, t2), dim=1)
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
Flatten operation for a batch of image inputs to a CNN
Flattening specific axes of a tensor
We know that the tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.
To start, suppose we have the following three tensors.
t1 = torch.tensor([
[1,1,1,1],
[1,1,1,1],
[1,1,1,1],
[1,1,1,1]
])
t2 = torch.tensor([
[2,2,2,2],
[2,2,2,2],
[2,2,2,2],
[2,2,2,2]
])
t3 = torch.tensor([
[3,3,3,3],
[3,3,3,3],
[3,3,3,3],
[3,3,3,3]
])
Remember, batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2
.
> t = torch.stack((t1, t2, t3))
> t.shape
torch.Size([3, 4, 4])
Here, we used the stack()
function to concatenate our sequence of three tensors along a new axis.
> t
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]],
[[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3]]])
All we need to do now to get this tensor into a form that a CNN expects is add an axis for the color channels. We basically have an implicit single color channel for each of these image tensors, so in practice, these would be grayscale images.
> t = t.reshape(3,1,4,4)
> t
tensor(
[
[
[
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]
]
],
[
[
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]
]
],
[
[
[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3]
]
]
])
Flattening the tensor batch
Here are some alternative implementations of the flatten() function.
> t.reshape(1,-1)[0]
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
> t.reshape(-1)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
> t.view(t.numel())
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
> t.flatten()
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
This flattened batch won’t work well inside our CNN because we need individual predictions for each image within our batch tensor, and now we have a flattened mess.
The solution here, is to flatten each image while still maintaining the batch axis. This means we want to flatten only part of the tensor. We want to flatten the, color channel axis with the height and width axes.
> t.flatten(start_dim=1).shape
torch.Size([3, 16])
> t.flatten(start_dim=1)
tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
]
)