상호배타적 집합(Disjoint Set)

오리엔테이션이나 수학 여행 등에서 서로 친해지기 위해서 하는 활동 중에 이런 활동들을 한번쯤 해보셨을 겁니다. 다들 서로 낯서고 모르는 상태이기 때문에 어색어색하죠. 이 순간은 서로가 각각 혼자서 있습니다. 그럴때 사회자가 혈액형이 같은 사람끼리 모이라고 했을때, A, B, AB, O 형으로 나뉘어진 각 사람들은 같은 혈액형을 찾으려고 시도하겠죠? 그래서 같으면 같은 조로 묶이게 됩니다.

이처럼 같은 특성끼리 서로 모이는 집합이 여러 집합이 있는데, 각 집합에서는 공통 원소가 존재하지 않는 그런 집합을 상호배타적 집합(Disjoint Set)이라고 합니다. 한 집합에서 다른 특성을 갖는 집합을 배제한다는 뜻입니다. 이때 쓰이는 자료구조가 유니온-파인드(Union-Find) 자료구조라고 합니다.

구현

상호배타적 집합을 구현하기 위해서는 트리의 자료구조를 사용합니다. 각 트리 노드들은 루트 노드를 가지고 있지요. 그래서 그 루트노드를 기준으로 같다면 같은 집합에 속해있는 것이고, 다르다면 다른 집합에 속해있다는 것을 알 수 있습니다.

1. 초기화

초기의 상태는 각 노드들이 자신이 루트가 됩니다. 아직 어떤 노드도 만난적이 없으니까 집합에서의 리더는 자기 자신일테니까요. 그러다가 같은 속성이 있다면 서로를 합치게 되는 것이죠. 아래의 코드가 초기화 코드를 나타냅니다. parent는 자신의 상위 트리 노드를 말하며 처음에는 자기 자신이 됩니다. 나머지 rank 변수는 이 후 설명하도록 하겠습니다.

 

struct DisjointSet {
	vector<int> parent, rank;
	DisjointSet(int n) :parent(n), rank(n, 1) {
		for (int i = 0; i < n; i++)
			parent[i] = i;
	}
};

 

아래는 두 트리가 있음을 보여줍니다. 두 트리의 루트는 1과 6이라서 이 트리들은 서로 다른 집합을 나타내는 것이라고 할 수 있습니다.

루트가 다른 트리

 

2. 루트 찾기

여기서 각 루트를 찾는 코드는 아래와 같습니다. 

int find(int node) {
	if (node == parent[node]) return node;
	return parent[node] = find(parent[node]);
}

 

find는 현재 노드가 속해있는 트리의 루트를 구하는 함수인데요. 여기서 재귀적으로 find를 호출해서 부모를 계속 찾다보면 나중에는 결국 부모가 자기 자신인 노드를 발견하게 됩니다. 왜냐면 루트의 부모는 없으며 루트는 위의 초기화에서 자기 자신이 부모 노드입니다. 그리고 반복적으로 find를 계속 호출하지 않기 위해서 parent[node] = find(parent[node]) 로 부모를 계산된 루트로 기록해둡니다. 그러면 나중에 루트를 찾을때 단번에 루트로 가서 반환되게 때문에 최적화를 할 수 있습니다.

3. 트리 합치기

루트가 다른 이 트리, 알보고니 공통 속성이 있어서 합쳐야될 것 같습니다. 합쳐보긴할텐데 어느쪽으로 합쳐야될까요? 우리는 루트가 1인 트리를 루트가 6인 트리에 합쳐야할까요? 아니면 6인 트리를 1인 트리 밑으로 가게 만들까요? 우리가 find를 구현했을때 루트를 찾이 위해서 계속 재귀적으로 부모를 호출해가면서 확인해보았습니다. 

그렇다면 루트를 빨리 찾기 위해서는 노드의 부모수가 적으면 속도가 빨라지겠죠. 그런 원리로 최적화하면 됩니다. 즉, 트리의 높이를 작게 만들수록 유리합니다. 

루트가 다른 트리

 

1이 루트인 트리에 6이 루트인 트리를 밑에 합쳐놓은 상황입니다. 트리의 높이가 증가했음을 알 수 있죠? 그렇다면 반대로 합쳐놓는다면, 즉 6이 루트인 트리에 1이 루트인 트리를 보면 어떨까요?

 

트리의 높이 증가

 

아래의 그림이 그 상황을 보여줍니다. 트리의 높이가 증가하지 않았음을 볼 수 있습니다. 그렇다면 결국에는 트리의 높이가 큰 트리 밑에 작은 트리를 합쳐 놓는게 높이를 증가시키지 않는 방법이네요. 이제 구현해봅시다.

트리의 높이 증가하지 않음

 

아래의 구현이 병합하는 코드를 구현한 것입니다.

void merge(int left, int right) {
	left = find(left), right = find(right);
	if (left == right) return;
	if (rank[left] > rank[right]) swap(left, right);
	parent[left] = right;
	if (rank[left] == rank[right]) ++rank[right];
}

 

왼쪽 트리가 항상 오른쪽 트리의 자식이 될 수 있도록 구현한 코드인데요. 왼쪽 트리의 루트를 구하고, 오른쪽 트리의 루트를 구해서 같으면 이미 같은 트리에 속해있는 것으로 종료합니다. 

그 후 rank라는 배열을 통해서 트리의 높이를 얻어옵니다. 항상 왼쪽에 작아서 오른쪽으로 합칠 구현을 해야하니까 값이 왼쪽이 크다면 교환해줍니다. 그래서 결국 왼쪽 부모는 오른쪽의 부모와 같게 만들죠.

트리의 높이가 증가하는 경우는 왼쪽, 오른쪽 트리의 rank값이 같을때만 존재합니다.

전체 소스코드는 아래와 같게됩니다. 아래의 코드는 알고리즘 문제해결 전략, 구종만 님의 코드를 참고했습니다.

 

struct DisjointSet {
	vector<int> parent, rank;
	DisjointSet(int n) :parent(n), rank(n, 1) {
		for (int i = 0; i < n; i++)
			parent[i] = i;
	}

	int find(int node) {
		if (node == parent[node]) return node;
		return parent[node] = find(parent[node]);
	}

	void merge(int left, int right) {
		left = find(left), right = find(right);
		if (left == right) return;
		if (rank[left] > rank[right]) swap(left, right);
		parent[left] = right;
		if (rank[left] == rank[right]) ++rank[right];
	}
};

 

예제 : BOJ 1717번, 집합의 표현

정확한 이해를 돕기 위해서 아래의 문제를 풀어봅시다. DisjointSet을 이용해서 풀 수 있는 문제로 위의 동작을 이해할 수 있으면 거뜬하게 풀 수 있는 문제입니다.

www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는

www.acmicpc.net

 

문제의 입력은 아래와 같습니다.

7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1

 

처음 N과 M이 주어지며 N은 집합의 원소 중 가장 마지막 원소, M은 질의를 나타냅니다.

다음은 질의가 M가 나오는데 질의의 유형은 0 a b, 1 a b가 있습니다. 0 a b는 a가 포함된 집합, b가 포함된 집합을 합치는 연산이며 1 a b는 a와 b가 같은 집합에 속해있는지 확인하는 연산입니다. 이때 1 a b의 질의에서 같은 집합에 속한다면 YES를 출력, 아니면 NO를 출력하면 됩니다.

위의 입력에 대한 출력은 아래와 같습니다.

NO
NO
YES

 

풀이 

맨 처음에는 각각의 수들은 다른 집합에 속해있다가 연산 0이 나오면 합치면 됩니다. 이 역할이 DisjointSet의 merge가 되겠군요. 그리고 연산 1이면 a와 b가 같은 집합에 속해있는지 확인하는 연산이므로 find(a), find(b)를 통해 두 집합의 root를 구한후 같은지 다른지 확인하면 되는 간단한 문제입니다.

문제의 정답 코드는 이렇습니다.

#include <cstdio>
#include <vector>
using namespace std;
#define MERGE 0
#define SAME_GROUP 1

struct DisjointSet {
	vector<int> parent, rank;
	DisjointSet(int n) :parent(n), rank(n, 1) {
		for (int i = 0; i < n; i++)
			parent[i] = i;
	}

	int find(int node) {
		if (node == parent[node]) return node;
		return parent[node] = find(parent[node]);
	}

	void merge(int left, int right) {
		left = find(left), right = find(right);
		if (left == right) return;
		if (rank[left] > rank[right]) swap(left, right);
		parent[left] = right;
		if (rank[left] == rank[right]) ++rank[right];
	}
};

int main() {
	int N, M;
	scanf("%d %d", &N, &M);
	DisjointSet djs(N+1);
	for (int i = 0; i < M; i++) {
		int query, a, b;
		scanf("%d %d %d", &query, &a, &b);
		if (query == MERGE) djs.merge(a, b);
		
		if (query == SAME_GROUP) 
			printf("%s\n", djs.find(a) == djs.find(b) ? "YES" : "NO");
	}

}

 

이상으로 상호배타적 집합에 대한 개념과 코드 구현, DisjointSet을 이용해서 문제를 풀어보았습니다. DisjointSet 알고 있으면 도움이 될 것 같나요?

반응형
블로그 이미지

REAKWON

와나진짜

,

Linked List


링크드리스트란 짧게 이야기해서 노드를 연결시킨 자료구조입니다. 노드는 무엇일까요?


링크드리스트에서 데이터를 갖고 있는 데이터의 묶음입니다. 그림으로 보는 것이 편할 것 같네요.





데이터들을 갖고 있는 하나의 요소가 보이시나요? 이것이 노드입니다. 노드 속에 다음 노드를 가리키고 있습니다.

화살표 모양으로 보아하니, 포인터군요!

특히, 제일 앞에 있는 노드는 헤드(head), 제일 끝 노드는 테일(tail)이라고 부릅니다.




head와 tail은 데이터 필드는 있지만 쓰지 않을 겁니다. 구현의 용이함을 위해선데요. 만약 head와 tail의 데이터 필드를 쓰게 되면  추가, 삭제시 3가지를 고려해야합니다.


1) 추가, 삭제할 노드가 맨 앞 노드인가

2) 추가, 삭제할 노드가 맨 뒤 노드인가

3) 추가, 삭제할 노드가 중간 노드인가


하지만 head와 tail의 data를 쓰지 않는다면 3)번 조건만 고려하면 되기 때문입니다.


이처럼 아무 데이터가 없는 노드를 더미노드라고 합니다.


이딴걸 왜 쓰지?

링크드리스트의 가장 큰 장점은 리스트의 길이가 가변적이라는 겁니다. 배열의 단점을 링크드리스트가 커버 할 수 있습니다.

배열은 크기가 가변적이지 않죠. 우선 크기를 정해준 다음에 모자라면 메모리를 더 할당하고, 배열의 데이터를 복사해야 되죠. 오래걸리고 비효율적이라는 것을 알 수 있겠죠?

배열을 사용해서 다시 사이즈를 늘리는 코드를 이런식으로 짤 수 있겠죠.

data *newList = (data*)malloc(sizeof(data)*(2*size));
for (int i = 0; i < size; size++)
	list[i] = newList[i];
free(list);


위 코드에서 보는 것과 같이 메모리를 할당한 후 for루프로 기존의 있던 값을 복사합니다.


하지만 링크드리스트는 어떨까요?

링크드리스트는 다음 노드만 추가하면 되기 때문에 리스트의 사이즈를 조정하는데, 그리 큰 비용을 들이지 않습니다.


아 좋은거네! 이것만 쓰면 되겠구만

이라고 생각할 수도 있겠지만, 링크드리스트는 어떤 노드를 search하거나 데이터를 변경할때 바로 찾아낼 수가 없습니다.

링크드리스트를 전부 탐색해야할 수도 있습니다.


그러니까 데이터가 자주 추가되거나 가변적으로 자주 변하게 될 프로그램이라면 링크드리스트를 쓰는 것이 좋겠고, 주로 데이터의 변경이나 탐색을 위한 것이라면 배열을 쓰는 것이 좋습니다. case-by-case죠.




구현

포인터나 구조체를 제대로 배우지 않은 사람이라면 약간 어려울 수 있습니다.


우리는 다음과 같은 연산을 정의할 겁니다.



1. addFirst(list, data)

링크드리스트의 새로운 노드를 추가합니다. 가장 앞 있는 노드(head의 다음)에 새로운 노드를 추가하는 연산입니다.


2. addLast(list, data)

addFirst와 반대로 가장 뒤에 노드를 추가합니다.


3. removeNode(list, data)

링크드리스트가 갖고 있는 노드 중에 data값을 갖고 있는 노드를 삭제합니다. 


4. searchNode(list,data)

링크드 리스트에서 data의 값과 일치하는 노드를 검색합니다.


5. printList()

링크드리스트를 전부 탐색합니다. 리스트의 데이터를 전부 보여줍니다.


각각의 연산을 어떻게 구현하면 될 지 그림으로 설명하도록 하지요.


1. addFirst(list, data)

가장 첫번째에 노드를 추가합니다. 가장 첫번째라고 해서 head 앞에 추가해서는 안됩니다. head의 다음 노드에 추가해야합니다.

head의 다음 노드를 새로운 노드로 가리키게 만들고, 그 새로운 노드는 이전에 head가 가리키고 있던 노드를 가리키면 됩니다.


코드는 아래와 같습니다.


void addFirst(List *list,int data) {
	Node *newNode = (Node*)malloc(sizeof(Node));
	newNode->data = data;
	newNode->next = list->head->next;
	list->head->next = newNode;
	list->size++;
}


2. addLast(list, data)

가장 마지막(tail 앞)에 노드를 추가하는 연산입니다. 일단 tail앞까지 가야하지만, 그 전의 노드를 기억해야합니다. 그러니까 살짝 까다로울 수 있지요.


아래는 노드 삭제연산을 구현한 코드입니다.


void addLast(List* list, int data) {
	Node *last = list->head;
	
	while (last->next != list->tail) 
		last = last->next;
	
	Node *newNode = (Node*)malloc(sizeof(Node));
	newNode->data = data;
	newNode->next = list->tail;
	last->next = newNode;
	list->size++;
}




3. removeNode(list, data)

리스트의 노드를 하나씩 돌면서 data가 일치하면 그 노드를 삭제하는 겁니다.

주의할 점은 그 노드 다음 노드를 이전의 노드가 가리키는 작업이 우선적으로 이루어져야한다는 겁니다.



void removeNode(List *list, int data) {
	Node *node = list->head;
	while (node->next != list->tail) {
		if (node->next->data == data) {
			Node *delNode = node->next;
			node->next = delNode->next;
			free(delNode);
			list->size--;
			return;
		}
		node = node->next;
	}
	printf("데이터를 찾지 못했습니다.\n");
}

4. searchNode(list, data)

삭제 연산보다 쉽습니다. list를 돌면서 data와 값이 일치하면 그 노드를 반환하면 되니까요.


void removeNode(List *list, int data) {
	Node *node = list->head;
	while (node->next != list->tail) {
		if (node->next->data == data) {
			Node *delNode = node->next;
			node->next = delNode->next;
			free(delNode);
			list->size--;
			return;
		}
		node = node->next;
	}
	printf("데이터를 찾지 못했습니다.\n");
}


5. printList(list)

이 함수 역시 정말 쉽습니다. search 연산과 별 다를 것이 없죠. 그냥 list돌면서 하나하나 출력해주기만 하면 됩니다.




전체 코드


#include <stdio.h>
#include <stdlib.h>


typedef struct node {
	int data;
	struct node *next;
} Node;
typedef struct list {
	Node *head;
	Node *tail;
	int size;
} List;

void createlist(List *list) {
	
	list->head = (Node*)malloc(sizeof(Node));
	list->tail = (Node*)malloc(sizeof(Node));
	list->head->next = list->tail;
	list->tail->next = list->tail;
	list->size = 0;
}
void addFirst(List *list,int data) {
	Node *newNode = (Node*)malloc(sizeof(Node));
	newNode->data = data;
	newNode->next = list->head->next;
	list->head->next = newNode;
	list->size++;
}
void addLast(List* list, int data) {
	Node *last = list->head;
	
	while (last->next != list->tail) 
		last = last->next;
	
	Node *newNode = (Node*)malloc(sizeof(Node));
	newNode->data = data;
	newNode->next = list->tail;
	last->next = newNode;
	list->size++;
}

Node* searchNode(List *list, int data) {
	Node *node = list->head->next;
	while (node != list->tail) {
		if (node->data == data)
			return node;
		node = node->next;
	}
	printf("데이터를 찾지 못했습니다.\n");

	return NULL;
}

void removeNode(List *list, int data) {
	Node *node = list->head;
	while (node->next != list->tail) {
		if (node->next->data == data) {
			Node *delNode = node->next;
			node->next = delNode->next;
			free(delNode);
			list->size--;
			return;
		}
		node = node->next;
	}
	printf("데이터를 찾지 못했습니다.\n");
}



void printList(List *list) {
	Node *node = list->head->next;
	int i = 1;
	while (node != list->tail) {
		printf("%d 번째 노드 데이터 :%d\n", i++, node->data);
		node = node->next;
	}
}
void distroyList(List *list) {
	Node *node = list->head;
	while (node != list->tail) {
		Node *delNode = node;
		node = delNode->next;
		free(delNode);
	}
	free(list->head);
	free(list->tail);
}

int main() {

	int i;
	List list;
	createlist(&list);
	
	for (i = 1; i <= 5; i++)
		addLast(&list, i);
	for (i = 11; i <= 15; i++)
		addFirst(&list, i);
	removeNode(&list, 11);
	removeNode(&list, 15);
	removeNode(&list, 5);
	removeNode(&list, 4);
	removeNode(&list, 50);

	Node *node = searchNode(&list, 14);
	printf("search :%d\n", node->data);

	node=searchNode(&list,12);
	printf("search :%d\n", node->data);

	node = searchNode(&list, 3);
	printf("search :%d\n", node->data);

	printList(&list);
	return 0;
}


제가 구현한 코드가 삑사리가 있을지도 모릅니다. 그럴땐 여러분들이 고쳐보세요! 그러면서 실력이 느는거 아니겠습니까?! 하하


반응형
블로그 이미지

REAKWON

와나진짜

,