Find the diameter of a binary tree
"Find the diameter of a binary tree." This is one of those problems that looks deceptively simple but hides a subtle insight. Also known as Diameter of Binary Tree on LeetCode, it builds directly on the tree height algorithm and tests whether you can extend a familiar recursion to solve a related but trickier problem. If you can find the height of a binary tree, you are one observation away from solving this.
TL;DR
The diameter of a binary tree is the number of nodes on the longest path between any two nodes. At every node, compute the height of the left and right subtrees. The diameter through that node is 1 + leftHeight + rightHeight. Track the maximum across all nodes. The recursive helper is nearly identical to the tree height algorithm, with one extra line to update the running max.
Why This Problem Matters
This problem teaches you how to extract multiple pieces of information from a single recursive traversal. Instead of running two separate passes (one for heights, one for diameters), you combine both computations into one DFS. This pattern of "compute something while returning something else" appears constantly in tree problems: lowest common ancestor, balanced tree check, maximum path sum, and many others. Master it here and you will recognize it everywhere.
Understanding the Problem
Given the root of a binary tree, return the number of nodes on the longest path between any two nodes. The path does not need to pass through the root.
Here is the example tree:
Loading visualization...
tree: 1 2 3 # # 4 5
1
/ \
2 3
/ \
4 5
treeDiameter(root) -> 4
The longest path is 2 -> 1 -> 3 -> 4 (or equivalently 2 -> 1 -> 3 -> 5), which has 4 nodes:
Loading visualization...
The Critical Insight: Diameter Does Not Always Pass Through the Root
Many candidates assume the answer is just 1 + height(left) + height(right) computed at the root. That works for balanced trees, but fails when the longest path lives entirely inside a subtree.
Consider this tree where the diameter bypasses the root entirely:
Loading visualization...
Here, the longest path is 2 -> 4 -> 5 -> 3 -> 6 -> 8 with 6 nodes, and it never touches node 1. This is why we must check every node as a potential center of the diameter, not just the root.
Solution Approach
The key observation: for any node in the tree, the diameter passing through that node equals 1 + leftHeight + rightHeight. The overall diameter is the maximum of this value across all nodes.
We can compute this in a single DFS by modifying the standard tree height algorithm:
- Recursively compute
leftHeightandrightHeightfor each node - At each node, calculate the diameter through it:
1 + leftHeight + rightHeight - Update a running maximum if this diameter is larger than what we have seen
- Return the height (
1 + max(leftHeight, rightHeight)) for the parent to use
Here is the tree with height and diameter annotations at each node:
Loading visualization...
Walking through this: leaf nodes 2, 4, and 5 each have height 1 and diameter 1. Node 3 has leftHeight = 1 (from 4) and rightHeight = 1 (from 5), giving height 2 and diameter 1 + 1 + 1 = 3. Node 1 has leftHeight = 1 (from 2) and rightHeight = 2 (from 3), giving height 3 and diameter 1 + 1 + 2 = 4. The maximum diameter across all nodes is 4.
Implementation
Prefer a different language? Jump to solutions in other languages.
public class Solution {
public int treeDiameter(TreeNode root) {
int[] diameterHolder = {0};
search(root, diameterHolder);
return diameterHolder[0];
}
private int search(TreeNode node, int[] diameterHolder) {
if (node == null) return 0;
int leftHeight = search(node.left, diameterHolder);
int rightHeight = search(node.right, diameterHolder);
diameterHolder[0] = Math.max(diameterHolder[0], 1 + leftHeight + rightHeight);
return 1 + Math.max(leftHeight, rightHeight);
}
}
Let's break down the two methods:
-
treeDiameter: The entry point. It creates a single-element arraydiameterHolderto act as a pass-by-reference container (since Java primitives are pass-by-value). After the recursive traversal completes,diameterHolder[0]holds the answer. -
search: A modified height function. At each node, it computes left and right subtree heights recursively. Then it calculates the diameter through the current node (1 + leftHeight + rightHeight) and updates the running max. Finally, it returns the height at this node for its parent to use.
The single-element array trick is a common Java pattern for simulating pass-by-reference. In Python, you would use a list. In C++, you would use a pointer. In Go, you would pass a pointer to an integer.
Complexity Analysis
Time: O(n). Every node is visited exactly once. At each node, the work is O(1): two additions, one comparison, one max. Total work is proportional to the number of nodes.
Space: O(n). The space is consumed by the recursion call stack. In the worst case (a completely skewed tree that looks like a linked list), the recursion depth is n. For a balanced tree, the depth is O(log n). The diameterHolder array is O(1) additional space.
Common Pitfalls
-
Only checking the root: Computing
1 + height(left) + height(right)at the root alone misses cases where the diameter exists entirely within a subtree. You must check every node. -
Confusing edges vs. nodes: Some problem variants define diameter as the number of edges, not nodes. In that case, the formula becomes
leftHeight + rightHeight(without the +1). Read the problem statement carefully. -
Forgetting the base case: When the node is null, return 0 for height. A null node has no height. Getting this wrong cascades through the entire recursion.
-
Using a class field instead of pass-by-reference: While a class field works, interviewers may view it as less clean. The array-based pass-by-reference approach keeps the solution self-contained and thread-safe.
Interview Tips
When presenting this solution:
- Start by stating the connection to tree height: "This is the height algorithm with one extra line of bookkeeping per node."
- Explain why a single pass works: "We compute height bottom-up, and at each node we have all the information needed to compute the diameter through it."
- Draw a tree where the diameter does not pass through the root to show you understand the subtlety.
- Mention the pass-by-reference pattern and why it is needed (the recursive function already returns height, so we need another channel for the diameter).
- If asked about edge counting vs. node counting, show that the only difference is whether you include the +1.
Key Takeaways
- The diameter of a binary tree is the longest path between any two nodes, measured in number of nodes. It may or may not pass through the root.
- At every node,
diameter = 1 + leftHeight + rightHeight. The overall answer is the maximum of this across all nodes. - A single DFS computes both height and diameter simultaneously. The function returns height to its caller while updating a shared variable with the best diameter seen so far.
- This pattern of extracting multiple results from one traversal is fundamental to tree problems. You will see it again in balanced tree checks, maximum path sum, and lowest common ancestor.
Practice and Related Problems
Once you have nailed the diameter problem, try these natural progressions:
- Height of a binary tree (the simpler version this builds on)
- Balanced binary tree check (another "return one thing, track another" pattern)
- Maximum path sum in a binary tree (same structure, harder arithmetic)
- Lowest common ancestor (similar recursive reasoning)
This problem and hundreds of others are available on Firecode, where spaced repetition ensures you internalize patterns like recursive tree traversal rather than just memorizing solutions. Building strong recursive thinking here pays dividends across your entire interview preparation.
Solutions in Other Languages
Python
class Solution:
def tree_diameter(self, root):
diameter_holder = [0]
self.search(root, diameter_holder)
return diameter_holder[0]
def search(self, node, diameter_holder):
if node is None:
return 0
left_height = self.search(node.left, diameter_holder)
right_height = self.search(node.right, diameter_holder)
diameter_holder[0] = max(diameter_holder[0], 1 + left_height + right_height)
return 1 + max(left_height, right_height)
JavaScript
class Solution {
treeDiameter(root) {
const diameterHolder = [0];
this.search(root, diameterHolder);
return diameterHolder[0];
}
search(node, diameterHolder) {
if (node === null) return 0;
const leftHeight = this.search(node.left, diameterHolder);
const rightHeight = this.search(node.right, diameterHolder);
diameterHolder[0] = Math.max(diameterHolder[0], 1 + leftHeight + rightHeight);
return 1 + Math.max(leftHeight, rightHeight);
}
}
TypeScript
class Solution {
treeDiameter(root: TreeNode | null): number {
const diameterHolder: number[] = [0];
this.search(root, diameterHolder);
return diameterHolder[0];
}
search(node: TreeNode | null, diameterHolder: number[]): number {
if (node === null) return 0;
const leftHeight = this.search(node.left, diameterHolder);
const rightHeight = this.search(node.right, diameterHolder);
diameterHolder[0] = Math.max(diameterHolder[0], 1 + leftHeight + rightHeight);
return 1 + Math.max(leftHeight, rightHeight);
}
}
C++
class Solution {
public:
int treeDiameter(TreeNode *root) {
int diameter = 0;
search(root, &diameter);
return diameter;
}
private:
int search(TreeNode *node, int *diameterRef) {
if (node == nullptr) return 0;
int leftHeight = search(node->left, diameterRef);
int rightHeight = search(node->right, diameterRef);
*diameterRef = std::max(*diameterRef, 1 + leftHeight + rightHeight);
return 1 + std::max(leftHeight, rightHeight);
}
};
Go
package solution
func (s *Solution) TreeDiameter(root *TreeNode) int {
diameter := 0
search(root, &diameter)
return diameter
}
func search(node *TreeNode, diameterRef *int) int {
if node == nil {
return 0
}
leftHeight := search(node.Left, diameterRef)
rightHeight := search(node.Right, diameterRef)
if 1+leftHeight+rightHeight > *diameterRef {
*diameterRef = 1 + leftHeight + rightHeight
}
if leftHeight > rightHeight {
return 1 + leftHeight
}
return 1 + rightHeight
}
Scala
class Solution {
def treeDiameter(root: TreeNode): Int = {
val diameterHolder = Array(0)
search(root, diameterHolder)
diameterHolder(0)
}
private def search(node: TreeNode, diameterHolder: Array[Int]): Int = {
if (node == null) return 0
val leftHeight = search(node.left, diameterHolder)
val rightHeight = search(node.right, diameterHolder)
diameterHolder(0) = Math.max(diameterHolder(0), 1 + leftHeight + rightHeight)
1 + Math.max(leftHeight, rightHeight)
}
}
Kotlin
import kotlin.math.max
class Solution {
fun treeDiameter(root: TreeNode?): Int {
val diameterHolder = intArrayOf(0)
search(root, diameterHolder)
return diameterHolder[0]
}
private fun search(node: TreeNode?, diameterHolder: IntArray): Int {
if (node == null) return 0
val leftHeight = search(node.left, diameterHolder)
val rightHeight = search(node.right, diameterHolder)
diameterHolder[0] = max(diameterHolder[0], 1 + leftHeight + rightHeight)
return 1 + max(leftHeight, rightHeight)
}
}
Swift
class Solution {
func treeDiameter(_ root: TreeNode?) -> Int {
var diameterHolder = 0
search(root, &diameterHolder)
return diameterHolder
}
private func search(_ node: TreeNode?, _ diameterHolder: inout Int) -> Int {
guard let node = node else { return 0 }
let leftHeight = search(node.left, &diameterHolder)
let rightHeight = search(node.right, &diameterHolder)
diameterHolder = max(diameterHolder, 1 + leftHeight + rightHeight)
return 1 + max(leftHeight, rightHeight)
}
}
Rust
impl Solution {
pub fn tree_diameter(&self, root: Option<Box<TreeNode>>) -> i32 {
let mut diameter = 0;
Self::search(&root, &mut diameter);
diameter
}
fn search(node: &Option<Box<TreeNode>>, diameter: &mut i32) -> i32 {
match node {
None => 0,
Some(current) => {
let left_height = Self::search(¤t.left, diameter);
let right_height = Self::search(¤t.right, diameter);
*diameter = std::cmp::max(*diameter, 1 + left_height + right_height);
1 + std::cmp::max(left_height, right_height)
}
}
}
}
C#
public class Solution {
public int TreeDiameter(TreeNode? root) {
int[] diameterHolder = {0};
Search(root, diameterHolder);
return diameterHolder[0];
}
private int Search(TreeNode? node, int[] diameterHolder) {
if (node == null) return 0;
int leftHeight = Search(node.left, diameterHolder);
int rightHeight = Search(node.right, diameterHolder);
diameterHolder[0] = Math.Max(diameterHolder[0], 1 + leftHeight + rightHeight);
return 1 + Math.Max(leftHeight, rightHeight);
}
}
Dart
class Solution {
int treeDiameter(TreeNode? root) {
List<int> diameterHolder = [0];
search(root, diameterHolder);
return diameterHolder[0];
}
int search(TreeNode? node, List<int> diameterHolder) {
if (node == null) return 0;
int leftHeight = search(node.left, diameterHolder);
int rightHeight = search(node.right, diameterHolder);
diameterHolder[0] = max(diameterHolder[0], 1 + leftHeight + rightHeight);
return 1 + max(leftHeight, rightHeight);
}
}
PHP
class Solution {
public function treeDiameter(?TreeNode $root): int {
$diameterHolder = 0;
$this->search($root, $diameterHolder);
return $diameterHolder;
}
private function search(?TreeNode $node, int &$diameterHolder): int {
if ($node === null) return 0;
$leftHeight = $this->search($node->left, $diameterHolder);
$rightHeight = $this->search($node->right, $diameterHolder);
$diameterHolder = max($diameterHolder, 1 + $leftHeight + $rightHeight);
return 1 + max($leftHeight, $rightHeight);
}
}
Ruby
class Solution
def tree_diameter(root)
diameter_holder = [0]
search(root, diameter_holder)
diameter_holder[0]
end
def search(node, diameter_holder)
return 0 if node.nil?
left_height = search(node.left, diameter_holder)
right_height = search(node.right, diameter_holder)
diameter_holder[0] = [diameter_holder[0], 1 + left_height + right_height].max
1 + [left_height, right_height].max
end
end