streaming response
This commit is contained in:
+85
-61
@@ -19,10 +19,20 @@ import {
|
|||||||
import { usePageContext } from "vike-react/usePageContext";
|
import { usePageContext } from "vike-react/usePageContext";
|
||||||
import { useData } from "vike-react/useData";
|
import { useData } from "vike-react/useData";
|
||||||
import type { Data } from "./+data";
|
import type { Data } from "./+data";
|
||||||
import type { CommittedMessage, DraftMessage } from "../../../types";
|
import type {
|
||||||
|
CommittedMessage,
|
||||||
|
DraftMessage,
|
||||||
|
OtherParameters,
|
||||||
|
} from "../../../types";
|
||||||
import Markdown from "react-markdown";
|
import Markdown from "react-markdown";
|
||||||
import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react";
|
import {
|
||||||
import { useTRPC } from "../../../trpc/client";
|
IconTrash,
|
||||||
|
IconEdit,
|
||||||
|
IconCheck,
|
||||||
|
IconX,
|
||||||
|
IconLoaderQuarter,
|
||||||
|
} from "@tabler/icons-react";
|
||||||
|
import { useTRPC, useTRPCClient } from "../../../trpc/client";
|
||||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||||
import { nanoid } from "nanoid";
|
import { nanoid } from "nanoid";
|
||||||
import type { Conversation } from "../../../database/common";
|
import type { Conversation } from "../../../database/common";
|
||||||
@@ -49,6 +59,7 @@ export default function ChatPage() {
|
|||||||
const setParameters = useStore((state) => state.setParameters);
|
const setParameters = useStore((state) => state.setParameters);
|
||||||
const setLoading = useStore((state) => state.setLoading);
|
const setLoading = useStore((state) => state.setLoading);
|
||||||
const trpc = useTRPC();
|
const trpc = useTRPC();
|
||||||
|
const trpcClient = useTRPCClient();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
const messagesResult = useQuery(
|
const messagesResult = useQuery(
|
||||||
@@ -334,84 +345,95 @@ export default function ChatPage() {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
const sendMessage = useMutation(
|
// Get state from Zustand store
|
||||||
trpc.chat.sendMessage.mutationOptions({
|
const sendMessageStatus = useStore((state) => state.sendMessageStatus);
|
||||||
onMutate: async ({
|
const isSendingMessage = useStore((state) => state.isSendingMessage);
|
||||||
|
const setSendMessageStatus = useStore((state) => state.setSendMessageStatus);
|
||||||
|
const setIsSendingMessage = useStore((state) => state.setIsSendingMessage);
|
||||||
|
|
||||||
|
// Function to send message using subscription
|
||||||
|
const sendSubscriptionMessage = async ({
|
||||||
conversationId,
|
conversationId,
|
||||||
messages,
|
messages,
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
parameters,
|
parameters,
|
||||||
|
}: {
|
||||||
|
conversationId: string;
|
||||||
|
messages: Array<DraftMessage | CommittedMessage>;
|
||||||
|
systemPrompt: string;
|
||||||
|
parameters: OtherParameters;
|
||||||
}) => {
|
}) => {
|
||||||
/** Cancel affected queries that may be in-flight: */
|
setIsSendingMessage(true);
|
||||||
await queryClient.cancelQueries({
|
setSendMessageStatus(null);
|
||||||
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
|
|
||||||
conversationId,
|
try {
|
||||||
}),
|
// Create an abort controller for the subscription
|
||||||
});
|
const abortController = new AbortController();
|
||||||
/** Optimistically update the affected queries in react-query's cache: */
|
|
||||||
const previousMessages: Array<CommittedMessage> | undefined =
|
// Start the subscription
|
||||||
await queryClient.getQueryData(
|
const subscription = trpcClient.chat.sendMessage.subscribe(
|
||||||
trpc.chat.messages.fetchByConversationId.queryKey({
|
|
||||||
conversationId,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
if (!previousMessages) {
|
|
||||||
return {
|
|
||||||
previousMessages: [],
|
|
||||||
newMessages: [],
|
|
||||||
};
|
|
||||||
}
|
|
||||||
const newMessages: Array<CommittedMessage> = [
|
|
||||||
...previousMessages,
|
|
||||||
{
|
{
|
||||||
/** placeholder id; will be overwritten when we get the true id from the backend */
|
|
||||||
id: nanoid(),
|
|
||||||
conversationId,
|
conversationId,
|
||||||
// content: messages[messages.length - 1].content,
|
messages,
|
||||||
// role: "user" as const,
|
systemPrompt,
|
||||||
...messages[messages.length - 1],
|
parameters,
|
||||||
index: previousMessages.length,
|
|
||||||
createdAt: new Date().toISOString(),
|
|
||||||
} as CommittedMessage,
|
|
||||||
];
|
|
||||||
queryClient.setQueryData(
|
|
||||||
trpc.chat.messages.fetchByConversationId.queryKey({
|
|
||||||
conversationId,
|
|
||||||
}),
|
|
||||||
newMessages
|
|
||||||
);
|
|
||||||
return { previousMessages, newMessages };
|
|
||||||
},
|
},
|
||||||
onSettled: async (data, variables, context) => {
|
{
|
||||||
await queryClient.invalidateQueries({
|
signal: abortController.signal,
|
||||||
|
onData: (data) => {
|
||||||
|
setSendMessageStatus(data);
|
||||||
|
|
||||||
|
// If we've completed, update the UI and invalidate queries
|
||||||
|
if (data.status === "completed") {
|
||||||
|
setIsSendingMessage(false);
|
||||||
|
// Invalidate queries to refresh the data
|
||||||
|
queryClient.invalidateQueries({
|
||||||
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
|
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
|
||||||
conversationId,
|
conversationId,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
await queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
|
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
|
||||||
conversationId,
|
conversationId,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
await queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
|
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey(
|
||||||
|
{
|
||||||
conversationId,
|
conversationId,
|
||||||
}),
|
}
|
||||||
|
),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setSendMessageStatus(data);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
console.error("Subscription error:", error);
|
||||||
|
setIsSendingMessage(false);
|
||||||
|
setSendMessageStatus({
|
||||||
|
status: "error",
|
||||||
|
message: "An error occurred while sending the message",
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
onError: async (error, variables, context) => {
|
}
|
||||||
console.error(error);
|
|
||||||
if (!context) return;
|
|
||||||
queryClient.setQueryData(
|
|
||||||
trpc.chat.messages.fetchByConversationId.queryKey({
|
|
||||||
conversationId,
|
|
||||||
}),
|
|
||||||
context.previousMessages
|
|
||||||
);
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Return a function to unsubscribe if needed
|
||||||
|
return () => {
|
||||||
|
abortController.abort();
|
||||||
|
subscription.unsubscribe();
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to start subscription:", error);
|
||||||
|
setIsSendingMessage(false);
|
||||||
|
setSendMessageStatus({
|
||||||
|
status: "error",
|
||||||
|
message: "Failed to start message sending process",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// State for editing facts
|
// State for editing facts
|
||||||
const [editingFactId, setEditingFactId] = useState<string | null>(null);
|
const [editingFactId, setEditingFactId] = useState<string | null>(null);
|
||||||
const [editingFactContent, setEditingFactContent] = useState("");
|
const [editingFactContent, setEditingFactContent] = useState("");
|
||||||
@@ -483,6 +505,8 @@ export default function ChatPage() {
|
|||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
{isSendingMessage && <IconLoaderQuarter size={16} stroke={1.5} />}
|
||||||
|
{sendMessageStatus && <span>{sendMessageStatus.message}</span>}
|
||||||
</div>
|
</div>
|
||||||
<Tabs defaultValue="message">
|
<Tabs defaultValue="message">
|
||||||
<Tabs.List>
|
<Tabs.List>
|
||||||
@@ -504,7 +528,7 @@ export default function ChatPage() {
|
|||||||
if (e.key === "Enter") {
|
if (e.key === "Enter") {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
await sendMessage.mutateAsync({
|
await sendSubscriptionMessage({
|
||||||
conversationId,
|
conversationId,
|
||||||
messages: [
|
messages: [
|
||||||
...(messages || []),
|
...(messages || []),
|
||||||
|
|||||||
+50
-10
@@ -74,10 +74,9 @@ export const chat = router({
|
|||||||
parameters: OtherParameters;
|
parameters: OtherParameters;
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
.mutation(
|
.subscription(async function* ({
|
||||||
async ({
|
|
||||||
input: { conversationId, messages, systemPrompt, parameters },
|
input: { conversationId, messages, systemPrompt, parameters },
|
||||||
}) => {
|
}) {
|
||||||
/** TODO: Save all unsaved messages (i.e. those without an `id`) to the
|
/** TODO: Save all unsaved messages (i.e. those without an `id`) to the
|
||||||
* database. Is this dangerous? Can an attacker just send a bunch of
|
* database. Is this dangerous? Can an attacker just send a bunch of
|
||||||
* messages, omitting the ids, causing me to save a bunch of them to the
|
* messages, omitting the ids, causing me to save a bunch of them to the
|
||||||
@@ -95,6 +94,13 @@ export const chat = router({
|
|||||||
const messagesSincePreviousRunningSummary = messages.slice(
|
const messagesSincePreviousRunningSummary = messages.slice(
|
||||||
previousRunningSummaryIndex + 1
|
previousRunningSummaryIndex + 1
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "saving_user_message",
|
||||||
|
message: "Saving user message...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** Save the incoming message to the database. */
|
/** Save the incoming message to the database. */
|
||||||
const insertedUserMessage = await db.messages.create({
|
const insertedUserMessage = await db.messages.create({
|
||||||
conversationId,
|
conversationId,
|
||||||
@@ -105,6 +111,12 @@ export const chat = router({
|
|||||||
createdAt: new Date().toISOString(),
|
createdAt: new Date().toISOString(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "generating_response",
|
||||||
|
message: "Generating AI response...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** Generate a new message from the model, but hold-off on adding it to
|
/** Generate a new message from the model, but hold-off on adding it to
|
||||||
* the database until we produce the associated running-summary, below.
|
* the database until we produce the associated running-summary, below.
|
||||||
* The model should be given the conversation summary thus far, and of
|
* The model should be given the conversation summary thus far, and of
|
||||||
@@ -138,6 +150,13 @@ export const chat = router({
|
|||||||
tools: undefined,
|
tools: undefined,
|
||||||
...parameters,
|
...parameters,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "extracting_facts_from_user",
|
||||||
|
message: "Extracting facts from user message...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** Extract Facts from the user's message, and add them to the database,
|
/** Extract Facts from the user's message, and add them to the database,
|
||||||
* linking the Facts with the messages they came from. (Yes, this should
|
* linking the Facts with the messages they came from. (Yes, this should
|
||||||
* be done *after* the model response, not before; because when we run a
|
* be done *after* the model response, not before; because when we run a
|
||||||
@@ -160,6 +179,12 @@ export const chat = router({
|
|||||||
}))
|
}))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "generating_summary",
|
||||||
|
message: "Generating conversation summary...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** Produce a running summary of the conversation, and save that along
|
/** Produce a running summary of the conversation, and save that along
|
||||||
* with the model's response to the database. The new running summary is
|
* with the model's response to the database. The new running summary is
|
||||||
* based on the previous running summary combined with the all messages
|
* based on the previous running summary combined with the all messages
|
||||||
@@ -179,6 +204,13 @@ export const chat = router({
|
|||||||
index: messages.length,
|
index: messages.length,
|
||||||
createdAt: new Date().toISOString(),
|
createdAt: new Date().toISOString(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "extracting_facts_from_assistant",
|
||||||
|
message: "Extracting facts from assistant response...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** Extract Facts from the model's response, and add them to the database,
|
/** Extract Facts from the model's response, and add them to the database,
|
||||||
* linking the Facts with the messages they came from. */
|
* linking the Facts with the messages they came from. */
|
||||||
const factsFromAssistantMessageResponse =
|
const factsFromAssistantMessageResponse =
|
||||||
@@ -208,6 +240,12 @@ export const chat = router({
|
|||||||
...insertedFactsFromAssistantMessage,
|
...insertedFactsFromAssistantMessage,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Emit status update
|
||||||
|
yield {
|
||||||
|
status: "generating_fact_triggers",
|
||||||
|
message: "Generating fact triggers...",
|
||||||
|
} as const;
|
||||||
|
|
||||||
/** For each Fact produced in the two fact-extraction steps, generate
|
/** For each Fact produced in the two fact-extraction steps, generate
|
||||||
* FactTriggers and add them to the database, linking the FactTriggers
|
* FactTriggers and add them to the database, linking the FactTriggers
|
||||||
* with the Facts they came from. A FactTrigger is a natural language
|
* with the Facts they came from. A FactTrigger is a natural language
|
||||||
@@ -229,18 +267,20 @@ export const chat = router({
|
|||||||
scopeConversationId: conversationId,
|
scopeConversationId: conversationId,
|
||||||
createdAt: new Date().toISOString(),
|
createdAt: new Date().toISOString(),
|
||||||
}));
|
}));
|
||||||
db.factTriggers.createMany(insertedFactTriggers);
|
await db.factTriggers.createMany(insertedFactTriggers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// await db.write();
|
// Emit final result
|
||||||
|
yield {
|
||||||
return {
|
status: "completed",
|
||||||
|
message: "Completed!",
|
||||||
|
result: {
|
||||||
insertedAssistantMessage,
|
insertedAssistantMessage,
|
||||||
insertedUserMessage,
|
insertedUserMessage,
|
||||||
insertedFacts,
|
insertedFacts,
|
||||||
};
|
},
|
||||||
}
|
} as const;
|
||||||
),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const createCaller = createCallerFactory(chat);
|
export const createCaller = createCallerFactory(chat);
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ export const useStore = create<Store>()(
|
|||||||
facts: [],
|
facts: [],
|
||||||
factTriggers: [],
|
factTriggers: [],
|
||||||
loading: false,
|
loading: false,
|
||||||
|
sendMessageStatus: null,
|
||||||
|
isSendingMessage: false,
|
||||||
setConversationId: (conversationId) =>
|
setConversationId: (conversationId) =>
|
||||||
set((stateDraft) => {
|
set((stateDraft) => {
|
||||||
stateDraft.selectedConversationId = conversationId;
|
stateDraft.selectedConversationId = conversationId;
|
||||||
@@ -92,5 +94,13 @@ export const useStore = create<Store>()(
|
|||||||
set((stateDraft) => {
|
set((stateDraft) => {
|
||||||
stateDraft.loading = loading;
|
stateDraft.loading = loading;
|
||||||
}),
|
}),
|
||||||
|
setSendMessageStatus: (status) =>
|
||||||
|
set((stateDraft) => {
|
||||||
|
stateDraft.sendMessageStatus = status;
|
||||||
|
}),
|
||||||
|
setIsSendingMessage: (isSending) =>
|
||||||
|
set((stateDraft) => {
|
||||||
|
stateDraft.isSendingMessage = isSending;
|
||||||
|
}),
|
||||||
})),
|
})),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ export type OtherParameters = Omit<
|
|||||||
|
|
||||||
export type ConversationUI = Conversation & {};
|
export type ConversationUI = Conversation & {};
|
||||||
|
|
||||||
|
export type SendMessageStatus = {
|
||||||
|
status: string;
|
||||||
|
message: string;
|
||||||
|
result?: any;
|
||||||
|
};
|
||||||
|
|
||||||
export type Store = {
|
export type Store = {
|
||||||
/** This is a string because Milvus sends it as a string, and the value
|
/** This is a string because Milvus sends it as a string, and the value
|
||||||
* overflows the JS integer anyway. */
|
* overflows the JS integer anyway. */
|
||||||
@@ -21,6 +27,8 @@ export type Store = {
|
|||||||
facts: Array<Fact>;
|
facts: Array<Fact>;
|
||||||
factTriggers: Array<FactTrigger>;
|
factTriggers: Array<FactTrigger>;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
|
sendMessageStatus: SendMessageStatus | null;
|
||||||
|
isSendingMessage: boolean;
|
||||||
setConversationId: (conversationId: string) => void;
|
setConversationId: (conversationId: string) => void;
|
||||||
setConversationTitle: (conversationTitle: string) => void;
|
setConversationTitle: (conversationTitle: string) => void;
|
||||||
setConversations: (conversations: Array<ConversationUI>) => void;
|
setConversations: (conversations: Array<ConversationUI>) => void;
|
||||||
@@ -35,6 +43,8 @@ export type Store = {
|
|||||||
removeFact: (factId: string) => void;
|
removeFact: (factId: string) => void;
|
||||||
removeFactTrigger: (factTriggerId: string) => void;
|
removeFactTrigger: (factTriggerId: string) => void;
|
||||||
setLoading: (loading: boolean) => void;
|
setLoading: (loading: boolean) => void;
|
||||||
|
setSendMessageStatus: (status: SendMessageStatus | null) => void;
|
||||||
|
setIsSendingMessage: (isSending: boolean) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** The message while it's being typed in the input box. */
|
/** The message while it's being typed in the input box. */
|
||||||
|
|||||||
Reference in New Issue
Block a user