Graph Attention Tensorflow2

3802 ワード

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,initializers


class BatchMultiHeadGraphAttention(keras.Model): #         
	def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
		super(BatchMultiHeadGraphAttention, self).__init__()
		self.n_head = n_head #    
		self.f_in = f_in #     
		self.f_out = f_out #     
		self.attn_dropout = attn_dropout # dropout
		self.add_self_loop = True #               
		self.initializer = initializers.GlorotUniform() #      
		self.w = tf.Variable(self.initializer(shape=[self.n_head, self.f_in, self.f_out], dtype=tf.float32)) #         
		self.adj = []
		self.fc = tf.Variable(self.initializer(shape=[self.n_head, 2*self.f_out, 1], dtype=tf.float32)) #       att
		self.leaky_relu = layers.LeakyReLU(alpha=0.2) #     
		self.softmax = layers.Softmax(axis=-1) #    
		self.dropout = layers.Dropout(rate=self.attn_dropout) # Dropout  
		if bias:
			self.bias = tf.Variable(tf.zeros(self.f_out)) #         

	def remove_self_loops(self,edge_index): #     
		row, col = edge_index
		mask = tf.where(row != col) #           
		edge_index = tf.transpose(tf.gather(tf.transpose(edge_index),tf.squeeze(mask)))
		return edge_index
	
	def add_self_loops(self, edge_index, num_nodes): #     
		loop_index = tf.range(0, num_nodes, dtype=tf.int64)
		loop_index = tf.tile(tf.expand_dims(loop_index,0),[2, 1])
		edge_index = tf.concat([edge_index, loop_index], 1)
		return edge_index
		
	def call(self, h, edge_index): 
		bs = h.shape[0] # [bs,fin]
		if self.add_self_loop: #        
			self.remove_self_loops(edge_index)
			self.add_self_loops(edge_index, bs)
			
		h_prime = tf.matmul(h, self.w) # [head,bs,fout]
		
		for i in range(h_prime.shape[1]): # for each node
			neighbors = tf.gather(edge_index[1,:],tf.squeeze(tf.where(edge_index[0,:]==i)),0) # neighbors
			if self.n_head == 1:
				shape = tf.cast(tf.constant([bs]),dtype = tf.int64)
			else :
				shape = tf.cast(tf.constant([bs,self.n_head]),dtype = tf.int64)
			n_neighbors = neighbors.shape[0] # number of this node's neighbors
			curr_node = tf.tile(tf.expand_dims(h_prime[:,i,:],1),[1, n_neighbors, 1]) # [head,cbs,fout]   tf.repeat                
			neighbors_node = tf.gather(h_prime,neighbors,axis=1) # [head,cbs,fout]
			total_node = tf.concat((curr_node,neighbors_node),2) # [head,cbs,fout*2]
			
			#att_node = self.leaky_relu(tf.matmul(total_node,self.fc))
			att_node = self.leaky_relu([email protected])
			att_node = self.softmax(tf.reshape(att_node,[self.n_head,n_neighbors])) # [head,cbs]
			att_node = self.dropout(att_node)
			att_node = tf.transpose(att_node,[1,0]) #     tf.scatter_nd  
			scatter = tf.scatter_nd(tf.expand_dims(neighbors,1), tf.squeeze(att_node), shape) 
			self.adj.append(tf.transpose(scatter))
		output = tf.matmul(tf.stack(self.adj,1),h_prime)  # [head,bs,f_out]
		output = tf.reduce_mean(output,0) # [bs,fout]
		
		if self.bias is not None:
			return output + self.bias
		else:
			return output
			
#       
def Get_Adj(bs):
	import itertools
	a = [[],[]]
	list_indices = range(bs) #     
	for i, j in itertools.permutations(list_indices, 2): #   -        
		a[0].append(i)
		a[1].append(j)
	return tf.convert_to_tensor(a,dtype=tf.int64)
	
import time
heads = 12
bs = 512
fin = 256
fout = 128
a = Get_Adj(bs)
h = tf.random.normal([bs,fin])
model = BatchMultiHeadGraphAttention(n_head=heads, f_in=fin, f_out=fout, attn_dropout=0.5)
start_time = time.time()  #    
out = model(h,a)
end_time = time.time()   #    
print("time:%d"  % (end_time-start_time))
print(out.shape)

"""
with tf.GradientTape(persistent=True) as tape: 
	 out = model(h,a)
	 loss = tf.reduce_sum(out)
	 print(out.shape)
print(tape.gradient(loss, model.trainable_variables))
"""