Max Stack

Design a stack that supports push, pop, top, peekMax, and popMax operations

Language Selection

Choose your preferred programming language

Showing: Python

Problem Statement

Design a max stack data structure that supports the stack operations and supports finding the stack’s maximum element.

Implement the MaxStack class:

  • MaxStack() Initializes the stack object.
  • void push(int x) Pushes element x onto the stack.
  • int pop() Removes the element on top of the stack and returns it.
  • int top() Gets the element on the top of the stack without removing it.
  • int peekMax() Retrieves the maximum element in the stack without removing it.
  • int popMax() Retrieves the maximum element in the stack and removes it. If there is more than one maximum element, only remove the top-most one.

Example 1:

Input
["MaxStack", "push", "push", "push", "top", "popMax", "top", "peekMax", "pop", "top"]
[[], [5], [1], [5], [], [], [], [], [], []]
Output
[null, null, null, null, 5, 5, 1, 5, 1, 5]

Explanation
MaxStack stk = new MaxStack();
stk.push(5);   // [5] the top of the stack and the maximum number is 5.
stk.push(1);   // [5, 1] the top of the stack and the maximum number is 5.
stk.push(5);   // [5, 1, 5] the top of the stack and the maximum number is 5.
stk.top();     // return 5, [5, 1, 5] the stack did not change.
stk.popMax();  // return 5, [5, 1] the stack changes now, and the top is 1.
stk.top();     // return 1, [5, 1] the stack did not change.
stk.peekMax(); // return 5, [5, 1] the stack did not change.
stk.pop();     // return 1, [5] the stack changes now, the top is 5.
stk.top();     // return 5, [5] the stack did not change.

Constraints:

  • -10^7 <= x <= 10^7
  • At most 10^4 calls will be made to push, pop, top, peekMax, and popMax.
  • There will be at least one element in the stack when calling pop, top, peekMax, or popMax.

Approach 1: Two Stacks

Algorithm

  1. Use two stacks: one for elements, one for maximums
  2. For push: push to both stacks, max stack gets max(current, previous max)
  3. For pop: pop from both stacks
  4. For popMax: pop elements until we find max, then push back

Solution

Python:

class MaxStack:
    """
    Two stacks approach
    Time: O(n) for popMax, O(1) for others
    Space: O(n) - two stacks
    """
    def __init__(self):
        self.stack = []
        self.max_stack = []

    def push(self, x):
        self.stack.append(x)
        if not self.max_stack or x >= self.max_stack[-1]:
            self.max_stack.append(x)

    def pop(self):
        if not self.stack:
            return None
        
        val = self.stack.pop()
        if self.max_stack and val == self.max_stack[-1]:
            self.max_stack.pop()
        return val

    def top(self):
        return self.stack[-1] if self.stack else None

    def peekMax(self):
        return self.max_stack[-1] if self.max_stack else None

    def popMax(self):
        if not self.max_stack:
            return None
        
        max_val = self.max_stack[-1]
        temp_stack = []
        
        # Pop elements until we find the max
        while self.stack[-1] != max_val:
            temp_stack.append(self.stack.pop())
        
        # Remove the max element
        self.stack.pop()
        self.max_stack.pop()
        
        # Push back the other elements
        while temp_stack:
            val = temp_stack.pop()
            self.push(val)
        
        return max_val

Java:

class MaxStack {
    /**
     * Two stacks approach
     * Time: O(n) for popMax, O(1) for others
     * Space: O(n) - two stacks
     */
    private Stack<Integer> stack;
    private Stack<Integer> maxStack;
    
    public MaxStack() {
        stack = new Stack<>();
        maxStack = new Stack<>();
    }
    
    public void push(int x) {
        stack.push(x);
        if (maxStack.isEmpty() || x >= maxStack.peek()) {
            maxStack.push(x);
        }
    }
    
    public int pop() {
        int val = stack.pop();
        if (!maxStack.isEmpty() && val == maxStack.peek()) {
            maxStack.pop();
        }
        return val;
    }
    
    public int top() {
        return stack.peek();
    }
    
    public int peekMax() {
        return maxStack.peek();
    }
    
    public int popMax() {
        int maxVal = maxStack.peek();
        Stack<Integer> tempStack = new Stack<>();
        
        // Pop elements until we find the max
        while (stack.peek() != maxVal) {
            tempStack.push(stack.pop());
        }
        
        // Remove the max element
        stack.pop();
        maxStack.pop();
        
        // Push back the other elements
        while (!tempStack.isEmpty()) {
            push(tempStack.pop());
        }
        
        return maxVal;
    }
}

Go:

// MaxStack - Two stacks approach
// Time: O(n) for popMax, O(1) for others
// Space: O(n) - two stacks
type MaxStack struct {
    stack    []int
    maxStack []int
}

func Constructor() MaxStack {
    return MaxStack{
        stack:    make([]int, 0),
        maxStack: make([]int, 0),
    }
}

func (this *MaxStack) Push(x int) {
    this.stack = append(this.stack, x)
    if len(this.maxStack) == 0 || x >= this.maxStack[len(this.maxStack)-1] {
        this.maxStack = append(this.maxStack, x)
    }
}

func (this *MaxStack) Pop() int {
    if len(this.stack) == 0 {
        return 0
    }
    
    val := this.stack[len(this.stack)-1]
    this.stack = this.stack[:len(this.stack)-1]
    
    if len(this.maxStack) > 0 && val == this.maxStack[len(this.maxStack)-1] {
        this.maxStack = this.maxStack[:len(this.maxStack)-1]
    }
    
    return val
}

func (this *MaxStack) Top() int {
    return this.stack[len(this.stack)-1]
}

func (this *MaxStack) PeekMax() int {
    return this.maxStack[len(this.maxStack)-1]
}

func (this *MaxStack) PopMax() int {
    if len(this.maxStack) == 0 {
        return 0
    }
    
    maxVal := this.maxStack[len(this.maxStack)-1]
    var tempStack []int
    
    // Pop elements until we find the max
    for this.stack[len(this.stack)-1] != maxVal {
        tempStack = append(tempStack, this.Pop())
    }
    
    // Remove the max element
    this.stack = this.stack[:len(this.stack)-1]
    this.maxStack = this.maxStack[:len(this.maxStack)-1]
    
    // Push back the other elements
    for i := len(tempStack) - 1; i >= 0; i-- {
        this.Push(tempStack[i])
    }
    
    return maxVal
}

JavaScript:

/**
 * Two stacks approach
 * Time: O(n) for popMax, O(1) for others
 * Space: O(n) - two stacks
 */
class MaxStack {
    constructor() {
        this.stack = [];
        this.maxStack = [];
    }
    
    push(x) {
        this.stack.push(x);
        if (this.maxStack.length === 0 || x >= this.maxStack[this.maxStack.length - 1]) {
            this.maxStack.push(x);
        }
    }
    
    pop() {
        if (this.stack.length === 0) {
            return null;
        }
        
        const val = this.stack.pop();
        if (this.maxStack.length > 0 && val === this.maxStack[this.maxStack.length - 1]) {
            this.maxStack.pop();
        }
        return val;
    }
    
    top() {
        return this.stack[this.stack.length - 1];
    }
    
    peekMax() {
        return this.maxStack[this.maxStack.length - 1];
    }
    
    popMax() {
        if (this.maxStack.length === 0) {
            return null;
        }
        
        const maxVal = this.maxStack[this.maxStack.length - 1];
        const tempStack = [];
        
        // Pop elements until we find the max
        while (this.stack[this.stack.length - 1] !== maxVal) {
            tempStack.push(this.stack.pop());
        }
        
        // Remove the max element
        this.stack.pop();
        this.maxStack.pop();
        
        // Push back the other elements
        while (tempStack.length > 0) {
            this.push(tempStack.pop());
        }
        
        return maxVal;
    }
}

C#:

public class MaxStack {
    /// <summary>
    /// Two stacks approach
    /// Time: O(n) for popMax, O(1) for others
    /// Space: O(n) - two stacks
    /// </summary>
    private Stack<int> stack;
    private Stack<int> maxStack;
    
    public MaxStack() {
        stack = new Stack<int>();
        maxStack = new Stack<int>();
    }
    
    public void Push(int x) {
        stack.Push(x);
        if (maxStack.Count == 0 || x >= maxStack.Peek()) {
            maxStack.Push(x);
        }
    }
    
    public int Pop() {
        int val = stack.Pop();
        if (maxStack.Count > 0 && val == maxStack.Peek()) {
            maxStack.Pop();
        }
        return val;
    }
    
    public int Top() {
        return stack.Peek();
    }
    
    public int PeekMax() {
        return maxStack.Peek();
    }
    
    public int PopMax() {
        int maxVal = maxStack.Peek();
        var tempStack = new Stack<int>();
        
        // Pop elements until we find the max
        while (stack.Peek() != maxVal) {
            tempStack.Push(stack.Pop());
        }
        
        // Remove the max element
        stack.Pop();
        maxStack.Pop();
        
        // Push back the other elements
        while (tempStack.Count > 0) {
            Push(tempStack.Pop());
        }
        
        return maxVal;
    }
}

Approach 2: TreeMap with Stack

Algorithm

  1. Use a stack to maintain insertion order
  2. Use a TreeMap to maintain value -> list of positions
  3. For popMax: find max in TreeMap, remove from both structures

Solution

Python:

from collections import defaultdict
import bisect

class MaxStack:
    """
    TreeMap with stack approach
    Time: O(log n) for popMax, O(1) for others
    Space: O(n) - stack and map
    """
    def __init__(self):
        self.stack = []
        self.max_map = defaultdict(list)
        self.max_values = []

    def push(self, x):
        self.stack.append(x)
        bisect.insort(self.max_values, x)
        self.max_map[x].append(len(self.stack) - 1)

    def pop(self):
        if not self.stack:
            return None
        
        val = self.stack.pop()
        # Remove from max_map
        if self.max_map[val]:
            self.max_map[val].pop()
            if not self.max_map[val]:
                del self.max_map[val]
        
        # Remove from max_values
        idx = bisect.bisect_left(self.max_values, val)
        if idx < len(self.max_values) and self.max_values[idx] == val:
            self.max_values.pop(idx)
        
        return val

    def top(self):
        return self.stack[-1] if self.stack else None

    def peekMax(self):
        return self.max_values[-1] if self.max_values else None

    def popMax(self):
        if not self.max_values:
            return None
        
        max_val = self.max_values[-1]
        # Find the rightmost occurrence
        positions = self.max_map[max_val]
        if positions:
            pos = positions[-1]
            # Remove from stack
            self.stack.pop(pos)
            # Update positions in max_map
            for key in self.max_map:
                self.max_map[key] = [p - 1 if p > pos else p for p in self.max_map[key] if p != pos]
            
            # Remove from max_map
            self.max_map[max_val].pop()
            if not self.max_map[max_val]:
                del self.max_map[max_val]
            
            # Remove from max_values
            self.max_values.pop()
        
        return max_val

Key Insights

  1. Two Stacks: Simple but popMax is O(n)
  2. TreeMap: More complex but popMax is O(log n)
  3. Position Tracking: Need to track positions for efficient removal
  4. Maintaining Order: Keep insertion order while supporting max operations

Edge Cases

  1. Empty Stack: Handle operations on empty stack
  2. Duplicate Max Values: Handle multiple occurrences of max
  3. Single Element: Stack with only one element
  4. All Same Values: Stack with all identical values

Follow-up Questions

  1. Min Stack: Implement min stack instead of max stack
  2. Range Queries: Support range max/min queries
  3. Persistent Stack: Support versioning of the stack
  4. Concurrent Access: Handle multiple threads

Common Mistakes

  1. Not updating max stack: Forgetting to update max stack on pop
  2. Wrong popMax logic: Not handling the case where max appears multiple times
  3. Position tracking: Incorrectly updating positions after removal
  4. Edge cases: Not handling empty stack or single element