/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
// every node carry information 0 or 1, 1 means find one of p/q, 0 means find none of them
// when root.left = 1 and root.right = 1, then return root
// edge case
if(root == p or root == q) return root;
// regular case, BFS
pair<TreeNode*, int> result = find(root, p, q);
return result.first;
}
// pair<TreeNode*, int> --> pair<ancestor, find 0/1/2>
pair<TreeNode*, int> find(TreeNode* root, TreeNode* p, TreeNode* q) {
// stop condition
if(root == NULL) return {NULL, 0};
if(root->left == NULL && root->right == NULL) {
if(root == p || root == q) return {NULL, 1};
return {NULL, 0};
}
// regular case
pair<TreeNode*, int> left = find(root->left, p, q);
pair<TreeNode*, int> right = find(root->right, p, q);
int r = (root == p || root == q) ? 1 : 0;
if(left.second == 2) return left;
if(right.second == 2) return right;
// return root if result != 2 doesn't matter, because this fake root will be replaced by real root eventually
return {root, left.second+right.second+r};
}
};