fix: hf hosted inference check (#1128)

This commit is contained in:
takatost
2023-09-09 00:29:48 +08:00
committed by GitHub
parent 681eb1cfcc
commit c4d8bdc3db
3 changed files with 69 additions and 4 deletions

View File

@@ -1,6 +1,5 @@
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
class HuggingfaceHubModel(BaseLLM):
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
streaming=streaming
)
else:
client = HuggingFaceHub(
client = HuggingFaceHubLLM(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
if 'baichuan' in self.name.lower():
return False
return True
return True
else:
return False