梶研 [GCN完全に理解してる途中]
2024年08月27日
GCN完全に理解してる途中
出席率
- 3年セミナー:??%
スケジュール
短期的な予定
- 9/3 GCN 完全に理解する
- 9/24 ST-GCN完全に理解して実装する
長期的な予定
進捗報告
[WIP] nn.Conv2d を使って実装する
データセットを読み込む
1# データセットを読み込む 2dataset = Planetoid(root="./Cora", name="Cora") 3# グラフ構造の数 4print(len(dataset)) # -> 1 5# クラス数 6print(dataset.num_classes) # -> 7 7# 徴量の次元数(1433種類の特定のワードが論文中に含まれているか. 0:ない,1:ある) 8print(dataset.num_node_features) # -> 1433 9# ノード数 10print(dataset[0].num_nodes) # -> 2708
必要なデータを定義する
1# ノードの数 2num_nodes = dataset[0].num_nodes 3# 特徴量の次元数 4in_channels = dataset.num_node_features 5# 6out_channels = dataset.num_classes 7# 特徴量行列 (迷走中) 8# X = rearrange(dataset[0].x, "V C -> C 1 V") 9# X = dataset[0].x 10X = rearrange(dataset[0].x, "V C -> V 1 C") 11# ラベル 12y = dataset[0].y 13# 隣接行列 14A = create_adjacency_matrix(dataset[0].edge_index, num_nodes) 15# 次数行列 16D = create_degree_matrix(A) 17 18# DAD行列 19DAD = D @ A @ D 20 21def create_adjacency_matrix(edge_index, num_nodes): 22 adjacency_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float32) 23 24 for i in range(edge_index.size(1)): 25 source = edge_index[0, i] 26 target = edge_index[1, i] 27 adjacency_matrix[source, target] = 1 28 29 return adjacency_matrix 30 31 32def create_degree_matrix(adjacency_matrix): 33 # 各行の要素の合計を計算し、それを逆数にする 34 degree_vector = torch.sum(adjacency_matrix, dim=1) 35 # 逆数を計算し、ゼロ除算を防ぐためにepsを加算 36 inv_degree_vector = 1.0 / (degree_vector + torch.finfo(torch.float32).eps) 37 # 対角行列として設定 38 degree_matrix = torch.diag(inv_degree_vector) 39 40 return degree_matrix
モデルの定義 (まだコピペ状態)
1# モデルの定義 2model = GraphConv(in_channels, out_channels) 3 4class GraphConv(nn.Module): 5 def __init__(self, in_features, out_features): 6 super(GraphConv, self).__init__() 7 self.in_features = in_features 8 self.out_features = out_features 9 self.conv = nn.Conv2d(in_features, out_features, kernel_size=1) 10 11 def forward(self, input, adj): 12 """ 13 Args: 14 input (Tensor): graph feature 15 input.size() = (N, V, C) 16 adj (Tensor): normalized adjacency matrix. 17 e.g. DAD or DA 18 input.size() = (V, V) 19 Returns: 20 Tensor: out.size() = (N, V, C_out) 21 """ 22 input = rearrange(input, "N V C -> N C 1 V") 23 XW = self.conv(input) 24 DADXW = torch.einsum("NCTV,VW->NCTW", XW, adj) 25 DADXW = rearrange(DADXW, "N C 1 V -> N V C") 26 return DADXW
描画する
1# 結果の図示 2fig, ax = plt.subplots(1, 2, width_ratios=[4, 8]) 3ax[0].pcolor(X[0], cmap=plt.cm.Blues) 4ax[0].set_aspect('equal', 'box') 5ax[0].set_title('X', fontsize=10) 6ax[0].invert_yaxis() 7 8ax[1].pcolor(new_X[0], cmap=plt.cm.Blues) 9ax[1].set_aspect('equal', 'box') 10ax[1].set_title('new_X', fontsize=10) 11ax[1].invert_yaxis()
おかしい
理想
1dataset[0].x.shape # -> torch.Size([2708, 1433])
正しい場合
1X.shape # -> (10, 5, 4)
(バッチ数, ノード数, 特徴量の次元数)