public ListNode partition(ListNode head, int x) {
ListNode temp = null;
ListNode bigHead = null, bigTail = null;
ListNode smallHead = null, smallTail = null;
while (head != null) {
temp = head.next;
head.next = null;
if (head.val < x) {
if (smallHead == null) {
smallHead = head;
} else {
smallTail.next = head;
}
smallTail = head;
} else {
if (bigHead == null) {
bigHead = head;
} else {
bigTail.next = head;
}
bigTail = head;
}
head = temp;
}
if (smallHead == null) {
return bigHead;
}
smallTail.next = bigHead;
return smallHead;
}