"""
Base classes for PitchLense MCP risk analysis tools.
Provides abstract base classes and common functionality for all risk analyzers.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from fastmcp import FastMCP
import json
from ..models.risk_models import RiskCategory, RiskLevel, StartupData
from ..utils.json_extractor import extract_json_from_response
class BaseLLM(ABC):
"""
Abstract base class for LLM integrations.
Provides a common interface for different LLM providers.
"""
def __init__(self):
"""Initialize the base LLM."""
pass
@abstractmethod
def predict(
self,
system_message: str,
user_message: str,
image_base64: Optional[str] = None
) -> Dict[str, Any]:
"""
Generate prediction from the LLM.
Args:
system_message: System instruction for the model
user_message: User's input message
image_base64: Optional base64 encoded image
Returns:
Dictionary containing the response and usage information
"""
pass
@abstractmethod
async def predict_stream(self, user_message: str):
"""
Stream predictions from the LLM.
Args:
user_message: User's input message
Yields:
Streamed response chunks
"""
pass
[docs]
class BaseRiskAnalyzer(ABC):
"""
Abstract base class for all risk analyzers.
Provides common functionality and interface for risk analysis tools.
"""
def __init__(self, llm_client, category_name: str):
"""
Initialize the base risk analyzer.
Args:
llm_client: LLM client instance for analysis
category_name: Name of the risk category
"""
self.llm_client = llm_client
self.category_name = category_name
self.risk_indicators = []
[docs]
@abstractmethod
def get_analysis_prompt(self) -> str:
"""
Get the analysis prompt for this risk category.
Returns:
String containing the analysis prompt
"""
pass
[docs]
@abstractmethod
def get_risk_indicators(self) -> List[str]:
"""
Get the list of risk indicators for this category.
Returns:
List of risk indicator names
"""
pass
[docs]
def analyze(self, startup_data: str) -> Dict[str, Any]:
"""
Perform risk analysis for the given startup data.
Args:
startup_data: String containing comprehensive startup information
Returns:
Dictionary containing risk analysis results
"""
try:
prompt = self.get_analysis_prompt()
# Format the prompt with the startup data
full_prompt = prompt.format(startup_data=startup_data)
# Use the LLM client to generate analysis
result = self.llm_client.predict(
system_message="You are an expert startup risk analyst.",
user_message=full_prompt
)
# Parse the response using the JSON extractor
response_text = result.get("response", "")
# Check if response contains error information
if "error" in response_text.lower() or "failed" in response_text.lower():
return self._create_error_response(f"LLM returned error: {response_text}")
# Extract JSON from the response
analysis_result = extract_json_from_response(response_text)
if analysis_result is not None:
return analysis_result
else:
return self._create_fallback_response(response_text, "JSON extraction failed")
except Exception as e:
return self._create_error_response(str(e))
def _create_fallback_response(self, raw_response: str, error_msg: str = "") -> Dict[str, Any]:
"""
Create a fallback response when JSON parsing fails.
Args:
raw_response: Raw response text from LLM
error_msg: JSON parsing error message
Returns:
Fallback response dictionary
"""
return {
"category_name": self.category_name,
"overall_risk_level": "unknown",
"category_score": 5,
"indicators": [],
"summary": f"Analysis completed but JSON parsing failed. Error: {error_msg}. Raw response: {raw_response[:200]}..."
}
def _create_error_response(self, error_message: str) -> Dict[str, Any]:
"""
Create an error response when analysis fails.
Args:
error_message: Error message from the exception
Returns:
Error response dictionary
"""
return {
"error": error_message,
"category_name": self.category_name,
"overall_risk_level": "unknown",
"category_score": 0,
"indicators": [],
"summary": f"Analysis failed due to error: {error_message}"
}