Add hybrid research copilot workspace
This commit is contained in:
229
lib/server/repos/research-copilot.ts
Normal file
229
lib/server/repos/research-copilot.ts
Normal file
@@ -0,0 +1,229 @@
|
||||
import { and, asc, eq } from 'drizzle-orm';
|
||||
import type {
|
||||
ResearchCopilotCitation,
|
||||
ResearchCopilotMessage,
|
||||
ResearchCopilotSession,
|
||||
ResearchCopilotSuggestedAction,
|
||||
ResearchMemoSection,
|
||||
SearchSource
|
||||
} from '@/lib/types';
|
||||
import { db } from '@/lib/server/db';
|
||||
import {
|
||||
researchCopilotMessage,
|
||||
researchCopilotSession
|
||||
} from '@/lib/server/db/schema';
|
||||
|
||||
type ResearchCopilotSessionRow = typeof researchCopilotSession.$inferSelect;
|
||||
type ResearchCopilotMessageRow = typeof researchCopilotMessage.$inferSelect;
|
||||
|
||||
const DEFAULT_SELECTED_SOURCES: SearchSource[] = ['documents', 'filings', 'research'];
|
||||
|
||||
function normalizeTicker(ticker: string) {
|
||||
return ticker.trim().toUpperCase();
|
||||
}
|
||||
|
||||
function normalizeSources(value?: SearchSource[] | null) {
|
||||
const unique = new Set<SearchSource>();
|
||||
|
||||
for (const source of value ?? DEFAULT_SELECTED_SOURCES) {
|
||||
if (source === 'documents' || source === 'filings' || source === 'research') {
|
||||
unique.add(source);
|
||||
}
|
||||
}
|
||||
|
||||
return unique.size > 0 ? [...unique] : [...DEFAULT_SELECTED_SOURCES];
|
||||
}
|
||||
|
||||
function normalizePinnedArtifactIds(value?: number[] | null) {
|
||||
const unique = new Set<number>();
|
||||
|
||||
for (const id of value ?? []) {
|
||||
const normalized = Math.trunc(Number(id));
|
||||
if (Number.isInteger(normalized) && normalized > 0) {
|
||||
unique.add(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
return [...unique];
|
||||
}
|
||||
|
||||
function normalizeOptionalString(value?: string | null) {
|
||||
const normalized = value?.trim();
|
||||
return normalized ? normalized : null;
|
||||
}
|
||||
|
||||
function toCitationArray(value: unknown): ResearchCopilotCitation[] {
|
||||
return Array.isArray(value) ? value as ResearchCopilotCitation[] : [];
|
||||
}
|
||||
|
||||
function toActionArray(value: unknown): ResearchCopilotSuggestedAction[] {
|
||||
return Array.isArray(value) ? value as ResearchCopilotSuggestedAction[] : [];
|
||||
}
|
||||
|
||||
function toFollowUps(value: unknown) {
|
||||
return Array.isArray(value)
|
||||
? value.filter((entry): entry is string => typeof entry === 'string' && entry.trim().length > 0)
|
||||
: [];
|
||||
}
|
||||
|
||||
function toMessage(row: ResearchCopilotMessageRow): ResearchCopilotMessage {
|
||||
return {
|
||||
id: row.id,
|
||||
session_id: row.session_id,
|
||||
user_id: row.user_id,
|
||||
role: row.role,
|
||||
content_markdown: row.content_markdown,
|
||||
citations: toCitationArray(row.citations),
|
||||
follow_ups: toFollowUps(row.follow_ups),
|
||||
suggested_actions: toActionArray(row.suggested_actions),
|
||||
selected_sources: normalizeSources(row.selected_sources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
||||
memo_section: row.memo_section ?? null,
|
||||
created_at: row.created_at
|
||||
};
|
||||
}
|
||||
|
||||
function toSession(row: ResearchCopilotSessionRow, messages: ResearchCopilotMessage[]): ResearchCopilotSession {
|
||||
return {
|
||||
id: row.id,
|
||||
user_id: row.user_id,
|
||||
ticker: row.ticker,
|
||||
title: row.title ?? null,
|
||||
selected_sources: normalizeSources(row.selected_sources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
||||
created_at: row.created_at,
|
||||
updated_at: row.updated_at,
|
||||
messages
|
||||
};
|
||||
}
|
||||
|
||||
async function listMessagesForSession(sessionId: number) {
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(researchCopilotMessage)
|
||||
.where(eq(researchCopilotMessage.session_id, sessionId))
|
||||
.orderBy(asc(researchCopilotMessage.created_at), asc(researchCopilotMessage.id));
|
||||
|
||||
return rows.map(toMessage);
|
||||
}
|
||||
|
||||
async function getSessionRowByTicker(userId: string, ticker: string) {
|
||||
const [row] = await db
|
||||
.select()
|
||||
.from(researchCopilotSession)
|
||||
.where(and(
|
||||
eq(researchCopilotSession.user_id, userId),
|
||||
eq(researchCopilotSession.ticker, normalizeTicker(ticker))
|
||||
))
|
||||
.limit(1);
|
||||
|
||||
return row ?? null;
|
||||
}
|
||||
|
||||
export async function getResearchCopilotSessionByTicker(userId: string, ticker: string) {
|
||||
const row = await getSessionRowByTicker(userId, ticker);
|
||||
if (!row) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return toSession(row, await listMessagesForSession(row.id));
|
||||
}
|
||||
|
||||
export async function getOrCreateResearchCopilotSession(input: {
|
||||
userId: string;
|
||||
ticker: string;
|
||||
title?: string | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
}) {
|
||||
const normalizedTicker = normalizeTicker(input.ticker);
|
||||
if (!normalizedTicker) {
|
||||
throw new Error('ticker is required');
|
||||
}
|
||||
|
||||
const existing = await getSessionRowByTicker(input.userId, normalizedTicker);
|
||||
if (existing) {
|
||||
const messages = await listMessagesForSession(existing.id);
|
||||
return toSession(existing, messages);
|
||||
}
|
||||
|
||||
const now = new Date().toISOString();
|
||||
const [created] = await db
|
||||
.insert(researchCopilotSession)
|
||||
.values({
|
||||
user_id: input.userId,
|
||||
ticker: normalizedTicker,
|
||||
title: normalizeOptionalString(input.title),
|
||||
selected_sources: normalizeSources(input.selectedSources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
})
|
||||
.returning();
|
||||
|
||||
return toSession(created, []);
|
||||
}
|
||||
|
||||
export async function upsertResearchCopilotSessionState(input: {
|
||||
userId: string;
|
||||
ticker: string;
|
||||
title?: string | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
}) {
|
||||
const session = await getOrCreateResearchCopilotSession(input);
|
||||
const [updated] = await db
|
||||
.update(researchCopilotSession)
|
||||
.set({
|
||||
title: input.title === undefined ? session.title : normalizeOptionalString(input.title),
|
||||
selected_sources: input.selectedSources === undefined
|
||||
? session.selected_sources
|
||||
: normalizeSources(input.selectedSources),
|
||||
pinned_artifact_ids: input.pinnedArtifactIds === undefined
|
||||
? session.pinned_artifact_ids
|
||||
: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
||||
updated_at: new Date().toISOString()
|
||||
})
|
||||
.where(eq(researchCopilotSession.id, session.id))
|
||||
.returning();
|
||||
|
||||
return toSession(updated, await listMessagesForSession(updated.id));
|
||||
}
|
||||
|
||||
export async function appendResearchCopilotMessage(input: {
|
||||
userId: string;
|
||||
sessionId: number;
|
||||
role: ResearchCopilotMessage['role'];
|
||||
contentMarkdown: string;
|
||||
citations?: ResearchCopilotCitation[] | null;
|
||||
followUps?: string[] | null;
|
||||
suggestedActions?: ResearchCopilotSuggestedAction[] | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
memoSection?: ResearchMemoSection | null;
|
||||
}) {
|
||||
const now = new Date().toISOString();
|
||||
const [created] = await db
|
||||
.insert(researchCopilotMessage)
|
||||
.values({
|
||||
session_id: input.sessionId,
|
||||
user_id: input.userId,
|
||||
role: input.role,
|
||||
content_markdown: input.contentMarkdown.trim(),
|
||||
citations: input.citations ?? [],
|
||||
follow_ups: input.followUps ?? [],
|
||||
suggested_actions: input.suggestedActions ?? [],
|
||||
selected_sources: input.selectedSources ? normalizeSources(input.selectedSources) : null,
|
||||
pinned_artifact_ids: input.pinnedArtifactIds ? normalizePinnedArtifactIds(input.pinnedArtifactIds) : null,
|
||||
memo_section: input.memoSection ?? null,
|
||||
created_at: now
|
||||
})
|
||||
.returning();
|
||||
|
||||
await db
|
||||
.update(researchCopilotSession)
|
||||
.set({ updated_at: now })
|
||||
.where(eq(researchCopilotSession.id, input.sessionId));
|
||||
|
||||
return toMessage(created);
|
||||
}
|
||||
Reference in New Issue
Block a user