Chat Memory
Multi-turn conversations require memory management to maintain context while avoiding overwhelming the LLM.
1. Naive Approach: Full History
Sending the full chat history may overwhelm LLMs.
class ChatNode(Node):
def prep(self, shared):
if "history" not in shared:
shared["history"] = []
user_input = input("You: ")
return shared["history"], user_input
def exec(self, inputs):
history, user_input = inputs
messages = [{"role": "system", "content": "You are a helpful assistant"}]
for h in history:
messages.append(h)
messages.append({"role": "user", "content": user_input})
response = call_llm(messages)
return response
def post(self, shared, prep_res, exec_res):
shared["history"].append({"role": "user", "content": prep_res[1]})
shared["history"].append({"role": "assistant", "content": exec_res})
return "continue"
chat = ChatNode()
chat - "continue" >> chat
flow = Flow(start=chat)
2. Improved Memory Management
We can:
- Limit the chat history to the most recent 4.
- Use vector search to retrieve relevant exchanges beyond the last 4.
class ChatWithMemory(Node):
def prep(self, s):
# Initialize shared dict
s.setdefault("history", [])
s.setdefault("memory_index", None)
user_input = input("You: ")
# Retrieve relevant past if we have enough history and an index
relevant = []
if len(s["history"]) > 8 and s["memory_index"]:
idx, _ = search_index(s["memory_index"], get_embedding(user_input), top_k=2)
relevant = [s["history"][i[0]] for i in idx]
return {"user_input": user_input, "recent": s["history"][-8:], "relevant": relevant}
def exec(self, c):
messages = [{"role": "system", "content": "You are a helpful assistant."}]
# Include relevant history if any
if c["relevant"]:
messages.append({"role": "system", "content": f"Relevant: {c['relevant']}"})
# Add recent history and the current user input
messages += c["recent"] + [{"role": "user", "content": c["user_input"]}]
return call_llm(messages)
def post(self, s, pre, ans):
# Update chat history
s["history"] += [
{"role": "user", "content": pre["user_input"]},
{"role": "assistant", "content": ans}
]
# When first reaching 8 messages, create index
if len(s["history"]) == 8:
embeddings = []
for i in range(0, 8, 2):
e = s["history"][i]["content"] + " " + s["history"][i+1]["content"]
embeddings.append(get_embedding(e))
s["memory_index"] = create_index(embeddings)
# Embed older exchanges once we exceed 8 messages
elif len(s["history"]) > 8:
pair = s["history"][-10:-8]
embedding = get_embedding(pair[0]["content"] + " " + pair[1]["content"])
s["memory_index"].add(np.array([embedding]).astype('float32'))
print(f"Assistant: {ans}")
return "continue"
chat = ChatWithMemory()
chat - "continue" >> chat
flow = Flow(start=chat)
flow.run({})