Files
Neon-Desk/lib/server/ai.ts

250 lines
6.0 KiB
TypeScript

import { embedMany, generateText } from 'ai';
import { createZhipu } from 'zhipu-ai-provider';
type AiWorkload = 'report' | 'extraction';
type AiProvider = 'zhipu';
type AiConfig = {
provider: AiProvider;
apiKey?: string;
baseUrl: string;
model: string;
temperature: number;
};
type EnvSource = Record<string, string | undefined>;
type GetAiConfigOptions = {
env?: EnvSource;
warn?: (message: string) => void;
};
type AiGenerateInput = {
model: unknown;
system?: string;
prompt: string;
temperature: number;
maxRetries?: number;
};
type AiGenerateOutput = {
text: string;
};
type AiEmbedOutput = {
embeddings: number[][];
};
type RunAiAnalysisOptions = GetAiConfigOptions & {
workload?: AiWorkload;
createModel?: (config: AiConfig) => unknown;
generate?: (input: AiGenerateInput) => Promise<AiGenerateOutput>;
};
type EmbeddingConfig = {
provider: AiProvider;
apiKey?: string;
baseUrl: string;
model: 'embedding-3';
dimensions: 256;
};
type RunAiEmbeddingsOptions = GetAiConfigOptions & {
createModel?: (config: EmbeddingConfig) => unknown;
embed?: (input: {
model: unknown;
values: string[];
}) => Promise<AiEmbedOutput>;
};
const CODING_API_BASE_URL = 'https://api.z.ai/api/coding/paas/v4';
const SEARCH_EMBEDDING_MODEL = 'embedding-3';
const SEARCH_EMBEDDING_DIMENSIONS = 256;
let warnedIgnoredZhipuBaseUrl = false;
function envValue(name: string, env: EnvSource = process.env) {
const value = env[name];
if (!value) {
return undefined;
}
const trimmed = value.trim();
return trimmed.length > 0 ? trimmed : undefined;
}
function parseTemperature(value: string | undefined) {
const parsed = Number(value);
if (!Number.isFinite(parsed)) {
return 0.2;
}
return Math.min(Math.max(parsed, 0), 2);
}
function warnIgnoredZhipuBaseUrl(env: EnvSource, warn: (message: string) => void) {
if (warnedIgnoredZhipuBaseUrl) {
return;
}
const configuredBaseUrl = envValue('ZHIPU_BASE_URL', env);
if (!configuredBaseUrl) {
return;
}
warnedIgnoredZhipuBaseUrl = true;
warn(
`[AI SDK] ZHIPU_BASE_URL is ignored. The Coding API endpoint is hardcoded to ${CODING_API_BASE_URL}.`
);
}
function defaultCreateModel(config: AiConfig) {
const zhipu = createZhipu({
apiKey: config.apiKey,
baseURL: config.baseUrl
});
return zhipu(config.model);
}
async function defaultGenerate(input: AiGenerateInput): Promise<AiGenerateOutput> {
const result = await generateText({
model: input.model as never,
system: input.system,
prompt: input.prompt,
temperature: input.temperature,
maxRetries: input.maxRetries ?? 0
});
return { text: result.text };
}
function defaultCreateEmbeddingModel(config: EmbeddingConfig) {
const zhipu = createZhipu({
apiKey: config.apiKey,
baseURL: config.baseUrl
});
return zhipu.textEmbeddingModel(config.model, {
dimensions: config.dimensions
});
}
async function defaultEmbed(input: {
model: unknown;
values: string[];
}): Promise<AiEmbedOutput> {
const result = await embedMany({
model: input.model as never,
values: input.values,
maxRetries: 0
});
return { embeddings: result.embeddings as number[][] };
}
export function getAiConfig(options?: GetAiConfigOptions) {
return getReportAiConfig(options);
}
export function getReportAiConfig(options?: GetAiConfigOptions) {
const env = options?.env ?? process.env;
warnIgnoredZhipuBaseUrl(env, options?.warn ?? console.warn);
return {
provider: 'zhipu',
apiKey: envValue('ZHIPU_API_KEY', env),
baseUrl: CODING_API_BASE_URL,
model: envValue('ZHIPU_MODEL', env) ?? 'glm-5',
temperature: parseTemperature(envValue('AI_TEMPERATURE', env))
} satisfies AiConfig;
}
export function getExtractionAiConfig(options?: GetAiConfigOptions) {
return {
...getReportAiConfig(options),
temperature: 0
};
}
export function getEmbeddingAiConfig(options?: GetAiConfigOptions) {
const env = options?.env ?? process.env;
warnIgnoredZhipuBaseUrl(env, options?.warn ?? console.warn);
return {
provider: 'zhipu',
apiKey: envValue('ZHIPU_API_KEY', env),
baseUrl: CODING_API_BASE_URL,
model: SEARCH_EMBEDDING_MODEL,
dimensions: SEARCH_EMBEDDING_DIMENSIONS
} satisfies EmbeddingConfig;
}
export function isAiConfigured(options?: GetAiConfigOptions) {
const config = getReportAiConfig(options);
return Boolean(config.apiKey);
}
export async function runAiAnalysis(prompt: string, systemPrompt?: string, options?: RunAiAnalysisOptions) {
const workload = options?.workload ?? 'report';
const config = workload === 'extraction'
? getExtractionAiConfig(options)
: getReportAiConfig(options);
if (!config.apiKey) {
throw new Error('ZHIPU_API_KEY is required for AI workloads');
}
const createModel = options?.createModel ?? defaultCreateModel;
const generate = options?.generate ?? defaultGenerate;
const model = createModel(config);
const result = await generate({
model,
system: systemPrompt,
prompt,
temperature: config.temperature,
maxRetries: 0
});
const text = result.text.trim();
if (!text) {
throw new Error('AI SDK returned an empty response');
}
return {
provider: config.provider,
model: config.model,
text
};
}
export async function runAiEmbeddings(values: string[], options?: RunAiEmbeddingsOptions) {
const sanitizedValues = values
.map((value) => value.trim())
.filter((value) => value.length > 0);
if (sanitizedValues.length === 0) {
return [];
}
const config = getEmbeddingAiConfig(options);
if (!config.apiKey) {
throw new Error('ZHIPU_API_KEY is required for AI workloads');
}
const createModel = options?.createModel ?? defaultCreateEmbeddingModel;
const embed = options?.embed ?? defaultEmbed;
const model = createModel(config);
const result = await embed({
model,
values: sanitizedValues
});
return result.embeddings.map((embedding) => embedding.map((value) => Number(value)));
}
export function __resetAiWarningsForTests() {
warnedIgnoredZhipuBaseUrl = false;
}