authz/authn on all trpc procedures
This commit is contained in:
@@ -15,6 +15,7 @@ import type {
|
||||
} from "@universal-middleware/core";
|
||||
import { env } from "./env.js";
|
||||
import { getDbClient } from "../database/index.js";
|
||||
import { JWT } from "@auth/core/jwt";
|
||||
|
||||
const POSTGRES_CONNECTION_STRING =
|
||||
"postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true";
|
||||
@@ -125,6 +126,7 @@ const authjsConfig = {
|
||||
...session.user,
|
||||
id: token.id as string,
|
||||
},
|
||||
jwt: token,
|
||||
};
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,33 +1,41 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import type { CommittedMessage } from "../../types";
|
||||
import { router, publicProcedure, createCallerFactory } from "./server";
|
||||
import { router, createCallerFactory, authProcedure } from "./server";
|
||||
import { z } from "zod";
|
||||
|
||||
export const conversations = router({
|
||||
fetchAll: publicProcedure.query(async ({ ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return [];
|
||||
const authConversationProcedure = authProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.use(async ({ input: { id }, ctx: { dbClient, jwt }, next }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("conversations")
|
||||
.where("userId", "=", userId)
|
||||
.selectAll()
|
||||
.where("id", "=", id)
|
||||
.execute();
|
||||
if (rows[0].userId !== jwt.id) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return next({
|
||||
ctx: {
|
||||
conversationRow: rows[0],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
export const conversations = router({
|
||||
fetchAll: authProcedure.query(async ({ ctx: { dbClient, jwt } }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("conversations")
|
||||
.where("userId", "=", jwt.id as string)
|
||||
.selectAll()
|
||||
.execute();
|
||||
return rows;
|
||||
}),
|
||||
fetchOne: publicProcedure
|
||||
.input((x) => x as { id: string })
|
||||
.query(async ({ input: { id }, ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return null;
|
||||
const row = await dbClient
|
||||
.selectFrom("conversations")
|
||||
.selectAll()
|
||||
.where("id", "=", id)
|
||||
.where("userId", "=", userId)
|
||||
.execute();
|
||||
return row[0];
|
||||
}),
|
||||
start: publicProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return null;
|
||||
fetchOne: authConversationProcedure.query(
|
||||
async ({ ctx: { conversationRow } }) => {
|
||||
return conversationRow;
|
||||
}
|
||||
),
|
||||
start: authProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => {
|
||||
const insertedRows = await dbClient
|
||||
.insertInto("conversations")
|
||||
.values({
|
||||
@@ -38,42 +46,34 @@ export const conversations = router({
|
||||
.execute();
|
||||
return insertedRows[0];
|
||||
}),
|
||||
deleteOne: publicProcedure
|
||||
.input((x) => x as { id: string })
|
||||
.mutation(async ({ input: { id }, ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return { ok: false };
|
||||
deleteOne: authConversationProcedure.mutation(
|
||||
async ({ input: { id }, ctx: { dbClient, jwt } }) => {
|
||||
await dbClient
|
||||
.deleteFrom("conversations")
|
||||
.where("id", "=", id)
|
||||
.where("userId", "=", userId)
|
||||
.where("userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
return { ok: true };
|
||||
}),
|
||||
updateTitle: publicProcedure
|
||||
}
|
||||
),
|
||||
updateTitle: authConversationProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
id: string;
|
||||
title: string;
|
||||
}
|
||||
z.object({
|
||||
title: z.string(),
|
||||
})
|
||||
)
|
||||
.mutation(async ({ input: { id, title }, ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return { ok: false };
|
||||
await dbClient
|
||||
.updateTable("conversations")
|
||||
.set({ title })
|
||||
.where("id", "=", id)
|
||||
.where("userId", "=", userId)
|
||||
.where("userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
return { ok: true };
|
||||
}),
|
||||
fetchMessages: publicProcedure
|
||||
.input((x) => x as { conversationId: string })
|
||||
fetchMessages: authProcedure
|
||||
.input(z.object({ conversationId: z.string() }))
|
||||
.query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
|
||||
const userId = jwt?.id as string | null;
|
||||
if (!userId) return [];
|
||||
const rows = await dbClient
|
||||
.selectFrom("messages")
|
||||
.innerJoin(
|
||||
@@ -83,7 +83,7 @@ export const conversations = router({
|
||||
)
|
||||
.selectAll("messages")
|
||||
.where("conversationId", "=", conversationId)
|
||||
.where("conversations.userId", "=", userId)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
return rows as Array<CommittedMessage>;
|
||||
}),
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import { router, publicProcedure, createCallerFactory } from "./server.js";
|
||||
import {
|
||||
router,
|
||||
publicProcedure,
|
||||
createCallerFactory,
|
||||
authProcedure,
|
||||
} from "./server.js";
|
||||
import type { DraftMessage } from "../../types.js";
|
||||
import { MODEL_NAME } from "../provider.js";
|
||||
import { generateObject, generateText, jsonSchema } from "ai";
|
||||
import type { Fact } from "@database/common.js";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { z } from "zod";
|
||||
|
||||
const factTriggersSystemPrompt = ({
|
||||
previousRunningSummary,
|
||||
@@ -53,52 +60,71 @@ ${factContent}
|
||||
|
||||
Generate a list of situations in which the fact is useful.`;
|
||||
|
||||
const authFactTriggerProcedure = authProcedure
|
||||
.input(z.object({ factTriggerId: z.string() }))
|
||||
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
|
||||
const factTriggerRows = await dbClient
|
||||
.selectFrom("fact_triggers")
|
||||
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
|
||||
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
|
||||
.innerJoin("conversations", "conversations.id", "messages.conversationId")
|
||||
.where("fact_triggers.id", "=", input.factTriggerId)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
if (!factTriggerRows.length) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return await next();
|
||||
});
|
||||
|
||||
export const factTriggers = router({
|
||||
fetchByFactId: publicProcedure
|
||||
fetchByFactId: authProcedure
|
||||
.input((x) => x as { factId: string })
|
||||
.query(async ({ input: { factId }, ctx: { dbClient } }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("fact_triggers")
|
||||
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
|
||||
.selectAll("fact_triggers")
|
||||
.where("sourceFactId", "=", factId)
|
||||
.execute();
|
||||
return rows;
|
||||
}),
|
||||
fetchByConversationId: publicProcedure
|
||||
.input((x) => x as { conversationId: string })
|
||||
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
|
||||
.query(async ({ input: { factId }, ctx: { dbClient, jwt } }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("fact_triggers")
|
||||
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
|
||||
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
|
||||
.innerJoin(
|
||||
"conversations",
|
||||
"conversations.id",
|
||||
"messages.conversationId"
|
||||
)
|
||||
.selectAll("fact_triggers")
|
||||
.where("messages.conversationId", "=", conversationId)
|
||||
.where("sourceFactId", "=", factId)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
return rows;
|
||||
}),
|
||||
deleteOne: publicProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
factTriggerId: string;
|
||||
}
|
||||
)
|
||||
.mutation(async ({ input: { factTriggerId }, ctx: { dbClient } }) => {
|
||||
fetchByConversationId: authProcedure
|
||||
.input((x) => x as { conversationId: string })
|
||||
.query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("fact_triggers")
|
||||
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
|
||||
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
|
||||
.innerJoin(
|
||||
"conversations",
|
||||
"conversations.id",
|
||||
"messages.conversationId"
|
||||
)
|
||||
.selectAll("fact_triggers")
|
||||
.where("messages.conversationId", "=", conversationId)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
return rows;
|
||||
}),
|
||||
deleteOne: authFactTriggerProcedure.mutation(
|
||||
async ({ input: { factTriggerId }, ctx: { dbClient, jwt } }) => {
|
||||
await dbClient
|
||||
.deleteFrom("fact_triggers")
|
||||
.where("id", "=", factTriggerId)
|
||||
.execute();
|
||||
return { ok: true };
|
||||
}),
|
||||
update: publicProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
factTriggerId: string;
|
||||
content: string;
|
||||
}
|
||||
)
|
||||
}
|
||||
),
|
||||
update: authFactTriggerProcedure
|
||||
.input(z.object({ content: z.string() }))
|
||||
.mutation(
|
||||
async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => {
|
||||
await dbClient
|
||||
|
||||
+38
-20
@@ -1,7 +1,9 @@
|
||||
import { router, publicProcedure, createCallerFactory } from "./server.js";
|
||||
import { router, createCallerFactory, authProcedure } from "./server.js";
|
||||
import type { DraftMessage } from "../../types.js";
|
||||
import { MODEL_NAME, openrouter } from "../provider.js";
|
||||
import { generateObject, generateText, jsonSchema } from "ai";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { z } from "zod";
|
||||
|
||||
const factsFromNewMessagesSystemPrompt = ({
|
||||
previousRunningSummary,
|
||||
@@ -48,9 +50,36 @@ const factsFromNewMessagesUserPrompt = ({
|
||||
|
||||
Extract new facts from these messages.`;
|
||||
|
||||
const authFactProcedure = authProcedure
|
||||
.input(z.object({ factId: z.string() }))
|
||||
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
|
||||
const factRows = await dbClient
|
||||
.selectFrom("facts")
|
||||
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
|
||||
.innerJoin("conversations", "conversations.id", "messages.conversationId")
|
||||
.where("facts.id", "=", input.factId)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
if (!factRows.length) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return await next();
|
||||
});
|
||||
|
||||
export const facts = router({
|
||||
fetchByConversationId: publicProcedure
|
||||
fetchByConversationId: authProcedure
|
||||
.input((x) => x as { conversationId: string })
|
||||
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
|
||||
const conversationRows = await dbClient
|
||||
.selectFrom("conversations")
|
||||
.where("id", "=", input.conversationId)
|
||||
.where("userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
if (!conversationRows.length) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return await next();
|
||||
})
|
||||
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
|
||||
const rows = await dbClient
|
||||
.selectFrom("facts")
|
||||
@@ -60,25 +89,14 @@ export const facts = router({
|
||||
.execute();
|
||||
return rows;
|
||||
}),
|
||||
deleteOne: publicProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
factId: string;
|
||||
}
|
||||
)
|
||||
.mutation(async ({ input: { factId }, ctx: { dbClient } }) => {
|
||||
deleteOne: authFactProcedure.mutation(
|
||||
async ({ input: { factId }, ctx: { dbClient } }) => {
|
||||
await dbClient.deleteFrom("facts").where("id", "=", factId).execute();
|
||||
return { ok: true };
|
||||
}),
|
||||
update: publicProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
factId: string;
|
||||
content: string;
|
||||
}
|
||||
)
|
||||
}
|
||||
),
|
||||
update: authFactProcedure
|
||||
.input(z.object({ content: z.string() }))
|
||||
.mutation(async ({ input: { factId, content }, ctx: { dbClient } }) => {
|
||||
await dbClient
|
||||
.updateTable("facts")
|
||||
@@ -87,7 +105,7 @@ export const facts = router({
|
||||
.execute();
|
||||
return { ok: true };
|
||||
}),
|
||||
extractFromNewMessages: publicProcedure
|
||||
extractFromNewMessages: authProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
|
||||
+40
-7
@@ -1,7 +1,14 @@
|
||||
import { router, publicProcedure, createCallerFactory } from "./server";
|
||||
import {
|
||||
router,
|
||||
publicProcedure,
|
||||
createCallerFactory,
|
||||
authProcedure,
|
||||
} from "./server";
|
||||
import { MODEL_NAME } from "../provider.js";
|
||||
import { generateObject, generateText, jsonSchema } from "ai";
|
||||
import type { CommittedMessage, DraftMessage } from "../../types.js";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { z } from "zod";
|
||||
|
||||
const runningSummarySystemPrompt = ({
|
||||
previousRunningSummary,
|
||||
@@ -43,9 +50,35 @@ ${mainResponseContent}
|
||||
|
||||
Generate a new running summary of the conversation.`;
|
||||
|
||||
const authMessageProcedure = authProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
|
||||
const messageRows = await dbClient
|
||||
.selectFrom("messages")
|
||||
.innerJoin("conversations", "conversations.id", "messages.conversationId")
|
||||
.where("messages.id", "=", input.id)
|
||||
.where("conversations.userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
if (!messageRows.length) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return await next();
|
||||
});
|
||||
|
||||
export const messages = router({
|
||||
fetchByConversationId: publicProcedure
|
||||
fetchByConversationId: authProcedure
|
||||
.input((x) => x as { conversationId: string })
|
||||
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
|
||||
const conversationRows = await dbClient
|
||||
.selectFrom("conversations")
|
||||
.where("id", "=", input.conversationId)
|
||||
.where("userId", "=", jwt.id as string)
|
||||
.execute();
|
||||
if (!conversationRows.length) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return await next();
|
||||
})
|
||||
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
|
||||
const rows = (await dbClient
|
||||
.selectFrom("messages")
|
||||
@@ -54,13 +87,13 @@ export const messages = router({
|
||||
.execute()) as Array<CommittedMessage>;
|
||||
return rows;
|
||||
}),
|
||||
deleteOne: publicProcedure
|
||||
.input((x) => x as { id: string })
|
||||
.mutation(async ({ input: { id }, ctx: { dbClient } }) => {
|
||||
deleteOne: authMessageProcedure.mutation(
|
||||
async ({ input: { id }, ctx: { dbClient } }) => {
|
||||
await dbClient.deleteFrom("messages").where("id", "=", id).execute();
|
||||
return { success: true };
|
||||
}),
|
||||
generateRunningSummary: publicProcedure
|
||||
}
|
||||
),
|
||||
generateRunningSummary: authProcedure
|
||||
.input(
|
||||
(x) =>
|
||||
x as {
|
||||
|
||||
@@ -36,6 +36,20 @@ const t = initTRPC
|
||||
*/
|
||||
export const router = t.router;
|
||||
export const publicProcedure = t.procedure;
|
||||
export const authProcedure = publicProcedure.use(
|
||||
async ({ ctx: { jwt }, next }) => {
|
||||
if (!jwt) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
if (!jwt.id) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
jwt.email;
|
||||
return await next({
|
||||
ctx: { jwt },
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Generate a TRPC-compatible validator function given a Typebox schema.
|
||||
|
||||
Reference in New Issue
Block a user