#Unembedding (reverse of embedding process) problem

2 messages · Page 1 of 1 (latest)

hybrid ether
#

I am trying to unembed some predicted values, but I got this error.

RuntimeError: [enforce fail at alloc_cpu.cpp:124] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 165478400000 bytes. Error code 12 (Cannot allocate memory)

Which base on my understanding said that it can not allocate memory, so it basically means that I had not enough memory on my system (it has 16GB currently)
This was my code (sample code):

import torch

class MyModel(torch.nn.Module):
    def __init__(self, token_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(token_size, embedding_dim)

    def unembedding(self, x: torch.Tensor) -> torch.Tensor:
        return torch.argmin(torch.abs(x.unsqueeze(2).expand((-1, -1, self.embedding.weight.size(0), -1)) - self.embedding.weight.unsqueeze(0).unsqueeze(0).expand((x.size(0), x.size(1), -1, -1))).sum(dim=3), dim=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        print(x)
        embedding_x = self.embedding(x)
        unembedding_x = self.unembedding(embedding_x)
        print(unembedding_x)

model = MyModel(101, 1000).to(device)
model.embedding.weight

sample = torch.randint(0, 100, (4096, 100))
model(sample)
#

So I tried to write a more memory-efficient unembedding function to unembed, and this was the result.

def unembedding(self, x: torch.Tensor) -> torch.Tensor:
    assert (x.shape[-1] == self.embedding.weight.shape[-1]), f"Last dimension must be equal to embedding dimension. Expect (*, {self.embedding.weight.shape[-1]}), got ({x.shape})"
    if len(x.shape) == 1:
        # x.shape == (ED)
        result = (self.embedding.weight - x).sum(dim=1).pow(2).argmin()
        return result
    if len(x.shape) == 2:
        # x.shape == (L, ED)
        result = torch.zeros(x.shape[0])
        for l in range(x.shape[0]):
            result[l] = self.unembedding(x[l, :])
        return result
    if len(x.shape) == 3:
        # x.shape == (N, L, ED)
        result = torch.zeros((x.shape[0], x.shape[1]))
        for n in range(x.shape[0]):
            result[n] = self.unembedding(x[n, :])
        return result

Now it takes about 40 seconds to run for each batch of data (on an i5-9400f CPU)

Any idea for better way to do this please? (faster while keep memory-efficacy)