Add hybrid research copilot workspace
This commit is contained in:
165
lib/server/repos/research-copilot.test.ts
Normal file
165
lib/server/repos/research-copilot.test.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
import {
|
||||
afterAll,
|
||||
beforeAll,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it
|
||||
} from 'bun:test';
|
||||
import { mock } from 'bun:test';
|
||||
import { mkdtempSync, readFileSync, rmSync } from 'node:fs';
|
||||
import { tmpdir } from 'node:os';
|
||||
import { join } from 'node:path';
|
||||
import { Database } from 'bun:sqlite';
|
||||
|
||||
const TEST_USER_ID = 'copilot-user';
|
||||
|
||||
let tempDir: string | null = null;
|
||||
let sqliteClient: Database | null = null;
|
||||
let copilotRepo: typeof import('./research-copilot') | null = null;
|
||||
|
||||
async function loadRepoModule() {
|
||||
const moduleUrl = new URL(`./research-copilot.ts?test=${Date.now()}`, import.meta.url).href;
|
||||
return await import(moduleUrl) as typeof import('./research-copilot');
|
||||
}
|
||||
|
||||
function resetDbSingletons() {
|
||||
const globalState = globalThis as typeof globalThis & {
|
||||
__fiscalSqliteClient?: Database;
|
||||
__fiscalDrizzleDb?: unknown;
|
||||
};
|
||||
|
||||
globalState.__fiscalSqliteClient?.close();
|
||||
globalState.__fiscalSqliteClient = undefined;
|
||||
globalState.__fiscalDrizzleDb = undefined;
|
||||
}
|
||||
|
||||
function applyMigration(client: Database, fileName: string) {
|
||||
const sql = readFileSync(join(process.cwd(), 'drizzle', fileName), 'utf8');
|
||||
client.exec(sql);
|
||||
}
|
||||
|
||||
function ensureUser(client: Database) {
|
||||
const now = Date.now();
|
||||
client.exec(`
|
||||
INSERT OR REPLACE INTO user (id, name, email, emailVerified, image, createdAt, updatedAt, role, banned, banReason, banExpires)
|
||||
VALUES ('${TEST_USER_ID}', 'Copilot User', 'copilot@example.com', 1, NULL, ${now}, ${now}, NULL, 0, NULL, NULL);
|
||||
`);
|
||||
}
|
||||
|
||||
describe('research copilot repo', () => {
|
||||
beforeAll(async () => {
|
||||
mock.restore();
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'fiscal-copilot-repo-'));
|
||||
process.env.DATABASE_URL = `file:${join(tempDir, 'repo.sqlite')}`;
|
||||
(process.env as Record<string, string | undefined>).NODE_ENV = 'test';
|
||||
|
||||
resetDbSingletons();
|
||||
sqliteClient = new Database(join(tempDir, 'repo.sqlite'), { create: true });
|
||||
sqliteClient.exec('PRAGMA foreign_keys = ON;');
|
||||
applyMigration(sqliteClient, '0000_cold_silver_centurion.sql');
|
||||
applyMigration(sqliteClient, '0008_research_workspace.sql');
|
||||
applyMigration(sqliteClient, '0013_research_copilot.sql');
|
||||
ensureUser(sqliteClient);
|
||||
|
||||
const globalState = globalThis as typeof globalThis & {
|
||||
__fiscalSqliteClient?: Database;
|
||||
__fiscalDrizzleDb?: unknown;
|
||||
};
|
||||
globalState.__fiscalSqliteClient = sqliteClient;
|
||||
globalState.__fiscalDrizzleDb = undefined;
|
||||
|
||||
copilotRepo = await loadRepoModule();
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
mock.restore();
|
||||
sqliteClient?.close();
|
||||
resetDbSingletons();
|
||||
if (tempDir) {
|
||||
rmSync(tempDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
sqliteClient?.exec('DELETE FROM research_copilot_message;');
|
||||
sqliteClient?.exec('DELETE FROM research_copilot_session;');
|
||||
});
|
||||
|
||||
it('creates and reloads ticker-scoped sessions', async () => {
|
||||
if (!copilotRepo) {
|
||||
throw new Error('repo not initialized');
|
||||
}
|
||||
|
||||
const session = await copilotRepo.getOrCreateResearchCopilotSession({
|
||||
userId: TEST_USER_ID,
|
||||
ticker: 'msft',
|
||||
selectedSources: ['documents', 'research'],
|
||||
pinnedArtifactIds: [2, 2, 5]
|
||||
});
|
||||
|
||||
const loaded = await copilotRepo.getResearchCopilotSessionByTicker(TEST_USER_ID, 'MSFT');
|
||||
|
||||
expect(session.ticker).toBe('MSFT');
|
||||
expect(session.selected_sources).toEqual(['documents', 'research']);
|
||||
expect(session.pinned_artifact_ids).toEqual([2, 5]);
|
||||
expect(loaded?.id).toBe(session.id);
|
||||
});
|
||||
|
||||
it('appends messages and updates session state', async () => {
|
||||
if (!copilotRepo) {
|
||||
throw new Error('repo not initialized');
|
||||
}
|
||||
|
||||
const session = await copilotRepo.getOrCreateResearchCopilotSession({
|
||||
userId: TEST_USER_ID,
|
||||
ticker: 'NVDA'
|
||||
});
|
||||
|
||||
await copilotRepo.appendResearchCopilotMessage({
|
||||
userId: TEST_USER_ID,
|
||||
sessionId: session.id,
|
||||
role: 'user',
|
||||
contentMarkdown: 'What changed in the latest filing?',
|
||||
selectedSources: ['filings'],
|
||||
pinnedArtifactIds: [7],
|
||||
memoSection: 'thesis'
|
||||
});
|
||||
|
||||
await copilotRepo.appendResearchCopilotMessage({
|
||||
userId: TEST_USER_ID,
|
||||
sessionId: session.id,
|
||||
role: 'assistant',
|
||||
contentMarkdown: 'Demand remained strong [1]',
|
||||
citations: [{
|
||||
index: 1,
|
||||
label: 'NVDA 10-K [1]',
|
||||
chunkId: 1,
|
||||
href: '/filings?ticker=NVDA',
|
||||
source: 'filings',
|
||||
sourceKind: 'filing_brief',
|
||||
sourceRef: '0001',
|
||||
title: '10-K brief',
|
||||
ticker: 'NVDA',
|
||||
accessionNumber: '0001',
|
||||
filingDate: '2026-01-01',
|
||||
excerpt: 'Demand remained strong.',
|
||||
artifactId: 3
|
||||
}]
|
||||
});
|
||||
|
||||
const updated = await copilotRepo.upsertResearchCopilotSessionState({
|
||||
userId: TEST_USER_ID,
|
||||
ticker: 'NVDA',
|
||||
title: 'NVDA demand update',
|
||||
selectedSources: ['filings'],
|
||||
pinnedArtifactIds: [7]
|
||||
});
|
||||
|
||||
expect(updated.title).toBe('NVDA demand update');
|
||||
expect(updated.messages).toHaveLength(2);
|
||||
expect(updated.messages[0]?.selected_sources).toEqual(['filings']);
|
||||
expect(updated.messages[0]?.memo_section).toBe('thesis');
|
||||
expect(updated.messages[1]?.citations[0]?.artifactId).toBe(3);
|
||||
});
|
||||
});
|
||||
229
lib/server/repos/research-copilot.ts
Normal file
229
lib/server/repos/research-copilot.ts
Normal file
@@ -0,0 +1,229 @@
|
||||
import { and, asc, eq } from 'drizzle-orm';
|
||||
import type {
|
||||
ResearchCopilotCitation,
|
||||
ResearchCopilotMessage,
|
||||
ResearchCopilotSession,
|
||||
ResearchCopilotSuggestedAction,
|
||||
ResearchMemoSection,
|
||||
SearchSource
|
||||
} from '@/lib/types';
|
||||
import { db } from '@/lib/server/db';
|
||||
import {
|
||||
researchCopilotMessage,
|
||||
researchCopilotSession
|
||||
} from '@/lib/server/db/schema';
|
||||
|
||||
type ResearchCopilotSessionRow = typeof researchCopilotSession.$inferSelect;
|
||||
type ResearchCopilotMessageRow = typeof researchCopilotMessage.$inferSelect;
|
||||
|
||||
const DEFAULT_SELECTED_SOURCES: SearchSource[] = ['documents', 'filings', 'research'];
|
||||
|
||||
function normalizeTicker(ticker: string) {
|
||||
return ticker.trim().toUpperCase();
|
||||
}
|
||||
|
||||
function normalizeSources(value?: SearchSource[] | null) {
|
||||
const unique = new Set<SearchSource>();
|
||||
|
||||
for (const source of value ?? DEFAULT_SELECTED_SOURCES) {
|
||||
if (source === 'documents' || source === 'filings' || source === 'research') {
|
||||
unique.add(source);
|
||||
}
|
||||
}
|
||||
|
||||
return unique.size > 0 ? [...unique] : [...DEFAULT_SELECTED_SOURCES];
|
||||
}
|
||||
|
||||
function normalizePinnedArtifactIds(value?: number[] | null) {
|
||||
const unique = new Set<number>();
|
||||
|
||||
for (const id of value ?? []) {
|
||||
const normalized = Math.trunc(Number(id));
|
||||
if (Number.isInteger(normalized) && normalized > 0) {
|
||||
unique.add(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
return [...unique];
|
||||
}
|
||||
|
||||
function normalizeOptionalString(value?: string | null) {
|
||||
const normalized = value?.trim();
|
||||
return normalized ? normalized : null;
|
||||
}
|
||||
|
||||
function toCitationArray(value: unknown): ResearchCopilotCitation[] {
|
||||
return Array.isArray(value) ? value as ResearchCopilotCitation[] : [];
|
||||
}
|
||||
|
||||
function toActionArray(value: unknown): ResearchCopilotSuggestedAction[] {
|
||||
return Array.isArray(value) ? value as ResearchCopilotSuggestedAction[] : [];
|
||||
}
|
||||
|
||||
function toFollowUps(value: unknown) {
|
||||
return Array.isArray(value)
|
||||
? value.filter((entry): entry is string => typeof entry === 'string' && entry.trim().length > 0)
|
||||
: [];
|
||||
}
|
||||
|
||||
function toMessage(row: ResearchCopilotMessageRow): ResearchCopilotMessage {
|
||||
return {
|
||||
id: row.id,
|
||||
session_id: row.session_id,
|
||||
user_id: row.user_id,
|
||||
role: row.role,
|
||||
content_markdown: row.content_markdown,
|
||||
citations: toCitationArray(row.citations),
|
||||
follow_ups: toFollowUps(row.follow_ups),
|
||||
suggested_actions: toActionArray(row.suggested_actions),
|
||||
selected_sources: normalizeSources(row.selected_sources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
||||
memo_section: row.memo_section ?? null,
|
||||
created_at: row.created_at
|
||||
};
|
||||
}
|
||||
|
||||
function toSession(row: ResearchCopilotSessionRow, messages: ResearchCopilotMessage[]): ResearchCopilotSession {
|
||||
return {
|
||||
id: row.id,
|
||||
user_id: row.user_id,
|
||||
ticker: row.ticker,
|
||||
title: row.title ?? null,
|
||||
selected_sources: normalizeSources(row.selected_sources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
||||
created_at: row.created_at,
|
||||
updated_at: row.updated_at,
|
||||
messages
|
||||
};
|
||||
}
|
||||
|
||||
async function listMessagesForSession(sessionId: number) {
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(researchCopilotMessage)
|
||||
.where(eq(researchCopilotMessage.session_id, sessionId))
|
||||
.orderBy(asc(researchCopilotMessage.created_at), asc(researchCopilotMessage.id));
|
||||
|
||||
return rows.map(toMessage);
|
||||
}
|
||||
|
||||
async function getSessionRowByTicker(userId: string, ticker: string) {
|
||||
const [row] = await db
|
||||
.select()
|
||||
.from(researchCopilotSession)
|
||||
.where(and(
|
||||
eq(researchCopilotSession.user_id, userId),
|
||||
eq(researchCopilotSession.ticker, normalizeTicker(ticker))
|
||||
))
|
||||
.limit(1);
|
||||
|
||||
return row ?? null;
|
||||
}
|
||||
|
||||
export async function getResearchCopilotSessionByTicker(userId: string, ticker: string) {
|
||||
const row = await getSessionRowByTicker(userId, ticker);
|
||||
if (!row) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return toSession(row, await listMessagesForSession(row.id));
|
||||
}
|
||||
|
||||
export async function getOrCreateResearchCopilotSession(input: {
|
||||
userId: string;
|
||||
ticker: string;
|
||||
title?: string | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
}) {
|
||||
const normalizedTicker = normalizeTicker(input.ticker);
|
||||
if (!normalizedTicker) {
|
||||
throw new Error('ticker is required');
|
||||
}
|
||||
|
||||
const existing = await getSessionRowByTicker(input.userId, normalizedTicker);
|
||||
if (existing) {
|
||||
const messages = await listMessagesForSession(existing.id);
|
||||
return toSession(existing, messages);
|
||||
}
|
||||
|
||||
const now = new Date().toISOString();
|
||||
const [created] = await db
|
||||
.insert(researchCopilotSession)
|
||||
.values({
|
||||
user_id: input.userId,
|
||||
ticker: normalizedTicker,
|
||||
title: normalizeOptionalString(input.title),
|
||||
selected_sources: normalizeSources(input.selectedSources),
|
||||
pinned_artifact_ids: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
})
|
||||
.returning();
|
||||
|
||||
return toSession(created, []);
|
||||
}
|
||||
|
||||
export async function upsertResearchCopilotSessionState(input: {
|
||||
userId: string;
|
||||
ticker: string;
|
||||
title?: string | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
}) {
|
||||
const session = await getOrCreateResearchCopilotSession(input);
|
||||
const [updated] = await db
|
||||
.update(researchCopilotSession)
|
||||
.set({
|
||||
title: input.title === undefined ? session.title : normalizeOptionalString(input.title),
|
||||
selected_sources: input.selectedSources === undefined
|
||||
? session.selected_sources
|
||||
: normalizeSources(input.selectedSources),
|
||||
pinned_artifact_ids: input.pinnedArtifactIds === undefined
|
||||
? session.pinned_artifact_ids
|
||||
: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
||||
updated_at: new Date().toISOString()
|
||||
})
|
||||
.where(eq(researchCopilotSession.id, session.id))
|
||||
.returning();
|
||||
|
||||
return toSession(updated, await listMessagesForSession(updated.id));
|
||||
}
|
||||
|
||||
export async function appendResearchCopilotMessage(input: {
|
||||
userId: string;
|
||||
sessionId: number;
|
||||
role: ResearchCopilotMessage['role'];
|
||||
contentMarkdown: string;
|
||||
citations?: ResearchCopilotCitation[] | null;
|
||||
followUps?: string[] | null;
|
||||
suggestedActions?: ResearchCopilotSuggestedAction[] | null;
|
||||
selectedSources?: SearchSource[] | null;
|
||||
pinnedArtifactIds?: number[] | null;
|
||||
memoSection?: ResearchMemoSection | null;
|
||||
}) {
|
||||
const now = new Date().toISOString();
|
||||
const [created] = await db
|
||||
.insert(researchCopilotMessage)
|
||||
.values({
|
||||
session_id: input.sessionId,
|
||||
user_id: input.userId,
|
||||
role: input.role,
|
||||
content_markdown: input.contentMarkdown.trim(),
|
||||
citations: input.citations ?? [],
|
||||
follow_ups: input.followUps ?? [],
|
||||
suggested_actions: input.suggestedActions ?? [],
|
||||
selected_sources: input.selectedSources ? normalizeSources(input.selectedSources) : null,
|
||||
pinned_artifact_ids: input.pinnedArtifactIds ? normalizePinnedArtifactIds(input.pinnedArtifactIds) : null,
|
||||
memo_section: input.memoSection ?? null,
|
||||
created_at: now
|
||||
})
|
||||
.returning();
|
||||
|
||||
await db
|
||||
.update(researchCopilotSession)
|
||||
.set({ updated_at: now })
|
||||
.where(eq(researchCopilotSession.id, input.sessionId));
|
||||
|
||||
return toMessage(created);
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
researchMemo,
|
||||
researchMemoEvidence
|
||||
} from '@/lib/server/db/schema';
|
||||
import { getResearchCopilotSessionByTicker } from '@/lib/server/repos/research-copilot';
|
||||
import { getFilingByAccession, listFilingsRecords } from '@/lib/server/repos/filings';
|
||||
import { getWatchlistItemByTicker } from '@/lib/server/repos/watchlist';
|
||||
|
||||
@@ -374,6 +375,26 @@ async function getArtifactByIdForUser(id: number, userId: string) {
|
||||
return row ?? null;
|
||||
}
|
||||
|
||||
export async function getResearchArtifactsByIdsForUser(userId: string, ids: number[]) {
|
||||
const normalizedIds = [...new Set(ids.map((id) => Math.trunc(id)).filter((id) => Number.isInteger(id) && id > 0))];
|
||||
if (normalizedIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(researchArtifact)
|
||||
.where(and(
|
||||
eq(researchArtifact.user_id, userId),
|
||||
sql`${researchArtifact.id} in (${sql.join(normalizedIds.map((id) => sql`${id}`), sql`, `)})`
|
||||
));
|
||||
|
||||
const order = new Map(normalizedIds.map((id, index) => [id, index]));
|
||||
return rows
|
||||
.sort((left, right) => (order.get(left.id) ?? Number.MAX_SAFE_INTEGER) - (order.get(right.id) ?? Number.MAX_SAFE_INTEGER))
|
||||
.map((row) => toResearchArtifact(row));
|
||||
}
|
||||
|
||||
async function getMemoByIdForUser(id: number, userId: string) {
|
||||
const [row] = await db
|
||||
.select()
|
||||
@@ -902,12 +923,13 @@ export async function getResearchPacket(userId: string, ticker: string): Promise
|
||||
|
||||
export async function getResearchWorkspace(userId: string, ticker: string): Promise<ResearchWorkspace> {
|
||||
const normalizedTicker = normalizeTicker(ticker);
|
||||
const [coverage, memo, library, packet, latestFiling] = await Promise.all([
|
||||
const [coverage, memo, library, packet, latestFiling, copilotSession] = await Promise.all([
|
||||
getWatchlistItemByTicker(userId, normalizedTicker),
|
||||
getResearchMemoByTicker(userId, normalizedTicker),
|
||||
listResearchArtifacts(userId, { ticker: normalizedTicker, limit: 40 }),
|
||||
getResearchPacket(userId, normalizedTicker),
|
||||
listFilingsRecords({ ticker: normalizedTicker, limit: 1 })
|
||||
listFilingsRecords({ ticker: normalizedTicker, limit: 1 }),
|
||||
getResearchCopilotSessionByTicker(userId, normalizedTicker)
|
||||
]);
|
||||
|
||||
return {
|
||||
@@ -918,7 +940,8 @@ export async function getResearchWorkspace(userId: string, ticker: string): Prom
|
||||
memo,
|
||||
library: library.artifacts,
|
||||
packet,
|
||||
availableTags: library.availableTags
|
||||
availableTags: library.availableTags,
|
||||
copilotSession
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1119,4 +1142,3 @@ export async function getResearchArtifactFileResponse(userId: string, id: number
|
||||
export function rebuildResearchArtifactIndex() {
|
||||
rebuildArtifactSearchIndex();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user