authz/authn on all trpc procedures

master
Avraham Sakal 3 weeks ago
parent fc70806b10
commit 0207e4fc47

@ -12,13 +12,12 @@ export const data = async (pageContext: PageContextServer) => {
openrouter: getOpenrouter(
(pageContext.env?.OPENROUTER_API_KEY || env.OPENROUTER_API_KEY) as string
),
// jwt: pageContext.,
jwt: pageContext.session?.jwt,
dbClient: getDbClient(
(pageContext.env?.POSTGRES_CONNECTION_STRING ||
env.POSTGRES_CONNECTION_STRING) as string
),
});
const [
conversation,
// messages,

@ -15,6 +15,7 @@ import type {
} from "@universal-middleware/core";
import { env } from "./env.js";
import { getDbClient } from "../database/index.js";
import { JWT } from "@auth/core/jwt";
const POSTGRES_CONNECTION_STRING =
"postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true";
@ -125,6 +126,7 @@ const authjsConfig = {
...session.user,
id: token.id as string,
},
jwt: token,
};
},
},

@ -1,33 +1,41 @@
import { TRPCError } from "@trpc/server";
import type { CommittedMessage } from "../../types";
import { router, publicProcedure, createCallerFactory } from "./server";
import { router, createCallerFactory, authProcedure } from "./server";
import { z } from "zod";
const authConversationProcedure = authProcedure
.input(z.object({ id: z.string() }))
.use(async ({ input: { id }, ctx: { dbClient, jwt }, next }) => {
const rows = await dbClient
.selectFrom("conversations")
.selectAll()
.where("id", "=", id)
.execute();
if (rows[0].userId !== jwt.id) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return next({
ctx: {
conversationRow: rows[0],
},
});
});
export const conversations = router({
fetchAll: publicProcedure.query(async ({ ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return [];
fetchAll: authProcedure.query(async ({ ctx: { dbClient, jwt } }) => {
const rows = await dbClient
.selectFrom("conversations")
.where("userId", "=", userId)
.where("userId", "=", jwt.id as string)
.selectAll()
.execute();
return rows;
}),
fetchOne: publicProcedure
.input((x) => x as { id: string })
.query(async ({ input: { id }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return null;
const row = await dbClient
.selectFrom("conversations")
.selectAll()
.where("id", "=", id)
.where("userId", "=", userId)
.execute();
return row[0];
}),
start: publicProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return null;
fetchOne: authConversationProcedure.query(
async ({ ctx: { conversationRow } }) => {
return conversationRow;
}
),
start: authProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => {
const insertedRows = await dbClient
.insertInto("conversations")
.values({
@ -38,42 +46,34 @@ export const conversations = router({
.execute();
return insertedRows[0];
}),
deleteOne: publicProcedure
.input((x) => x as { id: string })
.mutation(async ({ input: { id }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return { ok: false };
deleteOne: authConversationProcedure.mutation(
async ({ input: { id }, ctx: { dbClient, jwt } }) => {
await dbClient
.deleteFrom("conversations")
.where("id", "=", id)
.where("userId", "=", userId)
.where("userId", "=", jwt.id as string)
.execute();
return { ok: true };
}),
updateTitle: publicProcedure
}
),
updateTitle: authConversationProcedure
.input(
(x) =>
x as {
id: string;
title: string;
}
z.object({
title: z.string(),
})
)
.mutation(async ({ input: { id, title }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return { ok: false };
await dbClient
.updateTable("conversations")
.set({ title })
.where("id", "=", id)
.where("userId", "=", userId)
.where("userId", "=", jwt.id as string)
.execute();
return { ok: true };
}),
fetchMessages: publicProcedure
.input((x) => x as { conversationId: string })
fetchMessages: authProcedure
.input(z.object({ conversationId: z.string() }))
.query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return [];
const rows = await dbClient
.selectFrom("messages")
.innerJoin(
@ -83,7 +83,7 @@ export const conversations = router({
)
.selectAll("messages")
.where("conversationId", "=", conversationId)
.where("conversations.userId", "=", userId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
return rows as Array<CommittedMessage>;
}),

@ -1,8 +1,15 @@
import { router, publicProcedure, createCallerFactory } from "./server.js";
import {
router,
publicProcedure,
createCallerFactory,
authProcedure,
} from "./server.js";
import type { DraftMessage } from "../../types.js";
import { MODEL_NAME } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai";
import type { Fact } from "@database/common.js";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const factTriggersSystemPrompt = ({
previousRunningSummary,
@ -53,52 +60,71 @@ ${factContent}
Generate a list of situations in which the fact is useful.`;
const authFactTriggerProcedure = authProcedure
.input(z.object({ factTriggerId: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const factTriggerRows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("fact_triggers.id", "=", input.factTriggerId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!factTriggerRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const factTriggers = router({
fetchByFactId: publicProcedure
fetchByFactId: authProcedure
.input((x) => x as { factId: string })
.query(async ({ input: { factId }, ctx: { dbClient } }) => {
.query(async ({ input: { factId }, ctx: { dbClient, jwt } }) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin(
"conversations",
"conversations.id",
"messages.conversationId"
)
.selectAll("fact_triggers")
.where("sourceFactId", "=", factId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
return rows;
}),
fetchByConversationId: publicProcedure
fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string })
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
.query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin(
"conversations",
"conversations.id",
"messages.conversationId"
)
.selectAll("fact_triggers")
.where("messages.conversationId", "=", conversationId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
return rows;
}),
deleteOne: publicProcedure
.input(
(x) =>
x as {
factTriggerId: string;
}
)
.mutation(async ({ input: { factTriggerId }, ctx: { dbClient } }) => {
deleteOne: authFactTriggerProcedure.mutation(
async ({ input: { factTriggerId }, ctx: { dbClient, jwt } }) => {
await dbClient
.deleteFrom("fact_triggers")
.where("id", "=", factTriggerId)
.execute();
return { ok: true };
}),
update: publicProcedure
.input(
(x) =>
x as {
factTriggerId: string;
content: string;
}
)
}
),
update: authFactTriggerProcedure
.input(z.object({ content: z.string() }))
.mutation(
async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => {
await dbClient

@ -1,7 +1,9 @@
import { router, publicProcedure, createCallerFactory } from "./server.js";
import { router, createCallerFactory, authProcedure } from "./server.js";
import type { DraftMessage } from "../../types.js";
import { MODEL_NAME, openrouter } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const factsFromNewMessagesSystemPrompt = ({
previousRunningSummary,
@ -48,9 +50,36 @@ const factsFromNewMessagesUserPrompt = ({
Extract new facts from these messages.`;
const authFactProcedure = authProcedure
.input(z.object({ factId: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const factRows = await dbClient
.selectFrom("facts")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("facts.id", "=", input.factId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!factRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const facts = router({
fetchByConversationId: publicProcedure
fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string })
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const conversationRows = await dbClient
.selectFrom("conversations")
.where("id", "=", input.conversationId)
.where("userId", "=", jwt.id as string)
.execute();
if (!conversationRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
})
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
const rows = await dbClient
.selectFrom("facts")
@ -60,25 +89,14 @@ export const facts = router({
.execute();
return rows;
}),
deleteOne: publicProcedure
.input(
(x) =>
x as {
factId: string;
}
)
.mutation(async ({ input: { factId }, ctx: { dbClient } }) => {
deleteOne: authFactProcedure.mutation(
async ({ input: { factId }, ctx: { dbClient } }) => {
await dbClient.deleteFrom("facts").where("id", "=", factId).execute();
return { ok: true };
}),
update: publicProcedure
.input(
(x) =>
x as {
factId: string;
content: string;
}
)
}
),
update: authFactProcedure
.input(z.object({ content: z.string() }))
.mutation(async ({ input: { factId, content }, ctx: { dbClient } }) => {
await dbClient
.updateTable("facts")
@ -87,7 +105,7 @@ export const facts = router({
.execute();
return { ok: true };
}),
extractFromNewMessages: publicProcedure
extractFromNewMessages: authProcedure
.input(
(x) =>
x as {

@ -1,7 +1,14 @@
import { router, publicProcedure, createCallerFactory } from "./server";
import {
router,
publicProcedure,
createCallerFactory,
authProcedure,
} from "./server";
import { MODEL_NAME } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai";
import type { CommittedMessage, DraftMessage } from "../../types.js";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const runningSummarySystemPrompt = ({
previousRunningSummary,
@ -43,9 +50,35 @@ ${mainResponseContent}
Generate a new running summary of the conversation.`;
const authMessageProcedure = authProcedure
.input(z.object({ id: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const messageRows = await dbClient
.selectFrom("messages")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("messages.id", "=", input.id)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!messageRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const messages = router({
fetchByConversationId: publicProcedure
fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string })
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const conversationRows = await dbClient
.selectFrom("conversations")
.where("id", "=", input.conversationId)
.where("userId", "=", jwt.id as string)
.execute();
if (!conversationRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
})
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
const rows = (await dbClient
.selectFrom("messages")
@ -54,13 +87,13 @@ export const messages = router({
.execute()) as Array<CommittedMessage>;
return rows;
}),
deleteOne: publicProcedure
.input((x) => x as { id: string })
.mutation(async ({ input: { id }, ctx: { dbClient } }) => {
deleteOne: authMessageProcedure.mutation(
async ({ input: { id }, ctx: { dbClient } }) => {
await dbClient.deleteFrom("messages").where("id", "=", id).execute();
return { success: true };
}),
generateRunningSummary: publicProcedure
}
),
generateRunningSummary: authProcedure
.input(
(x) =>
x as {

@ -36,6 +36,20 @@ const t = initTRPC
*/
export const router = t.router;
export const publicProcedure = t.procedure;
export const authProcedure = publicProcedure.use(
async ({ ctx: { jwt }, next }) => {
if (!jwt) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!jwt.id) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
jwt.email;
return await next({
ctx: { jwt },
});
}
);
/**
* Generate a TRPC-compatible validator function given a Typebox schema.

Loading…
Cancel
Save