题目描述;
Given a binary tree, find the maximum path sum.
The path may start and end at any node in the tree.
For example:
Given the below binary tree,
1 / \ 2 3
Return 6
.
代码如下:
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(NULL), right(NULL) {}
};
struct GraphNode
{
int val;
GraphNode *left;
GraphNode *right;
GraphNode *prev;
GraphNode(int x) :val(x), left(NULL), right(NULL), prev(NULL){}
};
class Solution {
public:
int maxNum;
int maxPathSum(TreeNode *root)
{
if (!root)
return 0;
GraphNode *head = new GraphNode(root->val);
generateGraph(head, root);
maxNum = root->val;
head = getHead(head);
while (head->left || head->right || head->prev)
{
GraphNode *next=new GraphNode(0);
if (head->left)
next = head->left;
if (head->right)
next = head->right;
if (head->prev)
next = head->prev;
getMax(head, 0, NULL);
if (next->left == head)
next->left = NULL;
if (next->right == head)
next->right = NULL;
if (next->prev == head)
next->prev = NULL;
head = next;
}
maxNum = max(head->val , maxNum);
return maxNum;
}
//生成一个图
void generateGraph(GraphNode *start, TreeNode *root)
{
if (root->left)
{
GraphNode *left = new GraphNode(root->left->val);
left->prev = start;
start->left = left;
generateGraph(start->left, root->left);
}
if (root->right)
{
GraphNode *right = new GraphNode(root->right->val);
right->prev = start;
start->right = right;
generateGraph(start->right, root->right);
}
}
//获取从head开始的最大值,dir为0是向left传递,1时向right传递,2时向prev传递
int getMax(GraphNode *head, int count, GraphNode *prev)
{
if (head->left&&head->left!=prev)
return getMax(head->left, count, head);
if (head->right&&head->right != prev)
return getMax(head->right, count, head);
if (head->prev&&head->prev != prev)
return getMax(head->prev, count, head);
count += head->val;
maxNum = max(maxNum, count);
}
//找到一个端点
GraphNode* getHead(GraphNode* head)
{
int pNum = 0;
if (head->left)
pNum++;
if (head->right)
pNum++;
if (head->prev)
pNum++;
if (pNum <= 1)
return head;
if (head->left)
return getHead(head->left);
if (head->right)
return getHead(head->right);
if (head->prev)
return getHead(head->prev);
}
};
然后开始思考DP的解法。DP[i]表示以i为端点(起点或者终点)的最大值。用i1、i2分别表示i的左右孩子,sum表示当最大值。DP公式如下:
DP[i] = max(max(DP[i1] , DP[i2])+val(i), val(i))
sum=max(sum, DP[i], DP[i1]+DP[i2]+val(i))
代码如下:
class Solution {
public:
int maxSum;
int maxPathSum(TreeNode *root)
{
maxSum = root->val;
DP(root);
return maxSum;
}
int DP(TreeNode *root)
{
int left(0), right(0);
if (root->left)
left = DP(root->left);
if (root->right)
right = DP(root->right);
int dp = max(max(left, right) + root->val, root->val);
maxSum = max(maxSum, dp);
maxSum = max(maxSum, left + right + root->val);
return dp;
}
};
果然简单多了……