-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathConversationNode.py
More file actions
251 lines (202 loc) · 6.48 KB
/
ConversationNode.py
File metadata and controls
251 lines (202 loc) · 6.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
This module provides a ConversationNode class to represent a node in a conversation tree.
Each node stores the user who sent the message and the text of the message.
Nodes can have children, which represent replies or follow-up messages.
Features:
- Add child nodes to simulate a conversation.
- Print the conversation from any node to the root.
- Serialize the conversation tree to a JSON file.
- Deserialize a conversation tree from a JSON file.
Tutorial:
1. Create a root node:
root = ConversationNode(user="ROBOT", text="Hey, who is this?")
2. Add child nodes:
resp1 = ConversationNode("Enrique", "EM")
root.add(resp1)
3. Navigate through the conversation:
curr = resp1
4. Save the conversation to a file:
root.save_conversation_tree(curr=curr)
5. Load the conversation from a file:
root, curr = ConversationNode.load_conversation_tree()
"""
import json
class ConversationNode:
"""
Represents a node in a conversation tree.
Attributes:
user (str): The user who sent the message.
text (str): The text of the message.
depth (int): The depth of the node in the tree.
parent (ConversationNode): The parent node.
children (list): The child nodes.
"""
def __init__(self, text="None", user="system"):
"""
Initializes a ConversationNode with text and user.
Parameters:
text (str): The text of the message. Defaults to "None".
user (str): The user who sent the message. Defaults to "N/A".
"""
self.user = user
self.text = text
self.depth = 0
self.parent = None
self.children = []
def __str__(self):
_str = f"{self.user}:\"{self.text}\""
if len(self.children) > 0:
_str += f"\nChildren:[{','.join([str(x) for x in self.children])}]"
return _str
def add(self, msg):
"""
Adds a child node to the current node.
Parameters:
msg (ConversationNode): The child node to add.
"""
assert isinstance(msg, ConversationNode)
if msg not in self.children:
self.children.append(msg)
msg.depth = self.depth + 1
msg.parent = self
def delete(self):
for child in self.children:
child.delete()
del child
def print_conversation(self):
"""
Prints the conversation from the current node to the root node.
"""
if self.parent:
self.parent.print_conversation()
print(f"{self.user}: {self.text}")
def return_conversation(self):
"""
Returns the conversation from the current node to the root node.
"""
if self.parent:
conv = self.parent.return_conversation()
else:
conv = []
return conv + [(self.user, self.text)]
def save_conversation_tree(self, filename="test.conv", curr=None):
"""
Saves the conversation tree to a file.
Parameters:
filename (str): The name of the file to save the tree to. Defaults to "test.conv".
curr (ConversationNode): The current node in the conversation. Defaults to None.
"""
data = self.serialize()
curr_path = self.find_path_to_node(curr)
with open(filename, 'w') as file:
json.dump({"tree": data, "curr_path": curr_path}, file, indent=4)
@staticmethod
def load_conversation_tree(filename="test.conv"):
"""
Loads the conversation tree from a file.
Parameters:
filename (str): The name of the file to load the tree from. Defaults to "test.conv".
Returns:
tuple: The root and current nodes (root, curr).
"""
with open(filename, 'r') as file:
data = json.load(file)
root = ConversationNode.deserialize(data["tree"])
curr = root.find_node_by_path(data["curr_path"])
return root, curr
def serialize(self):
"""
Serializes the node to a dictionary.
Returns:
dict: A dictionary representation of the node.
"""
children_data = [child.serialize() for child in self.children]
return {
"user": self.user,
"text": self.text,
"depth": self.depth,
"children": children_data
}
@staticmethod
def deserialize(data):
"""
Deserializes a node from a dictionary.
Parameters:
data (dict): The dictionary containing the serialized node data.
Returns:
ConversationNode: The deserialized node.
"""
node = ConversationNode(data["text"], data["user"])
node.depth = data["depth"]
for child_data in data["children"]:
child = ConversationNode.deserialize(child_data)
node.add(child)
return node
def find_path_to_node(self, node, path=None):
"""
Finds the path to a node in the tree.
Parameters:
node (ConversationNode): The node to find the path to.
path (list): The current path. Should not be specified by the user.
Returns:
list: The path to the node as a list of child indices.
"""
if path is None:
path = []
if self == node:
return path
for i, child in enumerate(self.children):
new_path = path + [i]
found_path = child.find_path_to_node(node, new_path)
if found_path:
return found_path
return None
def find_node_by_path(self, path):
"""
Finds a node by its path in the tree.
Parameters:
path (list): The path to the node as a list of child indices.
Returns:
ConversationNode: The node at the specified path.
"""
node = self
for i in path:
node = node.children[i]
return node
def main():
root = ConversationNode(user="ROBOT", text="Hey, who is this?")
resp1 = ConversationNode("Enrique", "EM")
resp2 = ConversationNode("Grace", "GX")
resp11 = ConversationNode("Nice to meet you, Enrique!", "ROBOT")
resp21 = ConversationNode("What kind of name is Grace?", "ROBOT")
root.add(resp1)
root.add(resp2)
resp1.add(resp11)
resp2.add(resp21)
curr = resp1
hadparent = curr.parent is not None
oldparent = curr.parent
print(resp21.return_conversation())
for _ in range(3): # Reduced to 1 iteration for demonstration
try:
root.save_conversation_tree(curr=curr)
print("Saved conversation tree with position.")
except Exception as e:
print("Failed to save conversation tree with position.")
print(e)
try:
root, curr = ConversationNode.load_conversation_tree()
assert isinstance(root, ConversationNode)
assert isinstance(curr, ConversationNode)
if hadparent:
assert curr.parent is not None
# assert curr.parent == oldparent
print(curr.parent)
print(oldparent)
print("Loaded conversation tree with position and parent.")
print("Loaded conversation tree with position.")
except Exception as e:
print("Failed to load conversation tree with position.")
print(e)
if __name__ == '__main__':
main()