Graph Attention Tensorflow2
3802 ワード
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,initializers
class BatchMultiHeadGraphAttention(keras.Model): #
def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
super(BatchMultiHeadGraphAttention, self).__init__()
self.n_head = n_head #
self.f_in = f_in #
self.f_out = f_out #
self.attn_dropout = attn_dropout # dropout
self.add_self_loop = True #
self.initializer = initializers.GlorotUniform() #
self.w = tf.Variable(self.initializer(shape=[self.n_head, self.f_in, self.f_out], dtype=tf.float32)) #
self.adj = []
self.fc = tf.Variable(self.initializer(shape=[self.n_head, 2*self.f_out, 1], dtype=tf.float32)) # att
self.leaky_relu = layers.LeakyReLU(alpha=0.2) #
self.softmax = layers.Softmax(axis=-1) #
self.dropout = layers.Dropout(rate=self.attn_dropout) # Dropout
if bias:
self.bias = tf.Variable(tf.zeros(self.f_out)) #
def remove_self_loops(self,edge_index): #
row, col = edge_index
mask = tf.where(row != col) #
edge_index = tf.transpose(tf.gather(tf.transpose(edge_index),tf.squeeze(mask)))
return edge_index
def add_self_loops(self, edge_index, num_nodes): #
loop_index = tf.range(0, num_nodes, dtype=tf.int64)
loop_index = tf.tile(tf.expand_dims(loop_index,0),[2, 1])
edge_index = tf.concat([edge_index, loop_index], 1)
return edge_index
def call(self, h, edge_index):
bs = h.shape[0] # [bs,fin]
if self.add_self_loop: #
self.remove_self_loops(edge_index)
self.add_self_loops(edge_index, bs)
h_prime = tf.matmul(h, self.w) # [head,bs,fout]
for i in range(h_prime.shape[1]): # for each node
neighbors = tf.gather(edge_index[1,:],tf.squeeze(tf.where(edge_index[0,:]==i)),0) # neighbors
if self.n_head == 1:
shape = tf.cast(tf.constant([bs]),dtype = tf.int64)
else :
shape = tf.cast(tf.constant([bs,self.n_head]),dtype = tf.int64)
n_neighbors = neighbors.shape[0] # number of this node's neighbors
curr_node = tf.tile(tf.expand_dims(h_prime[:,i,:],1),[1, n_neighbors, 1]) # [head,cbs,fout] tf.repeat
neighbors_node = tf.gather(h_prime,neighbors,axis=1) # [head,cbs,fout]
total_node = tf.concat((curr_node,neighbors_node),2) # [head,cbs,fout*2]
#att_node = self.leaky_relu(tf.matmul(total_node,self.fc))
att_node = self.leaky_relu([email protected])
att_node = self.softmax(tf.reshape(att_node,[self.n_head,n_neighbors])) # [head,cbs]
att_node = self.dropout(att_node)
att_node = tf.transpose(att_node,[1,0]) # tf.scatter_nd
scatter = tf.scatter_nd(tf.expand_dims(neighbors,1), tf.squeeze(att_node), shape)
self.adj.append(tf.transpose(scatter))
output = tf.matmul(tf.stack(self.adj,1),h_prime) # [head,bs,f_out]
output = tf.reduce_mean(output,0) # [bs,fout]
if self.bias is not None:
return output + self.bias
else:
return output
#
def Get_Adj(bs):
import itertools
a = [[],[]]
list_indices = range(bs) #
for i, j in itertools.permutations(list_indices, 2): # -
a[0].append(i)
a[1].append(j)
return tf.convert_to_tensor(a,dtype=tf.int64)
import time
heads = 12
bs = 512
fin = 256
fout = 128
a = Get_Adj(bs)
h = tf.random.normal([bs,fin])
model = BatchMultiHeadGraphAttention(n_head=heads, f_in=fin, f_out=fout, attn_dropout=0.5)
start_time = time.time() #
out = model(h,a)
end_time = time.time() #
print("time:%d" % (end_time-start_time))
print(out.shape)
"""
with tf.GradientTape(persistent=True) as tape:
out = model(h,a)
loss = tf.reduce_sum(out)
print(out.shape)
print(tape.gradient(loss, model.trainable_variables))
"""