Files
Neon-Desk/lib/server/repos/research-copilot.ts

230 lines
7.0 KiB
TypeScript

import { and, asc, eq } from 'drizzle-orm';
import type {
ResearchCopilotCitation,
ResearchCopilotMessage,
ResearchCopilotSession,
ResearchCopilotSuggestedAction,
ResearchMemoSection,
SearchSource
} from '@/lib/types';
import { db } from '@/lib/server/db';
import {
researchCopilotMessage,
researchCopilotSession
} from '@/lib/server/db/schema';
type ResearchCopilotSessionRow = typeof researchCopilotSession.$inferSelect;
type ResearchCopilotMessageRow = typeof researchCopilotMessage.$inferSelect;
const DEFAULT_SELECTED_SOURCES: SearchSource[] = ['documents', 'filings', 'research'];
function normalizeTicker(ticker: string) {
return ticker.trim().toUpperCase();
}
function normalizeSources(value?: SearchSource[] | null) {
const unique = new Set<SearchSource>();
for (const source of value ?? DEFAULT_SELECTED_SOURCES) {
if (source === 'documents' || source === 'filings' || source === 'research') {
unique.add(source);
}
}
return unique.size > 0 ? [...unique] : [...DEFAULT_SELECTED_SOURCES];
}
function normalizePinnedArtifactIds(value?: number[] | null) {
const unique = new Set<number>();
for (const id of value ?? []) {
const normalized = Math.trunc(Number(id));
if (Number.isInteger(normalized) && normalized > 0) {
unique.add(normalized);
}
}
return [...unique];
}
function normalizeOptionalString(value?: string | null) {
const normalized = value?.trim();
return normalized ? normalized : null;
}
function toCitationArray(value: unknown): ResearchCopilotCitation[] {
return Array.isArray(value) ? value as ResearchCopilotCitation[] : [];
}
function toActionArray(value: unknown): ResearchCopilotSuggestedAction[] {
return Array.isArray(value) ? value as ResearchCopilotSuggestedAction[] : [];
}
function toFollowUps(value: unknown) {
return Array.isArray(value)
? value.filter((entry): entry is string => typeof entry === 'string' && entry.trim().length > 0)
: [];
}
function toMessage(row: ResearchCopilotMessageRow): ResearchCopilotMessage {
return {
id: row.id,
session_id: row.session_id,
user_id: row.user_id,
role: row.role,
content_markdown: row.content_markdown,
citations: toCitationArray(row.citations),
follow_ups: toFollowUps(row.follow_ups),
suggested_actions: toActionArray(row.suggested_actions),
selected_sources: normalizeSources(row.selected_sources),
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
memo_section: row.memo_section ?? null,
created_at: row.created_at
};
}
function toSession(row: ResearchCopilotSessionRow, messages: ResearchCopilotMessage[]): ResearchCopilotSession {
return {
id: row.id,
user_id: row.user_id,
ticker: row.ticker,
title: row.title ?? null,
selected_sources: normalizeSources(row.selected_sources),
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
created_at: row.created_at,
updated_at: row.updated_at,
messages
};
}
async function listMessagesForSession(sessionId: number) {
const rows = await db
.select()
.from(researchCopilotMessage)
.where(eq(researchCopilotMessage.session_id, sessionId))
.orderBy(asc(researchCopilotMessage.created_at), asc(researchCopilotMessage.id));
return rows.map(toMessage);
}
async function getSessionRowByTicker(userId: string, ticker: string) {
const [row] = await db
.select()
.from(researchCopilotSession)
.where(and(
eq(researchCopilotSession.user_id, userId),
eq(researchCopilotSession.ticker, normalizeTicker(ticker))
))
.limit(1);
return row ?? null;
}
export async function getResearchCopilotSessionByTicker(userId: string, ticker: string) {
const row = await getSessionRowByTicker(userId, ticker);
if (!row) {
return null;
}
return toSession(row, await listMessagesForSession(row.id));
}
export async function getOrCreateResearchCopilotSession(input: {
userId: string;
ticker: string;
title?: string | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
}) {
const normalizedTicker = normalizeTicker(input.ticker);
if (!normalizedTicker) {
throw new Error('ticker is required');
}
const existing = await getSessionRowByTicker(input.userId, normalizedTicker);
if (existing) {
const messages = await listMessagesForSession(existing.id);
return toSession(existing, messages);
}
const now = new Date().toISOString();
const [created] = await db
.insert(researchCopilotSession)
.values({
user_id: input.userId,
ticker: normalizedTicker,
title: normalizeOptionalString(input.title),
selected_sources: normalizeSources(input.selectedSources),
pinned_artifact_ids: normalizePinnedArtifactIds(input.pinnedArtifactIds),
created_at: now,
updated_at: now
})
.returning();
return toSession(created, []);
}
export async function upsertResearchCopilotSessionState(input: {
userId: string;
ticker: string;
title?: string | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
}) {
const session = await getOrCreateResearchCopilotSession(input);
const [updated] = await db
.update(researchCopilotSession)
.set({
title: input.title === undefined ? session.title : normalizeOptionalString(input.title),
selected_sources: input.selectedSources === undefined
? session.selected_sources
: normalizeSources(input.selectedSources),
pinned_artifact_ids: input.pinnedArtifactIds === undefined
? session.pinned_artifact_ids
: normalizePinnedArtifactIds(input.pinnedArtifactIds),
updated_at: new Date().toISOString()
})
.where(eq(researchCopilotSession.id, session.id))
.returning();
return toSession(updated, await listMessagesForSession(updated.id));
}
export async function appendResearchCopilotMessage(input: {
userId: string;
sessionId: number;
role: ResearchCopilotMessage['role'];
contentMarkdown: string;
citations?: ResearchCopilotCitation[] | null;
followUps?: string[] | null;
suggestedActions?: ResearchCopilotSuggestedAction[] | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
memoSection?: ResearchMemoSection | null;
}) {
const now = new Date().toISOString();
const [created] = await db
.insert(researchCopilotMessage)
.values({
session_id: input.sessionId,
user_id: input.userId,
role: input.role,
content_markdown: input.contentMarkdown.trim(),
citations: input.citations ?? [],
follow_ups: input.followUps ?? [],
suggested_actions: input.suggestedActions ?? [],
selected_sources: input.selectedSources ? normalizeSources(input.selectedSources) : null,
pinned_artifact_ids: input.pinnedArtifactIds ? normalizePinnedArtifactIds(input.pinnedArtifactIds) : null,
memo_section: input.memoSection ?? null,
created_at: now
})
.returning();
await db
.update(researchCopilotSession)
.set({ updated_at: now })
.where(eq(researchCopilotSession.id, input.sessionId));
return toMessage(created);
}