From a87daebe14120871d8d2c502ac7ee8fc060bbae2 Mon Sep 17 00:00:00 2001 From: YoVinchen Date: Wed, 19 Nov 2025 02:06:31 +0800 Subject: [PATCH] feat: add model name configuration support for Codex and fix Gemini model handling - Add visual model name input field for Codex providers - Add model name extraction and update utilities in providerConfigUtils - Implement model name state management in useCodexConfigState hook - Add conditional model field rendering in CodexFormFields (non-official only) - Integrate model name sync with TOML config in ProviderForm - Fix Gemini deeplink model injection bug - Correct environment variable name from GOOGLE_GEMINI_MODEL to GEMINI_MODEL - Add test cases for Gemini model injection (with/without model) - All tests passing (9/9) - Fix Gemini model field binding in edit mode - Add geminiModel state to useGeminiConfigState hook - Extract model value during initialization and reset - Sync model field with geminiEnv state to prevent data loss on submit - Fix missing model value display when editing Gemini providers Changes: - 6 files changed, 245 insertions(+), 13 deletions(-) --- src-tauri/src/deeplink.rs | 56 ++++++++++++++++- .../providers/forms/CodexFormFields.tsx | 35 +++++++++++ .../providers/forms/ProviderForm.tsx | 21 +++++-- .../forms/hooks/useCodexConfigState.ts | 60 +++++++++++++++++- .../forms/hooks/useGeminiConfigState.ts | 23 +++++-- src/utils/providerConfigUtils.ts | 63 +++++++++++++++++++ 6 files changed, 245 insertions(+), 13 deletions(-) diff --git a/src-tauri/src/deeplink.rs b/src-tauri/src/deeplink.rs index 86b6a17..b6d062f 100644 --- a/src-tauri/src/deeplink.rs +++ b/src-tauri/src/deeplink.rs @@ -302,7 +302,7 @@ requires_openai_auth = true // Add model if provided if let Some(model) = &request.model { - env.insert("GOOGLE_GEMINI_MODEL".to_string(), json!(model)); + env.insert("GEMINI_MODEL".to_string(), json!(model)); } json!({ "env": env }) @@ -400,4 +400,58 @@ mod tests { .to_string() .contains("must be http or https")); } + + #[test] + fn test_build_gemini_provider_with_model() { + let request = DeepLinkImportRequest { + version: "v1".to_string(), + resource: "provider".to_string(), + app: "gemini".to_string(), + name: "Test Gemini".to_string(), + homepage: "https://example.com".to_string(), + endpoint: "https://api.example.com".to_string(), + api_key: "test-api-key".to_string(), + model: Some("gemini-2.0-flash".to_string()), + notes: None, + }; + + let provider = build_provider_from_request(&AppType::Gemini, &request).unwrap(); + + // Verify provider basic info + assert_eq!(provider.name, "Test Gemini"); + assert_eq!( + provider.website_url, + Some("https://example.com".to_string()) + ); + + // Verify settings_config structure + let env = provider.settings_config["env"].as_object().unwrap(); + assert_eq!(env["GEMINI_API_KEY"], "test-api-key"); + assert_eq!(env["GOOGLE_GEMINI_BASE_URL"], "https://api.example.com"); + assert_eq!(env["GEMINI_MODEL"], "gemini-2.0-flash"); + } + + #[test] + fn test_build_gemini_provider_without_model() { + let request = DeepLinkImportRequest { + version: "v1".to_string(), + resource: "provider".to_string(), + app: "gemini".to_string(), + name: "Test Gemini".to_string(), + homepage: "https://example.com".to_string(), + endpoint: "https://api.example.com".to_string(), + api_key: "test-api-key".to_string(), + model: None, + notes: None, + }; + + let provider = build_provider_from_request(&AppType::Gemini, &request).unwrap(); + + // Verify settings_config structure + let env = provider.settings_config["env"].as_object().unwrap(); + assert_eq!(env["GEMINI_API_KEY"], "test-api-key"); + assert_eq!(env["GOOGLE_GEMINI_BASE_URL"], "https://api.example.com"); + // Model should not be present + assert!(env.get("GEMINI_MODEL").is_none()); + } } diff --git a/src/components/providers/forms/CodexFormFields.tsx b/src/components/providers/forms/CodexFormFields.tsx index 400642a..fefe41c 100644 --- a/src/components/providers/forms/CodexFormFields.tsx +++ b/src/components/providers/forms/CodexFormFields.tsx @@ -26,6 +26,11 @@ interface CodexFormFieldsProps { onEndpointModalToggle: (open: boolean) => void; onCustomEndpointsChange?: (endpoints: string[]) => void; + // Model Name + shouldShowModelField?: boolean; + modelName?: string; + onModelNameChange?: (model: string) => void; + // Speed Test Endpoints speedTestEndpoints: EndpointCandidate[]; } @@ -45,6 +50,9 @@ export function CodexFormFields({ isEndpointModalOpen, onEndpointModalToggle, onCustomEndpointsChange, + shouldShowModelField = true, + modelName = "", + onModelNameChange, speedTestEndpoints, }: CodexFormFieldsProps) { const { t } = useTranslation(); @@ -85,6 +93,33 @@ export function CodexFormFields({ /> )} + {/* Codex Model Name 输入框 */} + {shouldShowModelField && onModelNameChange && ( +
+ + onModelNameChange(e.target.value)} + placeholder={t("codexConfig.modelNamePlaceholder", { + defaultValue: "例如: gpt-5-codex", + })} + className="w-full px-3 py-2 border border-border-default dark:bg-gray-800 dark:text-gray-100 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500/20 dark:focus:ring-blue-400/20 transition-colors" + /> +

+ {t("codexConfig.modelNameHint", { + defaultValue: "指定使用的模型,将自动更新到 config.toml 中", + })} +

+
+ )} + {/* 端点测速弹窗 - Codex */} {shouldShowSpeedTest && isEndpointModalOpen && ( )} @@ -711,17 +718,19 @@ export function ProviderForm({ onEndpointModalToggle={setIsEndpointModalOpen} onCustomEndpointsChange={setDraftCustomEndpoints} shouldShowModelField={true} - model={ - form.watch("settingsConfig") - ? JSON.parse(form.watch("settingsConfig") || "{}")?.env - ?.GEMINI_MODEL || "" - : "" - } + model={geminiModel} onModelChange={(model) => { + // 同时更新 form.settingsConfig 和 geminiEnv const config = JSON.parse(form.watch("settingsConfig") || "{}"); if (!config.env) config.env = {}; config.env.GEMINI_MODEL = model; form.setValue("settingsConfig", JSON.stringify(config, null, 2)); + + // 同步更新 geminiEnv,确保提交时不丢失 + const envObj = envStringToObj(geminiEnv); + envObj.GEMINI_MODEL = model.trim(); + const newEnv = envObjToString(envObj); + handleGeminiEnvChange(newEnv); }} speedTestEndpoints={speedTestEndpoints} /> diff --git a/src/components/providers/forms/hooks/useCodexConfigState.ts b/src/components/providers/forms/hooks/useCodexConfigState.ts index 5f2823e..60436a0 100644 --- a/src/components/providers/forms/hooks/useCodexConfigState.ts +++ b/src/components/providers/forms/hooks/useCodexConfigState.ts @@ -2,6 +2,8 @@ import { useState, useCallback, useEffect, useRef } from "react"; import { extractCodexBaseUrl, setCodexBaseUrl as setCodexBaseUrlInConfig, + extractCodexModelName, + setCodexModelName as setCodexModelNameInConfig, } from "@/utils/providerConfigUtils"; import { normalizeTomlText } from "@/utils/textNormalization"; @@ -20,9 +22,11 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { const [codexConfig, setCodexConfigState] = useState(""); const [codexApiKey, setCodexApiKey] = useState(""); const [codexBaseUrl, setCodexBaseUrl] = useState(""); + const [codexModelName, setCodexModelName] = useState(""); const [codexAuthError, setCodexAuthError] = useState(""); const isUpdatingCodexBaseUrlRef = useRef(false); + const isUpdatingCodexModelNameRef = useRef(false); // 初始化 Codex 配置(编辑模式) useEffect(() => { @@ -47,6 +51,12 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { setCodexBaseUrl(initialBaseUrl); } + // 提取 Model Name + const initialModelName = extractCodexModelName(configStr); + if (initialModelName) { + setCodexModelName(initialModelName); + } + // 提取 API Key try { if (auth && typeof auth.OPENAI_API_KEY === "string") { @@ -69,6 +79,17 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { } }, [codexConfig, codexBaseUrl]); + // 与 TOML 配置保持模型名称同步 + useEffect(() => { + if (isUpdatingCodexModelNameRef.current) { + return; + } + const extracted = extractCodexModelName(codexConfig) || ""; + if (extracted !== codexModelName) { + setCodexModelName(extracted); + } + }, [codexConfig, codexModelName]); + // 获取 API Key(从 auth JSON) const getCodexAuthApiKey = useCallback((authString: string): string => { try { @@ -157,7 +178,26 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { [setCodexConfig], ); - // 处理 config 变化(同步 Base URL) + // 处理 Codex Model Name 变化 + const handleCodexModelNameChange = useCallback( + (modelName: string) => { + const trimmed = modelName.trim(); + setCodexModelName(trimmed); + + if (!trimmed) { + return; + } + + isUpdatingCodexModelNameRef.current = true; + setCodexConfig((prev) => setCodexModelNameInConfig(prev, trimmed)); + setTimeout(() => { + isUpdatingCodexModelNameRef.current = false; + }, 0); + }, + [setCodexConfig], + ); + + // 处理 config 变化(同步 Base URL 和 Model Name) const handleCodexConfigChange = useCallback( (value: string) => { // 归一化中文/全角/弯引号,避免 TOML 解析报错 @@ -170,8 +210,15 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { setCodexBaseUrl(extracted); } } + + if (!isUpdatingCodexModelNameRef.current) { + const extractedModel = extractCodexModelName(normalized) || ""; + if (extractedModel !== codexModelName) { + setCodexModelName(extractedModel); + } + } }, - [setCodexConfig, codexBaseUrl], + [setCodexConfig, codexBaseUrl, codexModelName], ); // 重置配置(用于预设切换) @@ -186,6 +233,13 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { setCodexBaseUrl(baseUrl); } + const modelName = extractCodexModelName(config); + if (modelName) { + setCodexModelName(modelName); + } else { + setCodexModelName(""); + } + // 提取 API Key try { if (auth && typeof auth.OPENAI_API_KEY === "string") { @@ -205,11 +259,13 @@ export function useCodexConfigState({ initialData }: UseCodexConfigStateProps) { codexConfig, codexApiKey, codexBaseUrl, + codexModelName, codexAuthError, setCodexAuth, setCodexConfig, handleCodexApiKeyChange, handleCodexBaseUrlChange, + handleCodexModelNameChange, handleCodexConfigChange, resetCodexConfig, getCodexAuthApiKey, diff --git a/src/components/providers/forms/hooks/useGeminiConfigState.ts b/src/components/providers/forms/hooks/useGeminiConfigState.ts index 4ab96e6..cad1220 100644 --- a/src/components/providers/forms/hooks/useGeminiConfigState.ts +++ b/src/components/providers/forms/hooks/useGeminiConfigState.ts @@ -17,6 +17,7 @@ export function useGeminiConfigState({ const [geminiConfig, setGeminiConfigState] = useState(""); const [geminiApiKey, setGeminiApiKey] = useState(""); const [geminiBaseUrl, setGeminiBaseUrl] = useState(""); + const [geminiModel, setGeminiModel] = useState(""); const [envError, setEnvError] = useState(""); const [configError, setConfigError] = useState(""); @@ -72,21 +73,25 @@ export function useGeminiConfigState({ const configObj = (config as any).config || {}; setGeminiConfigState(JSON.stringify(configObj, null, 2)); - // 提取 API Key 和 Base URL + // 提取 API Key、Base URL 和 Model if (typeof env.GEMINI_API_KEY === "string") { setGeminiApiKey(env.GEMINI_API_KEY); } if (typeof env.GOOGLE_GEMINI_BASE_URL === "string") { setGeminiBaseUrl(env.GOOGLE_GEMINI_BASE_URL); } + if (typeof env.GEMINI_MODEL === "string") { + setGeminiModel(env.GEMINI_MODEL); + } } }, [initialData, envObjToString]); - // 从 geminiEnv 中提取并同步 API Key 和 Base URL + // 从 geminiEnv 中提取并同步 API Key、Base URL 和 Model useEffect(() => { const envObj = envStringToObj(geminiEnv); const extractedKey = envObj.GEMINI_API_KEY || ""; const extractedBaseUrl = envObj.GOOGLE_GEMINI_BASE_URL || ""; + const extractedModel = envObj.GEMINI_MODEL || ""; if (extractedKey !== geminiApiKey) { setGeminiApiKey(extractedKey); @@ -94,7 +99,10 @@ export function useGeminiConfigState({ if (extractedBaseUrl !== geminiBaseUrl) { setGeminiBaseUrl(extractedBaseUrl); } - }, [geminiEnv, envStringToObj]); + if (extractedModel !== geminiModel) { + setGeminiModel(extractedModel); + } + }, [geminiEnv, envStringToObj, geminiApiKey, geminiBaseUrl, geminiModel]); // 验证 Gemini Config JSON const validateGeminiConfig = useCallback((value: string): string => { @@ -181,7 +189,7 @@ export function useGeminiConfigState({ setGeminiEnv(envString); setGeminiConfig(configString); - // 提取 API Key 和 Base URL + // 提取 API Key、Base URL 和 Model if (typeof env.GEMINI_API_KEY === "string") { setGeminiApiKey(env.GEMINI_API_KEY); } else { @@ -193,6 +201,12 @@ export function useGeminiConfigState({ } else { setGeminiBaseUrl(""); } + + if (typeof env.GEMINI_MODEL === "string") { + setGeminiModel(env.GEMINI_MODEL); + } else { + setGeminiModel(""); + } }, [envObjToString, setGeminiEnv, setGeminiConfig], ); @@ -202,6 +216,7 @@ export function useGeminiConfigState({ geminiConfig, geminiApiKey, geminiBaseUrl, + geminiModel, envError, configError, setGeminiEnv, diff --git a/src/utils/providerConfigUtils.ts b/src/utils/providerConfigUtils.ts index 7973b31..e2e73a1 100644 --- a/src/utils/providerConfigUtils.ts +++ b/src/utils/providerConfigUtils.ts @@ -467,3 +467,66 @@ export const setCodexBaseUrl = ( : normalizedText; return `${prefix}${replacementLine}\n`; }; + +// ========== Codex model name utils ========== + +// 从 Codex 的 TOML 配置文本中提取 model 字段(支持单/双引号) +export const extractCodexModelName = ( + configText: string | undefined | null, +): string | undefined => { + try { + const raw = typeof configText === "string" ? configText : ""; + // 归一化中文/全角引号,避免正则提取失败 + const text = normalizeQuotes(raw); + if (!text) return undefined; + + // 匹配 model = "xxx" 或 model = 'xxx' + const m = text.match(/^model\s*=\s*(['"])([^'"]+)\1/m); + return m && m[2] ? m[2] : undefined; + } catch { + return undefined; + } +}; + +// 在 Codex 的 TOML 配置文本中写入或更新 model 字段 +export const setCodexModelName = ( + configText: string, + modelName: string, +): string => { + const trimmed = modelName.trim(); + if (!trimmed) { + return configText; + } + + // 归一化原文本中的引号(既能匹配,也能输出稳定格式) + const normalizedText = normalizeQuotes(configText); + + const replacementLine = `model = "${trimmed}"`; + const pattern = /^model\s*=\s*["']([^"']+)["']/m; + + if (pattern.test(normalizedText)) { + return normalizedText.replace(pattern, replacementLine); + } + + // 如果不存在 model 字段,尝试在 model_provider 之后插入 + // 如果 model_provider 也不存在,则插入到开头 + const providerPattern = /^model_provider\s*=\s*["'][^"']+["']/m; + const match = normalizedText.match(providerPattern); + + if (match && match.index !== undefined) { + // 在 model_provider 行之后插入 + const endOfLine = normalizedText.indexOf("\n", match.index); + if (endOfLine !== -1) { + return ( + normalizedText.slice(0, endOfLine + 1) + + replacementLine + + "\n" + + normalizedText.slice(endOfLine + 1) + ); + } + } + + // 在文件开头插入 + const lines = normalizedText.split("\n"); + return `${replacementLine}\n${lines.join("\n")}`; +};