문제 원본: https://leetcode.com/problems/3sum/

문제

정수 배열 nums가 주어졌을 때, i != j, i != k, j != k, nums[i] + nums[j] + nums[k] == 0일 경우에 [nums[i], nums[j], nums[k]]의 세 개의 숫자로 이루어진 모든 집합을 반환하라.

답에는 중복된 집합을 포함하지 않는다.

Example 1:

Input: nums = [-1,0,1,2,-1,-4]
Output: [[-1,-1,2],[-1,0,1]]
Explanation: 
nums[0] + nums[1] + nums[2] = (-1) + 0 + 1 = 0.
nums[1] + nums[2] + nums[4] = 0 + 1 + (-1) = 0.
nums[0] + nums[3] + nums[4] = (-1) + 2 + (-1) = 0.
서로 다른 집합은 [-1, 0, 1]과 [-1, -1, 2]이다.  
결과의 순서와 집합의 순서는 상관 없다.

Example 2:

Input: nums = [0,1,1]
Output: []
Explanation: 단 하나의 가능성 있는 집합의 원소 합은 0이 아니다.

Example 3:

Input: nums = [0,0,0]
Output: [[0,0,0]]
Explanation: 단 하나의 가능성 있는 집합의 원소 합은 0이다.

Constraints:

  • 3 <= nums.length <= 3000
  • -10^5 <= nums[i] <= 10^5

접근 방법

재귀 사용

나는 처음에 재귀를 사용한 방법을 떠올렸다. 그도 그럴 게, 이게 어찌보면 n개의 숫자 중에서 3개의 숫자를 뽑는 것과 같다고 봤다. 단, 여기서는 그 3개의 숫자에 조건이 추가되었다는 점만 다르다고 봤다.
그래서 내가 생각한 로직은 아래와 같았다.

  1. nums에서 어떤 숫자가 선택되었는지를 저장할 selectedIdx 배열을 nums.length와 똑같은 크기로 생성한다. 초기화는 0으로 한다. nums[1]이 선택되었다면, selectedIdx[1] = 집합의 번호와 같은 식으로 값이 저장된다.
  2. 재귀 함수를 호출하는데, 이때 패러미터로 넘겨줄 것은 nums 배열, selectedIdx 배열, 그리고 조합을 수행할 시작 인덱스인 startIdx를 넘겨준다. 물론 초기에 재귀 함수를 호출할 때에는 startIdx = 0일 것이다.
  3. 재귀함수 내부에서는 startIdx번째 숫자를 제외한 다른 두 개의 숫자에 대해 for문을 이중으로 돌면서 문제에서 주어진 조건에 부합하는지를 살펴보도록 했다. 내부 for문에서 해당 조건을 검사할 것이며, 만약 조건에 부합한다면 selectedIdx에 해당 위치와 집합의 번호를 저장하고 재귀 함수를 startIdx + 1인 상태로 호출한다.
  4. 재귀 함수의 종료 조건은 startIdx == nums.length - 1이면 세 개의 숫자를 뽑는 것이 끝났다는 의미에서 현재까지 저장한 selectedIdx를 반환하도록 했다. 만약 완성하지 못했다면 selectedIdx에 여지껏 저장해둔 값을 전부 0으로 초기화한 후 반환하도록 했다. 그러면 맨 처음 호출된 재귀함수에서 for문을 돌 때, startIdx + 1인 상태에서 다시 재귀를 시작할 것이다.
  5. 모든 재귀 함수 호출 스택이 실행되면 selectedIdx를 반환한다.
  6. selectedIdx에서 0이 아닌 같은 집합의 번호(원소)를 가진 인덱스 세 개를 구해서, 해당 인덱스를 이용하여 nums로부터 숫자를 가져와 리스트에 저장한다.
  7. 모든 숫자 집합을 구한 후에, 그 숫자 집합을 반환한다.

라고 생각을 했는데… 쉽지 않았다. 재귀 함수가 제대로 돌아가게 하는 것도 문제였다. nums = [0, 0, 0]의 케이스에서는 잘 작동했지만, 그 외의 케이스에서는 종료 조건에서 쓸데없이 걸리는 바람에 원하는 값이 출력되지 않았다.

아래의 코드가 내가 짜다 만(시간 제한을 걸어뒀는데, 그 시간 내에 못 풀었다) 코드이다. intellij에서 테스트한다고 main 함수가 있기는 한데, 그거 말고는 leetcode에서 문제될 건 없다. 그리고 이것저것 실험해본다고 좀 지저분하다…

package com.company;

import java.sql.Array;
import java.util.ArrayList;
import java.util.List;

public class Main {

    public static void main(String[] args) {
        class Solution {
            public List<List<Integer>> threeSum(int[] nums) {
                int[] selectedIdx = new int[nums.length];
                selectedIdx = initializeSelectedIdx(selectedIdx);
                selectedIdx = getSelectedIdx(nums, selectedIdx, 0);
                List<ArrayList<Integer>> triplets = getTripletsByIdx(nums, selectedIdx);
                return triplets; // Error
            }

            public int[] initializeSelectedIdx(int[] selectedIdx) {
                for (int i = 0; i < selectedIdx.length; i++)
                    selectedIdx[i] = 0;
                return selectedIdx;
            }

            public int[] getSelectedIdx(int[] nums, int[] selectedIdx, int startIdx) {
                if (startIdx == nums.length - 2) {
                    if ()
                        initializeSelectedIdx(selectedIdx);
                    return selectedIdx;
                }
                for (int i = startIdx + 1; i < nums.length; i++) {
                    for (int j = i + 1; j < nums.length; j++) {
                        if (nums[i] + nums[j] + nums[startIdx] == 0) {
                            selectedIdx[startIdx] = startIdx + 1;
                            selectedIdx[i] = startIdx + 1;
                            selectedIdx[j] = startIdx + 1;
                            selectedIdx = getSelectedIdx(nums, selectedIdx, startIdx + 1);
                        }
                    }
                }
                return selectedIdx;
            }

            public List<ArrayList<Integer>> getTripletsByIdx(int[] nums, int[] selectedIdx) {
                ArrayList<ArrayList<Integer>> triplets = new ArrayList<ArrayList<Integer>>();

                for (int i = 0; i < selectedIdx.length; i++) {
                    for (int j = i + 1; j < selectedIdx.length; j++) {
                        for (int n = j + 1; n < selectedIdx.length; n++) {
                            if ((selectedIdx[i] = selectedIdx[j] = selectedIdx[n]) != 0) {
                                ArrayList<Integer> triplet = new ArrayList<>();
                                triplet.add(nums[i]);
                            }
                        }
                    }
                }
                return triplets;
            }
        }

        Solution solution = new Solution();
        int[] nums = { -1, 0, 1, 2, -1, -4 };
        List<List<Integer>> triplets = solution.threeSum(nums);
        for (int i = 0; i < triplets.size(); i++) {
            System.out.println(triplets.get(i));
        }
    }
}

풀이

class Solution {
    public List<List<Integer>> threeSum(int[] nums) {
        Arrays.sort(nums);
        List<List<Integer>> result = new ArrayList<>();

        for (int i = 0; i < nums.length; i++) {
            if (i > 0 && nums[i] == nums[i - 1]) continue;

            int start = i + 1;
            int end = nums.length - 1;

            while (start < end) {
                int sum = nums[i] + nums[start] + nums[end];

                if (sum == 0) {
                    result.add(Arrays.asList(nums[i], nums[start], nums[end]));
                    start++;
                    end--;

                    while (start < end && nums[start] == nums[start - 1]) {
                        start++;
                    }
                    while (start < end && nums[end] == nums[end + 1]) {
                        end--;
                    }
                } else if (sum < 0) {
                    start++;
                } else {
                    end--;
                }
            }
        }
        return result;
    }
}

우선 nums 배열을 정렬하면서 시작한다. 그리고 nums 배열의 요소를 순회한다.

if (i > 0 && nums[i] == nums[i - 1]) continue;가 왜 있는지 의아한데, 이는 중복을 피하기 위함이다.

현재 nums는 정렬이 되어 있는 상태로, nums = [-4, -1, -1, 0, 1, 2]라면, i = 1일 때 이미 -1이 포함된 triplet이 완성되었을 것이다.
그 상황에서 i = 2일 때 -1이 포함된 triplet을 또 생성하려 한다면, 중복된 결과가 나온다.
따라서 중복을 피하기 위해 같은 숫자가 연속하여 배열되어있다면, 가장 마지막 숫자를 기준으로 처리한다.

그리고 starti 다음 숫자를, endnums의 가장 마지막 숫자를 할당하여 while문을 순회한다. 이때 조건은 start < end.

nums[i] + nums[start] + nums[end] == 0인지 확인하여 0이라면 result 리스트에 추가한다.
아니라면 nums[start] == nums[start - 1]인 경우, 즉 위의 예시에 이어서 i = 1일 때 start++을 함으로써 같은 숫자의 연속된 배열 중, 가장 마지막 숫자의 위치를 start에 할당한다.
nums[end] == nums[end + 1]의 경우도 마찬가지로, 다른 점이 있다면 같은 숫자의 연속된 배열 중, 가장 첫 번째 숫자의 위치를 end에 할당한다.

그 외에 sum < 0이라면 start번째 숫자가 너무 작다는 뜻이므로, start++을 해주어 이전보다 큰 수를 더하여 sum == 0이 될 수 있게 한다.
sum > 0이라면, end번째 숫자가 너무 크다는 뜻으로, end--를 해주어 이전보다 작은 수를 더함으로써 sum == 0이 될 수 있게 한다.

댓글남기기