Minimum Spanning Tree

Find the minimum spanning tree of a weighted undirected graph using Kruskal's and Prim's algorithms

Language Selection

Choose your preferred programming language

Showing: Python

Minimum Spanning Tree

Problem Statement

Given a connected, undirected graph with weighted edges, find the minimum spanning tree (MST) that connects all vertices with the minimum total edge weight.

A spanning tree of a graph is a subgraph that includes all vertices and is a tree (connected and acyclic). A minimum spanning tree is a spanning tree with the minimum possible total edge weight.

Input/Output Specifications

  • Input: A weighted undirected graph represented as:
    • n vertices (numbered 0 to n-1)
    • List of edges [u, v, weight] where u and v are vertices and weight is the edge cost
  • Output: The minimum total weight of the spanning tree

Constraints

  • 1 <= n <= 10^4
  • 0 <= edges.length <= min(n*(n-1)/2, 10^4)
  • 0 <= u, v < n
  • u != v
  • 1 <= weight <= 10^6
  • The graph is connected

Examples

Example 1:

Input: n = 4, edges = [[0,1,10],[0,2,6],[0,3,5],[1,3,15],[2,3,4]]
Output: 19
Explanation: 
The MST includes edges: (2,3,4), (0,3,5), (0,1,10)
Total weight: 4 + 5 + 10 = 19

Example 2:

Input: n = 3, edges = [[0,1,1],[1,2,2],[0,2,3]]
Output: 3
Explanation:
The MST includes edges: (0,1,1), (1,2,2)
Total weight: 1 + 2 = 3

Example 3:

Input: n = 4, edges = [[0,1,1],[1,2,1],[2,3,1],[0,3,1]]
Output: 3
Explanation:
Any 3 edges with weight 1 will form the MST.
Total weight: 1 + 1 + 1 = 3

Solution Approaches

Approach 1: Kruskal’s Algorithm with Union-Find (Optimal)

Algorithm Explanation:

  1. Sort all edges by weight in ascending order
  2. Initialize Union-Find data structure for cycle detection
  3. Iterate through sorted edges:
    • If adding the edge doesn’t create a cycle (vertices not in same component), add it to MST
    • Use Union-Find to check connectivity and union components
  4. Continue until MST has n-1 edges

Implementation:

Python:

def minimum_spanning_tree_kruskal(n, edges):
    """
    Find MST using Kruskal's algorithm
    Time: O(E log E)
    Space: O(V)
    """
    class UnionFind:
        def __init__(self, n):
            self.parent = list(range(n))
            self.rank = [0] * n
            self.components = n
        
        def find(self, x):
            if self.parent[x] != x:
                self.parent[x] = self.find(self.parent[x])
            return self.parent[x]
        
        def union(self, x, y):
            px, py = self.find(x), self.find(y)
            if px == py:
                return False
            
            if self.rank[px] < self.rank[py]:
                px, py = py, px
            
            self.parent[py] = px
            if self.rank[px] == self.rank[py]:
                self.rank[px] += 1
            
            self.components -= 1
            return True
    
    # Sort edges by weight
    edges.sort(key=lambda x: x[2])
    
    uf = UnionFind(n)
    mst_weight = 0
    edges_added = 0
    
    for u, v, weight in edges:
        if uf.union(u, v):
            mst_weight += weight
            edges_added += 1
            if edges_added == n - 1:
                break
    
    return mst_weight

Java:

class Solution {
    /**
     * Find MST using Kruskal's algorithm
     * Time: O(E log E)
     * Space: O(V)
     */
    public int minimumSpanningTreeKruskal(int n, int[][] edges) {
        // Sort edges by weight
        Arrays.sort(edges, (a, b) -> Integer.compare(a[2], b[2]));
        
        UnionFind uf = new UnionFind(n);
        int mstWeight = 0;
        int edgesAdded = 0;
        
        for (int[] edge : edges) {
            int u = edge[0], v = edge[1], weight = edge[2];
            
            if (uf.union(u, v)) {
                mstWeight += weight;
                edgesAdded++;
                if (edgesAdded == n - 1) {
                    break;
                }
            }
        }
        
        return mstWeight;
    }
    
    class UnionFind {
        private int[] parent;
        private int[] rank;
        
        public UnionFind(int n) {
            parent = new int[n];
            rank = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }
        
        public int find(int x) {
            if (parent[x] != x) {
                parent[x] = find(parent[x]);
            }
            return parent[x];
        }
        
        public boolean union(int x, int y) {
            int px = find(x), py = find(y);
            if (px == py) return false;
            
            if (rank[px] < rank[py]) {
                parent[px] = py;
            } else if (rank[px] > rank[py]) {
                parent[py] = px;
            } else {
                parent[py] = px;
                rank[px]++;
            }
            
            return true;
        }
    }
}

Go:

// minimumSpanningTreeKruskal - Find MST using Kruskal's algorithm
// Time: O(E log E)
// Space: O(V)
func minimumSpanningTreeKruskal(n int, edges [][]int) int {
    // Sort edges by weight
    sort.Slice(edges, func(i, j int) bool {
        return edges[i][2] < edges[j][2]
    })
    
    uf := NewUnionFind(n)
    mstWeight := 0
    edgesAdded := 0
    
    for _, edge := range edges {
        u, v, weight := edge[0], edge[1], edge[2]
        
        if uf.Union(u, v) {
            mstWeight += weight
            edgesAdded++
            if edgesAdded == n-1 {
                break
            }
        }
    }
    
    return mstWeight
}

type UnionFind struct {
    parent []int
    rank   []int
}

func NewUnionFind(n int) *UnionFind {
    parent := make([]int, n)
    rank := make([]int, n)
    for i := 0; i < n; i++ {
        parent[i] = i
    }
    return &UnionFind{parent, rank}
}

func (uf *UnionFind) Find(x int) int {
    if uf.parent[x] != x {
        uf.parent[x] = uf.Find(uf.parent[x])
    }
    return uf.parent[x]
}

func (uf *UnionFind) Union(x, y int) bool {
    px, py := uf.Find(x), uf.Find(y)
    if px == py {
        return false
    }
    
    if uf.rank[px] < uf.rank[py] {
        uf.parent[px] = py
    } else if uf.rank[px] > uf.rank[py] {
        uf.parent[py] = px
    } else {
        uf.parent[py] = px
        uf.rank[px]++
    }
    
    return true
}

JavaScript:

/**
 * Find MST using Kruskal's algorithm
 * Time: O(E log E)
 * Space: O(V)
 */
function minimumSpanningTreeKruskal(n, edges) {
    class UnionFind {
        constructor(n) {
            this.parent = Array.from({length: n}, (_, i) => i);
            this.rank = new Array(n).fill(0);
        }
        
        find(x) {
            if (this.parent[x] !== x) {
                this.parent[x] = this.find(this.parent[x]);
            }
            return this.parent[x];
        }
        
        union(x, y) {
            const px = this.find(x);
            const py = this.find(y);
            
            if (px === py) return false;
            
            if (this.rank[px] < this.rank[py]) {
                this.parent[px] = py;
            } else if (this.rank[px] > this.rank[py]) {
                this.parent[py] = px;
            } else {
                this.parent[py] = px;
                this.rank[px]++;
            }
            
            return true;
        }
    }
    
    // Sort edges by weight
    edges.sort((a, b) => a[2] - b[2]);
    
    const uf = new UnionFind(n);
    let mstWeight = 0;
    let edgesAdded = 0;
    
    for (const [u, v, weight] of edges) {
        if (uf.union(u, v)) {
            mstWeight += weight;
            edgesAdded++;
            if (edgesAdded === n - 1) {
                break;
            }
        }
    }
    
    return mstWeight;
}

C#:

public class Solution {
    /// <summary>
    /// Find MST using Kruskal's algorithm
    /// Time: O(E log E)
    /// Space: O(V)
    /// </summary>
    public int MinimumSpanningTreeKruskal(int n, int[][] edges) {
        // Sort edges by weight
        Array.Sort(edges, (a, b) => a[2].CompareTo(b[2]));
        
        var uf = new UnionFind(n);
        int mstWeight = 0;
        int edgesAdded = 0;
        
        foreach (var edge in edges) {
            int u = edge[0], v = edge[1], weight = edge[2];
            
            if (uf.Union(u, v)) {
                mstWeight += weight;
                edgesAdded++;
                if (edgesAdded == n - 1) {
                    break;
                }
            }
        }
        
        return mstWeight;
    }
    
    public class UnionFind {
        private int[] parent;
        private int[] rank;
        
        public UnionFind(int n) {
            parent = new int[n];
            rank = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }
        
        public int Find(int x) {
            if (parent[x] != x) {
                parent[x] = Find(parent[x]);
            }
            return parent[x];
        }
        
        public bool Union(int x, int y) {
            int px = Find(x), py = Find(y);
            if (px == py) return false;
            
            if (rank[px] < rank[py]) {
                parent[px] = py;
            } else if (rank[px] > rank[py]) {
                parent[py] = px;
            } else {
                parent[py] = px;
                rank[px]++;
            }
            
            return true;
        }
    }
}

Approach 2: Prim’s Algorithm with Priority Queue

Algorithm Explanation:

  1. Start with an arbitrary vertex (usually 0)
  2. Maintain a priority queue of edges from MST vertices to non-MST vertices
  3. Repeatedly select the minimum weight edge that connects MST to a new vertex
  4. Add the new vertex to MST and update the priority queue
  5. Continue until all vertices are included

Implementation:

Python:

import heapq
from collections import defaultdict

def minimum_spanning_tree_prim(n, edges):
    """
    Find MST using Prim's algorithm
    Time: O(E log V)
    Space: O(V + E)
    """
    # Build adjacency list
    graph = defaultdict(list)
    for u, v, weight in edges:
        graph[u].append((v, weight))
        graph[v].append((u, weight))
    
    visited = set()
    min_heap = [(0, 0)]  # (weight, vertex)
    mst_weight = 0
    
    while min_heap and len(visited) < n:
        weight, vertex = heapq.heappop(min_heap)
        
        if vertex in visited:
            continue
        
        visited.add(vertex)
        mst_weight += weight
        
        # Add edges from this vertex to unvisited vertices
        for neighbor, edge_weight in graph[vertex]:
            if neighbor not in visited:
                heapq.heappush(min_heap, (edge_weight, neighbor))
    
    return mst_weight

Java:

class Solution {
    /**
     * Find MST using Prim's algorithm
     * Time: O(E log V)
     * Space: O(V + E)
     */
    public int minimumSpanningTreePrim(int n, int[][] edges) {
        // Build adjacency list
        List<List<int[]>> graph = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            graph.add(new ArrayList<>());
        }
        
        for (int[] edge : edges) {
            int u = edge[0], v = edge[1], weight = edge[2];
            graph.get(u).add(new int[]{v, weight});
            graph.get(v).add(new int[]{u, weight});
        }
        
        boolean[] visited = new boolean[n];
        PriorityQueue<int[]> minHeap = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        minHeap.offer(new int[]{0, 0}); // {weight, vertex}
        
        int mstWeight = 0;
        int verticesAdded = 0;
        
        while (!minHeap.isEmpty() && verticesAdded < n) {
            int[] current = minHeap.poll();
            int weight = current[0];
            int vertex = current[1];
            
            if (visited[vertex]) continue;
            
            visited[vertex] = true;
            mstWeight += weight;
            verticesAdded++;
            
            for (int[] neighbor : graph.get(vertex)) {
                int nextVertex = neighbor[0];
                int edgeWeight = neighbor[1];
                
                if (!visited[nextVertex]) {
                    minHeap.offer(new int[]{edgeWeight, nextVertex});
                }
            }
        }
        
        return mstWeight;
    }
}

Go:

import (
    "container/heap"
)

// minimumSpanningTreePrim - Find MST using Prim's algorithm
// Time: O(E log V)
// Space: O(V + E)
func minimumSpanningTreePrim(n int, edges [][]int) int {
    // Build adjacency list
    graph := make([][][2]int, n)
    for _, edge := range edges {
        u, v, weight := edge[0], edge[1], edge[2]
        graph[u] = append(graph[u], [2]int{v, weight})
        graph[v] = append(graph[v], [2]int{u, weight})
    }
    
    visited := make([]bool, n)
    minHeap := &EdgeHeap{}
    heap.Init(minHeap)
    heap.Push(minHeap, Edge{0, 0}) // weight, vertex
    
    mstWeight := 0
    verticesAdded := 0
    
    for minHeap.Len() > 0 && verticesAdded < n {
        current := heap.Pop(minHeap).(Edge)
        
        if visited[current.vertex] {
            continue
        }
        
        visited[current.vertex] = true
        mstWeight += current.weight
        verticesAdded++
        
        for _, neighbor := range graph[current.vertex] {
            nextVertex, edgeWeight := neighbor[0], neighbor[1]
            if !visited[nextVertex] {
                heap.Push(minHeap, Edge{edgeWeight, nextVertex})
            }
        }
    }
    
    return mstWeight
}

type Edge struct {
    weight int
    vertex int
}

type EdgeHeap []Edge

func (h EdgeHeap) Len() int           { return len(h) }
func (h EdgeHeap) Less(i, j int) bool { return h[i].weight < h[j].weight }
func (h EdgeHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *EdgeHeap) Push(x interface{}) {
    *h = append(*h, x.(Edge))
}

func (h *EdgeHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}

JavaScript:

/**
 * Find MST using Prim's algorithm
 * Time: O(E log V)
 * Space: O(V + E)
 */
function minimumSpanningTreePrim(n, edges) {
    // Build adjacency list
    const graph = Array.from({length: n}, () => []);
    for (const [u, v, weight] of edges) {
        graph[u].push([v, weight]);
        graph[v].push([u, weight]);
    }
    
    const visited = new Array(n).fill(false);
    const minHeap = new MinPriorityQueue({priority: x => x[0]});
    minHeap.enqueue([0, 0]); // [weight, vertex]
    
    let mstWeight = 0;
    let verticesAdded = 0;
    
    while (!minHeap.isEmpty() && verticesAdded < n) {
        const [weight, vertex] = minHeap.dequeue().element;
        
        if (visited[vertex]) continue;
        
        visited[vertex] = true;
        mstWeight += weight;
        verticesAdded++;
        
        for (const [nextVertex, edgeWeight] of graph[vertex]) {
            if (!visited[nextVertex]) {
                minHeap.enqueue([edgeWeight, nextVertex]);
            }
        }
    }
    
    return mstWeight;
}

// Alternative implementation with simple array-based priority queue
function minimumSpanningTreePrimSimple(n, edges) {
    const graph = Array.from({length: n}, () => []);
    for (const [u, v, weight] of edges) {
        graph[u].push([v, weight]);
        graph[v].push([u, weight]);
    }
    
    const visited = new Array(n).fill(false);
    const minEdge = new Array(n).fill(Infinity);
    minEdge[0] = 0;
    
    let mstWeight = 0;
    
    for (let i = 0; i < n; i++) {
        let u = -1;
        for (let v = 0; v < n; v++) {
            if (!visited[v] && (u === -1 || minEdge[v] < minEdge[u])) {
                u = v;
            }
        }
        
        visited[u] = true;
        mstWeight += minEdge[u];
        
        for (const [v, weight] of graph[u]) {
            if (!visited[v] && weight < minEdge[v]) {
                minEdge[v] = weight;
            }
        }
    }
    
    return mstWeight;
}

C#:

public class Solution {
    /// <summary>
    /// Find MST using Prim's algorithm
    /// Time: O(E log V)
    /// Space: O(V + E)
    /// </summary>
    public int MinimumSpanningTreePrim(int n, int[][] edges) {
        // Build adjacency list
        var graph = new List<(int vertex, int weight)>[n];
        for (int i = 0; i < n; i++) {
            graph[i] = new List<(int, int)>();
        }
        
        foreach (var edge in edges) {
            int u = edge[0], v = edge[1], weight = edge[2];
            graph[u].Add((v, weight));
            graph[v].Add((u, weight));
        }
        
        var visited = new bool[n];
        var minHeap = new PriorityQueue<(int weight, int vertex), int>();
        minHeap.Enqueue((0, 0), 0);
        
        int mstWeight = 0;
        int verticesAdded = 0;
        
        while (minHeap.Count > 0 && verticesAdded < n) {
            var (weight, vertex) = minHeap.Dequeue();
            
            if (visited[vertex]) continue;
            
            visited[vertex] = true;
            mstWeight += weight;
            verticesAdded++;
            
            foreach (var (nextVertex, edgeWeight) in graph[vertex]) {
                if (!visited[nextVertex]) {
                    minHeap.Enqueue((edgeWeight, nextVertex), edgeWeight);
                }
            }
        }
        
        return mstWeight;
    }
}

Complexity Analysis:

Kruskal’s Algorithm:

  • Time Complexity: O(E log E) - Dominated by sorting edges
  • Space Complexity: O(V) - Union-Find data structure

Prim’s Algorithm:

  • Time Complexity: O(E log V) - Each edge operation in priority queue
  • Space Complexity: O(V + E) - Adjacency list and priority queue

Trade-offs:

  • Kruskal’s: Better for sparse graphs, simpler implementation, requires sorting
  • Prim’s: Better for dense graphs, doesn’t require sorting all edges

Key Insights

  • Greedy Strategy: Both algorithms use greedy approach - always select minimum weight edge that doesn’t violate MST properties
  • Cycle Detection: Kruskal’s uses Union-Find, Prim’s avoids cycles by construction
  • Edge vs Vertex Focus: Kruskal’s focuses on edges, Prim’s focuses on vertices
  • Graph Representation: Choice affects performance - edge list for Kruskal’s, adjacency list for Prim’s

Edge Cases

  1. Single Vertex: n = 1 → MST weight = 0
  2. Disconnected Graph: No spanning tree exists
  3. Self-loops: Should be ignored
  4. Multiple Edges: Keep minimum weight edge between same vertices
  5. All Equal Weights: Any spanning tree is minimum

How Solutions Handle Edge Cases:

  • Single vertex: Algorithms terminate early with weight 0
  • Disconnected graph: Would need connectivity check first
  • Self-loops: Union-Find returns false, Prim’s visited check prevents
  • Multiple edges: Sorting ensures minimum weight edge selected first
  • Equal weights: Any valid spanning tree will be found

Test Cases

def test_minimum_spanning_tree():
    # Basic case
    assert minimum_spanning_tree_kruskal(4, [[0,1,10],[0,2,6],[0,3,5],[1,3,15],[2,3,4]]) == 19
    
    # Triangle
    assert minimum_spanning_tree_kruskal(3, [[0,1,1],[1,2,2],[0,2,3]]) == 3
    
    # Square
    assert minimum_spanning_tree_kruskal(4, [[0,1,1],[1,2,1],[2,3,1],[0,3,1]]) == 3
    
    # Single vertex
    assert minimum_spanning_tree_kruskal(1, []) == 0
    
    # Linear chain
    assert minimum_spanning_tree_kruskal(4, [[0,1,1],[1,2,2],[2,3,3]]) == 6
    
    # Complete graph
    edges = [[i,j,i+j] for i in range(4) for j in range(i+1, 4)]
    result = minimum_spanning_tree_kruskal(4, edges)
    assert result == 5  # edges (0,1,1), (0,2,2), (1,2,3) or similar
    
    print("All tests passed!")

test_minimum_spanning_tree()

Follow-up Questions

  1. Maximum Spanning Tree: How would you find the maximum spanning tree?
  2. K Minimum Spanning Trees: Find the k smallest distinct spanning trees
  3. Dynamic MST: Handle edge insertions/deletions efficiently
  4. Degree Constrained MST: MST where each vertex has degree ≤ k
  5. Distributed MST: Find MST in distributed/parallel setting

Common Mistakes

  1. Wrong Edge Sorting: Sorting by vertex indices instead of weights

    • Problem: Doesn’t minimize total weight
    • Solution: Sort by edge[2] (weight)
  2. Missing Cycle Check: Adding edges without checking connectivity

    • Problem: Creates cycles, not a tree
    • Solution: Use Union-Find or visited tracking
  3. Incorrect Union-Find: Not using path compression or union by rank

    • Problem: Poor performance, O(n) per operation
    • Solution: Implement both optimizations
  4. Off-by-one Errors: Adding wrong number of edges

    • Problem: Not a spanning tree (too few/many edges)
    • Solution: MST has exactly n-1 edges

Interview Tips

  1. Algorithm Choice: Ask about graph density to choose between Kruskal’s and Prim’s
  2. Implementation Details: Be prepared to implement Union-Find from scratch
  3. Optimization Discussion: Mention path compression and union by rank
  4. Edge Cases: Always discuss connectivity requirement
  5. Follow-up Preparation: Understand relationship to shortest path algorithms

Concept Explanations

Minimum Spanning Tree Properties:

  • Has exactly n-1 edges for n vertices
  • Connects all vertices with minimum total weight
  • Removing any edge disconnects the graph
  • Adding any edge creates exactly one cycle

Cut Property: For any cut of the graph, the minimum weight edge crossing the cut is in some MST. This justifies the greedy approaches.

When to Apply: Use MST algorithms for network design problems (minimum cost to connect all locations), clustering analysis, and approximation algorithms for TSP.