I have a dataframe representing a point cloud plus node data: I have hence X,Y,Z positions as well as "node features". I would like to represent it as a graph to use it in a GNN.
I would hence like to use the GNN transform in PyGeometric to extract edge connectivity info. An example in the docs can be found here.
I am not managing to get it to work, I think I need to represent my dataset as a torch_geometric.datasets.GeometricShapes
object first?
This is what I tried
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Dataset,InMemoryDataset
from torch_geometric.transforms import SamplePoints, KNNGraph
import torch_geometric.transforms as T
from torch_geometric.datasets import GeometricShapes
class CustomDataset(InMemoryDataset):
def __init__(self, listOfDataObjects):
super().__init__()
self.data, self.slices = self.collate(listOfDataObjects)
def __len__(self):
return len(self.slices)
def __getitem__(self, idx):
sample = self.get(idx)
return sample
## My data
data_fake = pd.DataFrame(data=np.random.rand(20,5), columns =['X','Y', 'Z','Node_feature_1', 'Node_feature_2'])
# Get a data object and then Dataset
my_fake_data = Data()
my_fake_data.pos = torch.from_numpy(data_fake[['X','Y', 'Z']].values)
dataset_cloud = CustomDataset([my_fake_data])
# Transform
dataset_cloud.transform = T.Compose([SamplePoints(num=20), KNNGraph(k=5)])
alas, dataset_cloud[0]
returns
>>>Data(pos=[20, 3])
so no edge connectivity info for the nearest neighbours is added, as in the docs.
Any hints please? Thanks
EDIT
I think a workaround is not to use dataset.transforms
and apply KNNGraph
to the point cloud directly, as in
knn =KNNGraph(k=3)
data_knn = knn(dataset_cloud[0])
which return ad edge_index
attribute as expected.
Still would like to understand what goes wrong with the approach detailed above
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744762230a4592239.html
评论列表(0条)