gnn - Py Geometric - Use KNNGraph Transform on Dataset - Stack Overflow

I have a dataframe representing a point cloud plus node data: I have hence X,Y,Z positions as well as &

I have a dataframe representing a point cloud plus node data: I have hence X,Y,Z positions as well as "node features". I would like to represent it as a graph to use it in a GNN.

I would hence like to use the GNN transform in PyGeometric to extract edge connectivity info. An example in the docs can be found here.

I am not managing to get it to work, I think I need to represent my dataset as a torch_geometric.datasets.GeometricShapes object first?

This is what I tried

import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Dataset,InMemoryDataset
from torch_geometric.transforms import SamplePoints, KNNGraph
import torch_geometric.transforms as T
from torch_geometric.datasets import GeometricShapes

class CustomDataset(InMemoryDataset):
    def __init__(self, listOfDataObjects):
        super().__init__()
        self.data, self.slices = self.collate(listOfDataObjects)
    
    def __len__(self):
        return len(self.slices)
    
    def __getitem__(self, idx):
        sample = self.get(idx)
        return sample

## My data
data_fake = pd.DataFrame(data=np.random.rand(20,5), columns =['X','Y', 'Z','Node_feature_1', 'Node_feature_2'])

# Get a data object and then Dataset
my_fake_data = Data()
my_fake_data.pos =  torch.from_numpy(data_fake[['X','Y', 'Z']].values)
dataset_cloud = CustomDataset([my_fake_data])


# Transform
dataset_cloud.transform = T.Compose([SamplePoints(num=20), KNNGraph(k=5)])

alas, dataset_cloud[0] returns >>>Data(pos=[20, 3])

so no edge connectivity info for the nearest neighbours is added, as in the docs.

Any hints please? Thanks

EDIT

I think a workaround is not to use dataset.transforms and apply KNNGraph to the point cloud directly, as in

knn =KNNGraph(k=3)
data_knn = knn(dataset_cloud[0])

which return ad edge_index attribute as expected. Still would like to understand what goes wrong with the approach detailed above

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744762230a4592239.html

相关推荐

  • gnn - Py Geometric - Use KNNGraph Transform on Dataset - Stack Overflow

    I have a dataframe representing a point cloud plus node data: I have hence X,Y,Z positions as well as &

    17小时前
    30

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信