An Introduction to Graph Neural Networks

Written by Coursera Staff • Updated on

Graphs are a powerful tool to represent data, but machines often find them difficult to analyze. Explore graph neural networks, a deep-learning method designed to address this problem, and learn about the impact this methodology has across industries.

[Featured Image] A data scientist uses graph neural networks to make predictions.

Neural networks are an important component of artificial intelligence and deep learning as they are systems designed to help machines process and produce information in a way similar to humans. When working with neural networks, you can design various types for different purposes, such as convolutional neural networks (CNNs), recurrent neural networks (RNNs), and graph neural networks (GNNs). In this article, we will focus on graph neural networks, including when you should choose this type of model and the advantages you might find.

What is a neural network in machine learning?

In machine learning, a neural network is a computational model designed to mimic neural pathways in the brain. Neural networks independently process and analyze information, learning from mistakes and continuously updating their own algorithms to improve results. This type of technology has been used in various fields, including aiding in medical diagnosis, target marking, compound identification, and computer vision.

A typical neural network consists of layers of interconnected nodes called neurons. These layers include an input layer, one or more hidden layers, and an output layer. After analysis in one layer, the data moves to the next layer for another layer of analysis, and so on. Processing involves a mathematical function with weights and biases, which are adjustable network parameters. Training a neural network involves repeatedly passing a set of inputs through the network, comparing the network's output with the desired output, and adjusting the weights and biases to reduce the error.

What is a graph neural network (GNN)?

Graph neural networks, or GNNs, are a type of neural network model designed specifically to process information represented in a graphical format. In traditional neural networks, like convolutional neural networks (CNNs), the data is typically assumed to be in Euclidean space (like text or time data), which can be represented in regular grid structures. These neural networks assume that inputs (like pixels in an image or words in a sentence) are independent of each other or they follow a sequential order. 

However, graphs are a powerful tool to represent many kinds of real-world data (like social networks, molecular structures, or computer networks), with varying distances between points and orders of connection. GNNs can handle data where entities connect in complex ways, like the nodes and edges in a graph. GNNs shine when you look at the relationships and connections between data points. GNNs show promising use in research areas within social media, social networks, roadways, financial trading, the internet, natural systems, and health care.

Benefits of GNNs

GNNs stand out due to their unique features that make them suitable for analyzing data represented as graphs. Key features of this type of model include the following.

They apply to broad graphical input types.

GNNs can effectively work with various types of graph structures. These include directed or undirected graphs (depending on whether nodes direct to one another), homogenous or heterogenous graphs (depending on whether nodes and edges have the same type), and static or dynamic graphs (depending on whether input features change with time). This adaptability makes them applicable in diverse real-world scenarios, from analyzing social networks (where connections represent friendships) to understanding molecular structures in chemistry.

They can capture node and edge data.

Unlike traditional neural networks that process individual data points independently, GNNs consider both the attributes of the nodes (individual entities) and the edges (relationships). For instance, in a social network graph, a node attribute could be a person’s profile information, while an edge attribute could represent the strength or type of relationship between two people.

They can learn with several training methods. 

When you build a machine learning model, you can choose between unsupervised, supervised, and semi-supervised training. Supervised training involves labeling data sets, while unsupervised training lets the machine uncover patterns without guidance. Semi-supervised training falls somewhere in the middle. Regarding GNNs, you can build models using any of these training types. 

Limitations of GNNs

GNNs are useful for many applications, but knowing their limitations can help you avoid common pitfalls. While many models require comprehensive training data, GNNs can learn from limited data pools. However, GNNs may inherit the bias present in the data they train on, making some resistant to adopting this model for certain applications, such as ranking applicants for a job. Another limitation is that GNNs may not generalize well to unseen data if you are using a large number of variables. By being aware of the ways in which GNNs generalize and learn, you can improve the structure of your training data and reduce the likelihood of a biased or ungeneralizable model. 

Real-world applications of GNNs

GNNs are popular in many fields, some of which lend themselves more easily to graphical structure than others. Some ways in which GNNs are already making headway in modern industries include the following. 

User recommendations

GNNs can analyze the complex web of user interactions and preferences in recommendation systems. GNNs can make highly personalized and accurate recommendations by understanding the intricate network of user-product interactions. This capability is particularly valuable in platforms like e-commerce or streaming services, where enhancing user experience through personalized content is a top priority.

Stock forecasting

Stock predictions are often based on highly volatile, related data. GNNs can process time-sharing diagrams, line graphs, and other stock trend representations, allowing more comprehensive historical data to influence the model. Embedding graphical data into prediction algorithms also allows the inclusion of relationship data, which can improve forecasts over time.

Predicting patient outcomes

To better predict patient outcomes and response to treatments, GNNs can include data on patient demographics and diagnosis information, combined with therapy options and broader patient outcome data. This type of model can then identify similar patients that have measured outcomes of different treatments, allowing for the prediction of how a particular patient is likely to respond to different treatment options. 

Modeling social networks

You can use GNNs to model relationships within networks of people, such as modeling collaborations between groups of people. For example, you could model the relationship between different actors and actresses within Hollywood. Each actor or actress would be a node, and the GNN algorithm would draw edges between each pair of professionals who had worked together in a movie or television show. The algorithm can then analyze this network model to see the strength of relationships between different actors or actresses. 

Keep learning on Coursera.

Deep learning is an exciting field to explore, and new applications pop up every day. If you want to build a broad foundation in machine learning that you can apply across industries, consider starting with the Deep Learning Specialization by DeepLearning.AI. 

Keep reading

Updated on
Written by:

Editorial Team

Coursera’s editorial team is comprised of highly experienced professional editors, writers, and fact...

This content has been made available for informational purposes only. Learners are advised to conduct additional research to ensure that courses and other credentials pursued meet their personal, professional, and financial goals.