mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
merge: resolve conflict with main in transcription.py
Keep _post_transcription_with_retry from PR branch, drop inline httpx calls that were replaced by the shared retry helper.
This commit is contained in:
commit
40b4e01b13
@ -1131,3 +1131,23 @@ Disabled skills are excluded from the main agent's skill summary, from always-on
|
|||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
|--------|---------|-------------|
|
|--------|---------|-------------|
|
||||||
| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. |
|
| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. |
|
||||||
|
|
||||||
|
## Tool Hint Max Length
|
||||||
|
|
||||||
|
Tool hints are the short progress messages shown when the agent calls tools (e.g. `$ cd …/project && npm test`). By default, these are truncated at 40 characters, which can make long commands hard to read.
|
||||||
|
|
||||||
|
Set `agents.defaults.toolHintMaxLength` to control the truncation threshold:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"toolHintMaxLength": 120
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| `agents.defaults.toolHintMaxLength` | `40` | Maximum characters for tool hint display. Range: 20–500. Higher values show more of the command or path; lower values keep hints compact. |
|
||||||
|
|||||||
@ -112,6 +112,11 @@ class _LoopHook(AgentHook):
|
|||||||
|
|
||||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
self._loop._current_iteration = context.iteration
|
self._loop._current_iteration = context.iteration
|
||||||
|
logger.debug(
|
||||||
|
"Starting agent loop iteration {} for session {}",
|
||||||
|
context.iteration,
|
||||||
|
self._session_key,
|
||||||
|
)
|
||||||
|
|
||||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||||
if self._on_progress:
|
if self._on_progress:
|
||||||
@ -193,6 +198,7 @@ class AgentLoop:
|
|||||||
context_block_limit: int | None = None,
|
context_block_limit: int | None = None,
|
||||||
max_tool_result_chars: int | None = None,
|
max_tool_result_chars: int | None = None,
|
||||||
provider_retry_mode: str = "standard",
|
provider_retry_mode: str = "standard",
|
||||||
|
tool_hint_max_length: int | None = None,
|
||||||
web_config: WebToolsConfig | None = None,
|
web_config: WebToolsConfig | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: CronService | None = None,
|
cron_service: CronService | None = None,
|
||||||
@ -237,6 +243,10 @@ class AgentLoop:
|
|||||||
else defaults.max_tool_result_chars
|
else defaults.max_tool_result_chars
|
||||||
)
|
)
|
||||||
self.provider_retry_mode = provider_retry_mode
|
self.provider_retry_mode = provider_retry_mode
|
||||||
|
self.tool_hint_max_length = (
|
||||||
|
tool_hint_max_length if tool_hint_max_length is not None
|
||||||
|
else defaults.tool_hint_max_length
|
||||||
|
)
|
||||||
self.web_config = web_config or WebToolsConfig()
|
self.web_config = web_config or WebToolsConfig()
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
@ -417,7 +427,7 @@ class AgentLoop:
|
|||||||
logger.warning("MCP connection cancelled (will retry next message)")
|
logger.warning("MCP connection cancelled (will retry next message)")
|
||||||
self._mcp_stacks.clear()
|
self._mcp_stacks.clear()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.warning("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
self._mcp_stacks.clear()
|
self._mcp_stacks.clear()
|
||||||
finally:
|
finally:
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
@ -466,12 +476,11 @@ class AgentLoop:
|
|||||||
"""Return the chat id shown in runtime metadata for the model."""
|
"""Return the chat id shown in runtime metadata for the model."""
|
||||||
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
||||||
|
|
||||||
@staticmethod
|
def _tool_hint(self, tool_calls: list) -> str:
|
||||||
def _tool_hint(tool_calls: list) -> str:
|
|
||||||
"""Format tool calls as concise hints with smart abbreviation."""
|
"""Format tool calls as concise hints with smart abbreviation."""
|
||||||
from nanobot.utils.tool_hints import format_tool_hints
|
from nanobot.utils.tool_hints import format_tool_hints
|
||||||
|
|
||||||
return format_tool_hints(tool_calls)
|
return format_tool_hints(tool_calls, max_length=self.tool_hint_max_length)
|
||||||
|
|
||||||
async def _dispatch_command_inline(
|
async def _dispatch_command_inline(
|
||||||
self,
|
self,
|
||||||
@ -644,6 +653,7 @@ class AgentLoop:
|
|||||||
context_block_limit=self.context_block_limit,
|
context_block_limit=self.context_block_limit,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
progress_callback=on_progress,
|
progress_callback=on_progress,
|
||||||
|
stream_progress_deltas=on_stream is not None,
|
||||||
retry_wait_callback=on_retry_wait,
|
retry_wait_callback=on_retry_wait,
|
||||||
checkpoint_callback=_checkpoint,
|
checkpoint_callback=_checkpoint,
|
||||||
injection_callback=_drain_pending,
|
injection_callback=_drain_pending,
|
||||||
@ -907,6 +917,8 @@ class AgentLoop:
|
|||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
session, pending = self.auto_compact.prepare_session(session, key)
|
session, pending = self.auto_compact.prepare_session(session, key)
|
||||||
|
if pending:
|
||||||
|
logger.info("Memory compact triggered for session {}", key)
|
||||||
|
|
||||||
await self.consolidator.maybe_consolidate_by_tokens(
|
await self.consolidator.maybe_consolidate_by_tokens(
|
||||||
session,
|
session,
|
||||||
@ -919,6 +931,7 @@ class AgentLoop:
|
|||||||
# LLM via the merged prompt. See _persist_subagent_followup.
|
# LLM via the merged prompt. See _persist_subagent_followup.
|
||||||
is_subagent = msg.sender_id == "subagent"
|
is_subagent = msg.sender_id == "subagent"
|
||||||
if is_subagent and self._persist_subagent_followup(session, msg):
|
if is_subagent and self._persist_subagent_followup(session, msg):
|
||||||
|
logger.debug("Subagent result persisted for session {}", key)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._set_tool_context(
|
self._set_tool_context(
|
||||||
channel, chat_id, msg.metadata.get("message_id"),
|
channel, chat_id, msg.metadata.get("message_id"),
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class AgentRunSpec:
|
|||||||
context_block_limit: int | None = None
|
context_block_limit: int | None = None
|
||||||
provider_retry_mode: str = "standard"
|
provider_retry_mode: str = "standard"
|
||||||
progress_callback: Any | None = None
|
progress_callback: Any | None = None
|
||||||
|
stream_progress_deltas: bool = True
|
||||||
retry_wait_callback: Any | None = None
|
retry_wait_callback: Any | None = None
|
||||||
checkpoint_callback: Any | None = None
|
checkpoint_callback: Any | None = None
|
||||||
injection_callback: Any | None = None
|
injection_callback: Any | None = None
|
||||||
@ -261,12 +262,11 @@ class AgentRunner:
|
|||||||
# Snipping may have created new orphans; clean them up.
|
# Snipping may have created new orphans; clean them up.
|
||||||
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
|
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
|
||||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||||
except Exception as exc:
|
except Exception:
|
||||||
logger.warning(
|
logger.exception(
|
||||||
"Context governance failed on turn {} for {}: {}; applying minimal repair",
|
"Context governance failed on turn {} for {}; applying minimal repair",
|
||||||
iteration,
|
iteration,
|
||||||
spec.session_key or "default",
|
spec.session_key or "default",
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||||
@ -616,6 +616,7 @@ class AgentRunner:
|
|||||||
wants_streaming = hook.wants_streaming()
|
wants_streaming = hook.wants_streaming()
|
||||||
wants_progress_streaming = (
|
wants_progress_streaming = (
|
||||||
not wants_streaming
|
not wants_streaming
|
||||||
|
and spec.stream_progress_deltas
|
||||||
and spec.progress_callback is not None
|
and spec.progress_callback is not None
|
||||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||||
)
|
)
|
||||||
@ -981,12 +982,11 @@ class AgentRunner:
|
|||||||
result,
|
result,
|
||||||
max_chars=spec.max_tool_result_chars,
|
max_chars=spec.max_tool_result_chars,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception:
|
||||||
logger.warning(
|
logger.exception(
|
||||||
"Tool result persist failed for {} in {}: {}; using raw result",
|
"Tool result persist failed for {} in {}; using raw result",
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
spec.session_key or "default",
|
spec.session_key or "default",
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
content = result
|
content = result
|
||||||
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
||||||
|
|||||||
@ -250,7 +250,7 @@ class SubagentManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
status.phase = "error"
|
status.phase = "error"
|
||||||
status.error = str(e)
|
status.error = str(e)
|
||||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
logger.exception("Subagent [{}] failed", task_id)
|
||||||
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error", origin_message_id)
|
await self._announce_result(task_id, label, task, f"Error: {e}", origin, "error", origin_message_id)
|
||||||
|
|
||||||
async def _announce_result(
|
async def _announce_result(
|
||||||
|
|||||||
@ -198,11 +198,10 @@ class MCPToolWrapper(Tool):
|
|||||||
await asyncio.sleep(1) # Brief backoff before retry
|
await asyncio.sleep(1) # Brief backoff before retry
|
||||||
continue
|
continue
|
||||||
# Second transient failure — give up with retry-specific message
|
# Second transient failure — give up with retry-specific message
|
||||||
logger.error(
|
logger.exception(
|
||||||
"MCP tool '{}' failed after retry: {}: {}",
|
"MCP tool '{}' failed after retry: {}",
|
||||||
self._name,
|
self._name,
|
||||||
type(exc).__name__,
|
type(exc).__name__,
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
return f"(MCP tool call failed after retry: {type(exc).__name__})"
|
return f"(MCP tool call failed after retry: {type(exc).__name__})"
|
||||||
logger.exception(
|
logger.exception(
|
||||||
@ -287,11 +286,10 @@ class MCPResourceWrapper(Tool):
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
logger.error(
|
logger.exception(
|
||||||
"MCP resource '{}' failed after retry: {}: {}",
|
"MCP resource '{}' failed after retry: {}",
|
||||||
self._name,
|
self._name,
|
||||||
type(exc).__name__,
|
type(exc).__name__,
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
return f"(MCP resource read failed after retry: {type(exc).__name__})"
|
return f"(MCP resource read failed after retry: {type(exc).__name__})"
|
||||||
logger.exception(
|
logger.exception(
|
||||||
@ -383,7 +381,7 @@ class MCPPromptWrapper(Tool):
|
|||||||
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
|
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
|
||||||
return "(MCP prompt call was cancelled)"
|
return "(MCP prompt call was cancelled)"
|
||||||
except McpError as exc:
|
except McpError as exc:
|
||||||
logger.error(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed: code={} message={}",
|
"MCP prompt '{}' failed: code={} message={}",
|
||||||
self._name,
|
self._name,
|
||||||
exc.error.code,
|
exc.error.code,
|
||||||
@ -400,11 +398,10 @@ class MCPPromptWrapper(Tool):
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
logger.error(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed after retry: {}: {}",
|
"MCP prompt '{}' failed after retry: {}",
|
||||||
self._name,
|
self._name,
|
||||||
type(exc).__name__,
|
type(exc).__name__,
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed after retry: {type(exc).__name__})"
|
return f"(MCP prompt call failed after retry: {type(exc).__name__})"
|
||||||
logger.exception(
|
logger.exception(
|
||||||
@ -439,8 +436,8 @@ async def connect_mcp_servers(
|
|||||||
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||||
|
|
||||||
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||||
Each server gets its own stack and runs in its own task to prevent
|
Each server gets its own stack to prevent cancel scope conflicts
|
||||||
cancel scope conflicts when multiple MCP servers are configured.
|
when multiple MCP servers are configured.
|
||||||
"""
|
"""
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
@ -608,26 +605,20 @@ async def connect_mcp_servers(
|
|||||||
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
|
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
|
||||||
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
||||||
)
|
)
|
||||||
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
logger.exception("MCP server '{}': failed to connect: {}", name, hint)
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await server_stack.aclose()
|
await server_stack.aclose()
|
||||||
return name, None
|
return name, None
|
||||||
|
|
||||||
server_stacks: dict[str, AsyncExitStack] = {}
|
server_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
|
||||||
tasks: list[asyncio.Task] = []
|
|
||||||
for name, cfg in mcp_servers.items():
|
for name, cfg in mcp_servers.items():
|
||||||
task = asyncio.create_task(connect_single_server(name, cfg))
|
try:
|
||||||
tasks.append(task)
|
result = await connect_single_server(name, cfg)
|
||||||
|
except Exception as e:
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
logger.error("MCP server '{}' connection failed: {}", name, e)
|
||||||
|
continue
|
||||||
for i, result in enumerate(results):
|
if result is not None and result[1] is not None:
|
||||||
name = list(mcp_servers.keys())[i]
|
|
||||||
if isinstance(result, BaseException):
|
|
||||||
if not isinstance(result, asyncio.CancelledError):
|
|
||||||
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
|
||||||
elif result is not None and result[1] is not None:
|
|
||||||
server_stacks[result[0]] = result[1]
|
server_stacks[result[0]] = result[1]
|
||||||
|
|
||||||
return server_stacks
|
return server_stacks
|
||||||
|
|||||||
@ -500,10 +500,10 @@ class WebFetchTool(Tool):
|
|||||||
"untrusted": True, "text": text,
|
"untrusted": True, "text": text,
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
except httpx.ProxyError as e:
|
except httpx.ProxyError as e:
|
||||||
logger.error("WebFetch proxy error for {}: {}", url, e)
|
logger.exception("WebFetch proxy error for {}", url)
|
||||||
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
|
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("WebFetch error for {}: {}", url, e)
|
logger.exception("WebFetch error for {}", url)
|
||||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
def _to_markdown(self, html_content: str) -> str:
|
def _to_markdown(self, html_content: str) -> str:
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class BaseChannel(ABC):
|
|||||||
bus: The message bus for communication.
|
bus: The message bus for communication.
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.logger = logger.bind(channel=self.name)
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@ -61,8 +62,8 @@ class BaseChannel(ABC):
|
|||||||
language=self.transcription_language or None,
|
language=self.transcription_language or None,
|
||||||
)
|
)
|
||||||
return await provider.transcribe(file_path)
|
return await provider.transcribe(file_path)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
self.logger.exception("Audio transcription failed")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def login(self, force: bool = False) -> bool:
|
async def login(self, force: bool = False) -> bool:
|
||||||
@ -136,7 +137,7 @@ class BaseChannel(ABC):
|
|||||||
else:
|
else:
|
||||||
allow_list = getattr(self.config, "allow_from", [])
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
if not allow_list:
|
if not allow_list:
|
||||||
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
self.logger.warning("allow_from is empty — all access denied")
|
||||||
return False
|
return False
|
||||||
if "*" in allow_list:
|
if "*" in allow_list:
|
||||||
return True
|
return True
|
||||||
@ -165,10 +166,10 @@ class BaseChannel(ABC):
|
|||||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||||
"""
|
"""
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Access denied for sender {} on channel {}. "
|
"Access denied for sender {}. "
|
||||||
"Add them to allowFrom list in config to grant access.",
|
"Add them to allowFrom list in config to grant access.",
|
||||||
sender_id, self.name,
|
sender_id,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from typing import Any
|
|||||||
from urllib.parse import unquote, urljoin, urlparse
|
from urllib.parse import unquote, urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -113,7 +112,7 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
content = content + "\n\nReceived files:\n" + file_list
|
content = content + "\n\nReceived files:\n" + file_list
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(
|
self.channel.logger.warning(
|
||||||
"Received empty or unsupported message type: {}",
|
"Received empty or unsupported message type: {}",
|
||||||
chatbot_msg.message_type,
|
chatbot_msg.message_type,
|
||||||
)
|
)
|
||||||
@ -128,7 +127,7 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
or message.data.get("openConversationId")
|
or message.data.get("openConversationId")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
self.channel.logger.info("Received message from {} ({}): {}", sender_name, sender_id, content)
|
||||||
|
|
||||||
# Forward to Nanobot via _on_message (non-blocking).
|
# Forward to Nanobot via _on_message (non-blocking).
|
||||||
# Store reference to prevent GC before task completes.
|
# Store reference to prevent GC before task completes.
|
||||||
@ -146,8 +145,8 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
|
|
||||||
return AckMessage.STATUS_OK, "OK"
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error processing DingTalk message: {}", e)
|
self.channel.logger.exception("Error processing message")
|
||||||
# Return OK to avoid retry loop from DingTalk server
|
# Return OK to avoid retry loop from DingTalk server
|
||||||
return AckMessage.STATUS_OK, "Error"
|
return AckMessage.STATUS_OK, "Error"
|
||||||
|
|
||||||
@ -204,20 +203,20 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"""Start the DingTalk bot with Stream Mode."""
|
"""Start the DingTalk bot with Stream Mode."""
|
||||||
try:
|
try:
|
||||||
if not DINGTALK_AVAILABLE:
|
if not DINGTALK_AVAILABLE:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"DingTalk Stream SDK not installed. Run: pip install dingtalk-stream"
|
"Stream SDK not installed. Run: pip install dingtalk-stream"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.client_id or not self.config.client_secret:
|
if not self.config.client_id or not self.config.client_secret:
|
||||||
logger.error("DingTalk client_id and client_secret not configured")
|
self.logger.error("client_id and client_secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._http = httpx.AsyncClient()
|
self._http = httpx.AsyncClient()
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
"Initializing Stream Client with Client ID: {}...",
|
||||||
self.config.client_id,
|
self.config.client_id,
|
||||||
)
|
)
|
||||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||||
@ -227,20 +226,20 @@ class DingTalkChannel(BaseChannel):
|
|||||||
handler = NanobotDingTalkHandler(self)
|
handler = NanobotDingTalkHandler(self)
|
||||||
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
||||||
|
|
||||||
logger.info("DingTalk bot started with Stream Mode")
|
self.logger.info("bot started with Stream Mode")
|
||||||
|
|
||||||
# Reconnect loop: restart stream if SDK exits or crashes
|
# Reconnect loop: restart stream if SDK exits or crashes
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
await self._client.start()
|
await self._client.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("DingTalk stream error: {}", e)
|
self.logger.warning("stream error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
self.logger.info("Reconnecting stream in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
self.logger.exception("Failed to start channel")
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the DingTalk bot."""
|
"""Stop the DingTalk bot."""
|
||||||
@ -266,7 +265,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not self._http:
|
if not self._http:
|
||||||
logger.warning("DingTalk HTTP client not initialized, cannot refresh token")
|
self.logger.warning("HTTP client not initialized, cannot refresh token")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -277,8 +276,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
# Expire 60s early to be safe
|
# Expire 60s early to be safe
|
||||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||||
return self._access_token
|
return self._access_token
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to get DingTalk access token: {}", e)
|
self.logger.exception("Failed to get access token")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -317,8 +316,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
) -> tuple[bytes, str, str | None]:
|
) -> tuple[bytes, str, str | None]:
|
||||||
ext = Path(filename).suffix.lower()
|
ext = Path(filename).suffix.lower()
|
||||||
if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html":
|
if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html":
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"DingTalk does not accept raw HTML attachments, zipping {} before upload",
|
"does not accept raw HTML attachments, zipping {} before upload",
|
||||||
filename,
|
filename,
|
||||||
)
|
)
|
||||||
return self._zip_bytes(filename, data)
|
return self._zip_bytes(filename, data)
|
||||||
@ -327,7 +326,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
def _validate_remote_media_url(self, media_ref: str) -> bool:
|
def _validate_remote_media_url(self, media_ref: str) -> bool:
|
||||||
ok, err = validate_url_target(media_ref)
|
ok, err = validate_url_target(media_ref)
|
||||||
if not ok:
|
if not ok:
|
||||||
logger.warning("DingTalk remote media URL blocked ref={} reason={}", media_ref, err)
|
self.logger.warning("remote media URL blocked ref={} reason={}", media_ref, err)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -343,15 +342,15 @@ class DingTalkChannel(BaseChannel):
|
|||||||
|
|
||||||
def _next_remote_media_url(self, current_url: str, location: str | None) -> str | None:
|
def _next_remote_media_url(self, current_url: str, location: str | None) -> str | None:
|
||||||
if not self.config.allow_remote_media_redirects:
|
if not self.config.allow_remote_media_redirects:
|
||||||
logger.warning("DingTalk media download redirect refused ref={}", current_url)
|
self.logger.warning("media download redirect refused ref={}", current_url)
|
||||||
return None
|
return None
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning("DingTalk media download redirect without Location ref={}", current_url)
|
self.logger.warning("media download redirect without Location ref={}", current_url)
|
||||||
return None
|
return None
|
||||||
next_url = urljoin(current_url, location)
|
next_url = urljoin(current_url, location)
|
||||||
if not self._redirect_host_allowed(current_url, next_url):
|
if not self._redirect_host_allowed(current_url, next_url):
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk media download cross-host redirect refused ref={} next={}",
|
"media download cross-host redirect refused ref={} next={}",
|
||||||
current_url,
|
current_url,
|
||||||
next_url,
|
next_url,
|
||||||
)
|
)
|
||||||
@ -382,8 +381,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
async with stream("GET", current_url, follow_redirects=False) as resp:
|
async with stream("GET", current_url, follow_redirects=False) as resp:
|
||||||
final_ok, final_err = validate_resolved_url(str(resp.url))
|
final_ok, final_err = validate_resolved_url(str(resp.url))
|
||||||
if not final_ok:
|
if not final_ok:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk remote media redirect blocked ref={} final={} reason={}",
|
"remote media redirect blocked ref={} final={} reason={}",
|
||||||
media_ref,
|
media_ref,
|
||||||
resp.url,
|
resp.url,
|
||||||
final_err,
|
final_err,
|
||||||
@ -398,8 +397,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
current_url = next_url
|
current_url = next_url
|
||||||
continue
|
continue
|
||||||
if resp.status_code >= 400:
|
if resp.status_code >= 400:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk media download failed status={} ref={}",
|
"media download failed status={} ref={}",
|
||||||
resp.status_code,
|
resp.status_code,
|
||||||
current_url,
|
current_url,
|
||||||
)
|
)
|
||||||
@ -409,15 +408,15 @@ class DingTalkChannel(BaseChannel):
|
|||||||
async for chunk in resp.aiter_bytes():
|
async for chunk in resp.aiter_bytes():
|
||||||
total += len(chunk)
|
total += len(chunk)
|
||||||
if total > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
if total > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk media download too large ref={} bytes>{}",
|
"media download too large ref={} bytes>{}",
|
||||||
current_url,
|
current_url,
|
||||||
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
return b"".join(chunks), (resp.headers.get("content-type") or "")
|
return b"".join(chunks), (resp.headers.get("content-type") or "")
|
||||||
logger.warning("DingTalk media download exceeded redirect limit ref={}", media_ref)
|
self.logger.warning("media download exceeded redirect limit ref={}", media_ref)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
current_url = media_ref
|
current_url = media_ref
|
||||||
@ -425,8 +424,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
resp = await self._http.get(current_url, follow_redirects=False)
|
resp = await self._http.get(current_url, follow_redirects=False)
|
||||||
final_ok, final_err = validate_resolved_url(str(getattr(resp, "url", current_url)))
|
final_ok, final_err = validate_resolved_url(str(getattr(resp, "url", current_url)))
|
||||||
if not final_ok:
|
if not final_ok:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk remote media redirect blocked ref={} final={} reason={}",
|
"remote media redirect blocked ref={} final={} reason={}",
|
||||||
media_ref,
|
media_ref,
|
||||||
getattr(resp, "url", current_url),
|
getattr(resp, "url", current_url),
|
||||||
final_err,
|
final_err,
|
||||||
@ -441,27 +440,27 @@ class DingTalkChannel(BaseChannel):
|
|||||||
current_url = next_url
|
current_url = next_url
|
||||||
continue
|
continue
|
||||||
if resp.status_code >= 400:
|
if resp.status_code >= 400:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk media download failed status={} ref={}",
|
"media download failed status={} ref={}",
|
||||||
resp.status_code,
|
resp.status_code,
|
||||||
current_url,
|
current_url,
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
if len(resp.content) > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
if len(resp.content) > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"DingTalk media download too large ref={} bytes>{}",
|
"media download too large ref={} bytes>{}",
|
||||||
current_url,
|
current_url,
|
||||||
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
return resp.content, (resp.headers.get("content-type") or "")
|
return resp.content, (resp.headers.get("content-type") or "")
|
||||||
logger.warning("DingTalk media download exceeded redirect limit ref={}", media_ref)
|
self.logger.warning("media download exceeded redirect limit ref={}", media_ref)
|
||||||
return None, None
|
return None, None
|
||||||
except httpx.TransportError as e:
|
except httpx.TransportError:
|
||||||
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
self.logger.exception("media download network error ref={}", media_ref)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
self.logger.exception("media download error ref={}", media_ref)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def _read_media_bytes(
|
async def _read_media_bytes(
|
||||||
@ -486,13 +485,13 @@ class DingTalkChannel(BaseChannel):
|
|||||||
else:
|
else:
|
||||||
local_path = Path(os.path.expanduser(media_ref))
|
local_path = Path(os.path.expanduser(media_ref))
|
||||||
if not local_path.is_file():
|
if not local_path.is_file():
|
||||||
logger.warning("DingTalk media file not found: {}", local_path)
|
self.logger.warning("media file not found: {}", local_path)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
data = await asyncio.to_thread(local_path.read_bytes)
|
data = await asyncio.to_thread(local_path.read_bytes)
|
||||||
content_type = mimetypes.guess_type(local_path.name)[0]
|
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||||
return data, local_path.name, content_type
|
return data, local_path.name, content_type
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
self.logger.exception("media read error ref={}", media_ref)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
async def _upload_media(
|
async def _upload_media(
|
||||||
@ -514,23 +513,23 @@ class DingTalkChannel(BaseChannel):
|
|||||||
text = resp.text
|
text = resp.text
|
||||||
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||||
if resp.status_code >= 400:
|
if resp.status_code >= 400:
|
||||||
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
self.logger.error("media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||||
return None
|
return None
|
||||||
errcode = result.get("errcode", 0)
|
errcode = result.get("errcode", 0)
|
||||||
if errcode != 0:
|
if errcode != 0:
|
||||||
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
self.logger.error("media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||||
return None
|
return None
|
||||||
sub = result.get("result") or {}
|
sub = result.get("result") or {}
|
||||||
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||||
if not media_id:
|
if not media_id:
|
||||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
self.logger.error("media upload missing media_id body={}", text[:500])
|
||||||
return None
|
return None
|
||||||
return str(media_id)
|
return str(media_id)
|
||||||
except httpx.TransportError as e:
|
except httpx.TransportError:
|
||||||
logger.error("DingTalk media upload network error type={} err={}", media_type, e)
|
self.logger.exception("media upload network error type={}", media_type)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
self.logger.exception("media upload error type={}", media_type)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _send_batch_message(
|
async def _send_batch_message(
|
||||||
@ -541,7 +540,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
msg_param: dict[str, Any],
|
msg_param: dict[str, Any],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if not self._http:
|
if not self._http:
|
||||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
self.logger.warning("HTTP client not initialized, cannot send")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
headers = {"x-acs-dingtalk-access-token": token}
|
headers = {"x-acs-dingtalk-access-token": token}
|
||||||
@ -568,7 +567,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
resp = await self._http.post(url, json=payload, headers=headers)
|
resp = await self._http.post(url, json=payload, headers=headers)
|
||||||
body = resp.text
|
body = resp.text
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
self.logger.error("send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
@ -576,15 +575,15 @@ class DingTalkChannel(BaseChannel):
|
|||||||
result = {}
|
result = {}
|
||||||
errcode = result.get("errcode")
|
errcode = result.get("errcode")
|
||||||
if errcode not in (None, 0):
|
if errcode not in (None, 0):
|
||||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
self.logger.error("send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||||
return False
|
return False
|
||||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
self.logger.debug("message sent to {} with msgKey={}", chat_id, msg_key)
|
||||||
return True
|
return True
|
||||||
except httpx.TransportError as e:
|
except httpx.TransportError:
|
||||||
logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e)
|
self.logger.exception("network error sending message msgKey={}", msg_key)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
self.logger.exception("Error sending message msgKey={}", msg_key)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||||
@ -610,11 +609,11 @@ class DingTalkChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
if ok:
|
if ok:
|
||||||
return True
|
return True
|
||||||
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
self.logger.warning("image url send failed, trying upload fallback: {}", media_ref)
|
||||||
|
|
||||||
data, filename, content_type = await self._read_media_bytes(media_ref)
|
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||||
if not data:
|
if not data:
|
||||||
logger.error("DingTalk media read failed: {}", media_ref)
|
self.logger.error("media read failed: {}", media_ref)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||||
@ -646,7 +645,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
if ok:
|
if ok:
|
||||||
return True
|
return True
|
||||||
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
self.logger.warning("image media_id send failed, falling back to file: {}", media_ref)
|
||||||
|
|
||||||
return await self._send_batch_message(
|
return await self._send_batch_message(
|
||||||
token,
|
token,
|
||||||
@ -668,7 +667,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||||
if ok:
|
if ok:
|
||||||
continue
|
continue
|
||||||
logger.error("DingTalk media send failed for {}", media_ref)
|
self.logger.error("media send failed for {}", media_ref)
|
||||||
# Send visible fallback so failures are observable by the user.
|
# Send visible fallback so failures are observable by the user.
|
||||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
await self._send_markdown_text(
|
await self._send_markdown_text(
|
||||||
@ -691,7 +690,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
permission checks before publishing to the bus.
|
permission checks before publishing to the bus.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
self.logger.info("inbound: {} from {}", content, sender_name)
|
||||||
is_group = conversation_type == "2" and conversation_id
|
is_group = conversation_type == "2" and conversation_id
|
||||||
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
@ -704,8 +703,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"conversation_type": conversation_type,
|
"conversation_type": conversation_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error publishing DingTalk message: {}", e)
|
self.logger.exception("Error publishing message")
|
||||||
|
|
||||||
async def _download_dingtalk_file(
|
async def _download_dingtalk_file(
|
||||||
self,
|
self,
|
||||||
@ -719,7 +718,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
token = await self._get_access_token()
|
token = await self._get_access_token()
|
||||||
if not token or not self._http:
|
if not token or not self._http:
|
||||||
logger.error("DingTalk file download: no token or http client")
|
self.logger.error("file download: no token or http client")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Step 1: Exchange downloadCode for a temporary download URL
|
# Step 1: Exchange downloadCode for a temporary download URL
|
||||||
@ -728,19 +727,19 @@ class DingTalkChannel(BaseChannel):
|
|||||||
payload = {"downloadCode": download_code, "robotCode": self.config.client_id}
|
payload = {"downloadCode": download_code, "robotCode": self.config.client_id}
|
||||||
resp = await self._http.post(api_url, json=payload, headers=headers)
|
resp = await self._http.post(api_url, json=payload, headers=headers)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
logger.error("DingTalk get download URL failed: status={}, body={}", resp.status_code, resp.text)
|
self.logger.error("get download URL failed: status={}, body={}", resp.status_code, resp.text)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
download_url = result.get("downloadUrl")
|
download_url = result.get("downloadUrl")
|
||||||
if not download_url:
|
if not download_url:
|
||||||
logger.error("DingTalk download URL not found in response: {}", result)
|
self.logger.error("download URL not found in response: {}", result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Step 2: Download the file content
|
# Step 2: Download the file content
|
||||||
file_resp = await self._http.get(download_url, follow_redirects=True)
|
file_resp = await self._http.get(download_url, follow_redirects=True)
|
||||||
if file_resp.status_code != 200:
|
if file_resp.status_code != 200:
|
||||||
logger.error("DingTalk file download failed: status={}", file_resp.status_code)
|
self.logger.error("file download failed: status={}", file_resp.status_code)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Save to media directory (accessible under workspace)
|
# Save to media directory (accessible under workspace)
|
||||||
@ -748,8 +747,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
download_dir.mkdir(parents=True, exist_ok=True)
|
download_dir.mkdir(parents=True, exist_ok=True)
|
||||||
file_path = download_dir / filename
|
file_path = download_dir / filename
|
||||||
await asyncio.to_thread(file_path.write_bytes, file_resp.content)
|
await asyncio.to_thread(file_path.write_bytes, file_resp.content)
|
||||||
logger.info("DingTalk file saved: {}", file_path)
|
self.logger.info("file saved: {}", file_path)
|
||||||
return str(file_path)
|
return str(file_path)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("DingTalk file download error: {}", e)
|
self.logger.exception("file download error")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -86,12 +85,12 @@ if DISCORD_AVAILABLE:
|
|||||||
|
|
||||||
async def on_ready(self) -> None:
|
async def on_ready(self) -> None:
|
||||||
self._channel._bot_user_id = str(self.user.id) if self.user else None
|
self._channel._bot_user_id = str(self.user.id) if self.user else None
|
||||||
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
|
self._channel.logger.info("bot connected as user {}", self._channel._bot_user_id)
|
||||||
try:
|
try:
|
||||||
synced = await self.tree.sync()
|
synced = await self.tree.sync()
|
||||||
logger.info("Discord app commands synced: {}", len(synced))
|
self._channel.logger.info("app commands synced: {}", len(synced))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord app command sync failed: {}", e)
|
self._channel.logger.warning("app command sync failed: {}", e)
|
||||||
|
|
||||||
async def on_message(self, message: discord.Message) -> None:
|
async def on_message(self, message: discord.Message) -> None:
|
||||||
await self._channel._handle_discord_message(message)
|
await self._channel._handle_discord_message(message)
|
||||||
@ -111,7 +110,7 @@ if DISCORD_AVAILABLE:
|
|||||||
await interaction.response.send_message(text, ephemeral=True)
|
await interaction.response.send_message(text, ephemeral=True)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord interaction response failed: {}", e)
|
self._channel.logger.warning("interaction response failed: {}", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _resolve_interaction_channel(
|
async def _resolve_interaction_channel(
|
||||||
@ -126,7 +125,7 @@ if DISCORD_AVAILABLE:
|
|||||||
try:
|
try:
|
||||||
channel = await self.fetch_channel(channel_id)
|
channel = await self.fetch_channel(channel_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord interaction channel {} unavailable: {}", channel_id, e)
|
self._channel.logger.warning("interaction channel {} unavailable: {}", channel_id, e)
|
||||||
return None
|
return None
|
||||||
self._channel._remember_channel(channel)
|
self._channel._remember_channel(channel)
|
||||||
return channel
|
return channel
|
||||||
@ -154,7 +153,7 @@ if DISCORD_AVAILABLE:
|
|||||||
channel_id = interaction.channel_id
|
channel_id = interaction.channel_id
|
||||||
|
|
||||||
if channel_id is None:
|
if channel_id is None:
|
||||||
logger.warning("Discord slash command missing channel_id: {}", command_text)
|
self._channel.logger.warning("slash command missing channel_id: {}", command_text)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._channel.is_allowed(sender_id):
|
if not self._channel.is_allowed(sender_id):
|
||||||
@ -226,8 +225,8 @@ if DISCORD_AVAILABLE:
|
|||||||
error: app_commands.AppCommandError,
|
error: app_commands.AppCommandError,
|
||||||
) -> None:
|
) -> None:
|
||||||
command_name = interaction.command.qualified_name if interaction.command else "?"
|
command_name = interaction.command.qualified_name if interaction.command else "?"
|
||||||
logger.warning(
|
self._channel.logger.warning(
|
||||||
"Discord app command failed user={} channel={} cmd={} error={}",
|
"app command failed user={} channel={} cmd={} error={}",
|
||||||
interaction.user.id,
|
interaction.user.id,
|
||||||
interaction.channel_id,
|
interaction.channel_id,
|
||||||
command_name,
|
command_name,
|
||||||
@ -243,7 +242,7 @@ if DISCORD_AVAILABLE:
|
|||||||
try:
|
try:
|
||||||
channel = await self.fetch_channel(channel_id)
|
channel = await self.fetch_channel(channel_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
|
self._channel.logger.warning("channel {} unavailable: {}", msg.chat_id, e)
|
||||||
return
|
return
|
||||||
|
|
||||||
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
|
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
|
||||||
@ -281,11 +280,11 @@ if DISCORD_AVAILABLE:
|
|||||||
"""Send a file attachment via discord.py."""
|
"""Send a file attachment via discord.py."""
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
self._channel.logger.warning("file not found, skipping: {}", file_path)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
self._channel.logger.warning("file too large (>20MB), skipping: {}", path.name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -294,10 +293,10 @@ if DISCORD_AVAILABLE:
|
|||||||
kwargs["reference"] = reference
|
kwargs["reference"] = reference
|
||||||
kwargs["allowed_mentions"] = mention_settings
|
kwargs["allowed_mentions"] = mention_settings
|
||||||
await channel.send(**kwargs)
|
await channel.send(**kwargs)
|
||||||
logger.info("Discord file sent: {}", path.name)
|
self._channel.logger.info("file sent: {}", path.name)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
self._channel.logger.exception("Error sending file {}", path.name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -321,7 +320,7 @@ if DISCORD_AVAILABLE:
|
|||||||
try:
|
try:
|
||||||
message_id = int(reply_to)
|
message_id = int(reply_to)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("Invalid Discord reply target: {}", reply_to)
|
self._channel.logger.warning("Invalid reply target: {}", reply_to)
|
||||||
return None, mention_settings
|
return None, mention_settings
|
||||||
|
|
||||||
return channel.get_partial_message(message_id), mention_settings
|
return channel.get_partial_message(message_id), mention_settings
|
||||||
@ -385,11 +384,11 @@ class DiscordChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Discord client."""
|
"""Start the Discord client."""
|
||||||
if not DISCORD_AVAILABLE:
|
if not DISCORD_AVAILABLE:
|
||||||
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
|
self.logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.token:
|
if not self.config.token:
|
||||||
logger.error("Discord bot token not configured")
|
self.logger.error("bot token not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -407,8 +406,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
password=self.config.proxy_password,
|
password=self.config.proxy_password,
|
||||||
)
|
)
|
||||||
elif has_user != has_pass:
|
elif has_user != has_pass:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Discord proxy auth incomplete: both proxy_username and "
|
"proxy auth incomplete: both proxy_username and "
|
||||||
"proxy_password must be set; ignoring partial credentials",
|
"proxy_password must be set; ignoring partial credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -418,21 +417,21 @@ class DiscordChannel(BaseChannel):
|
|||||||
proxy=self.config.proxy,
|
proxy=self.config.proxy,
|
||||||
proxy_auth=proxy_auth,
|
proxy_auth=proxy_auth,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to initialize Discord client: {}", e)
|
self.logger.exception("Failed to initialize client")
|
||||||
self._client = None
|
self._client = None
|
||||||
self._running = False
|
self._running = False
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
logger.info("Starting Discord client via discord.py...")
|
self.logger.info("Starting client via discord.py...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._client.start(self.config.token)
|
await self._client.start(self.config.token)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Discord client startup failed: {}", e)
|
self.logger.exception("client startup failed")
|
||||||
finally:
|
finally:
|
||||||
self._running = False
|
self._running = False
|
||||||
await self._reset_runtime_state(close_client=True)
|
await self._reset_runtime_state(close_client=True)
|
||||||
@ -446,15 +445,15 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""Send a message through Discord using discord.py."""
|
"""Send a message through Discord using discord.py."""
|
||||||
client = self._client
|
client = self._client
|
||||||
if client is None or not client.is_ready():
|
if client is None or not client.is_ready():
|
||||||
logger.warning("Discord client not ready; dropping outbound message")
|
self.logger.warning("client not ready; dropping outbound message")
|
||||||
return
|
return
|
||||||
|
|
||||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.send_outbound(msg)
|
await client.send_outbound(msg)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending Discord message: {}", e)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if not is_progress:
|
if not is_progress:
|
||||||
@ -467,7 +466,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||||
client = self._client
|
client = self._client
|
||||||
if client is None or not client.is_ready():
|
if client is None or not client.is_ready():
|
||||||
logger.warning("Discord client not ready; dropping stream delta")
|
self.logger.warning("client not ready; dropping stream delta")
|
||||||
return
|
return
|
||||||
|
|
||||||
meta = metadata or {}
|
meta = metadata or {}
|
||||||
@ -497,7 +496,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
|
|
||||||
target = await self._resolve_channel(chat_id)
|
target = await self._resolve_channel(chat_id)
|
||||||
if target is None:
|
if target is None:
|
||||||
logger.warning("Discord stream target {} unavailable", chat_id)
|
self.logger.warning("stream target {} unavailable", chat_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
@ -506,7 +505,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
buf.message = await target.send(content=buf.text)
|
buf.message = await target.send(content=buf.text)
|
||||||
buf.last_edit = now
|
buf.last_edit = now
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord stream initial send failed: {}", e)
|
self.logger.warning("stream initial send failed: {}", e)
|
||||||
raise
|
raise
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -517,7 +516,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
|
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
|
||||||
buf.last_edit = now
|
buf.last_edit = now
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord stream edit failed: {}", e)
|
self.logger.warning("stream edit failed: {}", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||||
@ -560,7 +559,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
await message.add_reaction(self.config.read_receipt_emoji)
|
await message.add_reaction(self.config.read_receipt_emoji)
|
||||||
self._pending_reactions[channel_id] = message
|
self._pending_reactions[channel_id] = message
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Failed to add read receipt reaction: {}", e)
|
self.logger.debug("Failed to add read receipt reaction: {}", e)
|
||||||
|
|
||||||
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
||||||
async def _delayed_working_emoji() -> None:
|
async def _delayed_working_emoji() -> None:
|
||||||
@ -603,7 +602,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
return await client.fetch_channel(channel_id)
|
return await client.fetch_channel(channel_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
|
self.logger.warning("channel {} unavailable: {}", chat_id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
|
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
|
||||||
@ -616,12 +615,12 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await buf.message.edit(content=chunks[0])
|
await buf.message.edit(content=chunks[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord final stream edit failed: {}", e)
|
self.logger.warning("final stream edit failed: {}", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
|
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
|
||||||
if target is None:
|
if target is None:
|
||||||
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
|
self.logger.warning("stream follow-up target {} unavailable", chat_id)
|
||||||
self._stream_bufs.pop(chat_id, None)
|
self._stream_bufs.pop(chat_id, None)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -673,7 +672,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
media_paths.append(str(file_path))
|
media_paths.append(str(file_path))
|
||||||
markers.append(f"[attachment: {file_path.name}]")
|
markers.append(f"[attachment: {file_path.name}]")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to download Discord attachment: {}", e)
|
self.logger.warning("Failed to download attachment: {}", e)
|
||||||
markers.append(f"[attachment: {filename} - download failed]")
|
markers.append(f"[attachment: {filename} - download failed]")
|
||||||
|
|
||||||
return media_paths, markers
|
return media_paths, markers
|
||||||
@ -715,8 +714,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
if bot_user_id is None and self._client and self._client.user:
|
if bot_user_id is None and self._client and self._client.user:
|
||||||
bot_user_id = str(self._client.user.id)
|
bot_user_id = str(self._client.user.id)
|
||||||
if bot_user_id is None:
|
if bot_user_id is None:
|
||||||
logger.debug(
|
self.logger.debug(
|
||||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
"message in {} ignored (bot identity unavailable)", message.channel.id
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -729,7 +728,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
if self._references_bot_message(message, bot_user_id):
|
if self._references_bot_message(message, bot_user_id):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
self.logger.debug("message in {} ignored (bot not mentioned)", message.channel.id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -759,7 +758,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
self.logger.debug("typing indicator failed for {}: {}", channel_id, e)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||||
@ -803,6 +802,6 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._client.close()
|
await self._client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Discord client close failed: {}", e)
|
self.logger.warning("client close failed: {}", e)
|
||||||
self._client = None
|
self._client = None
|
||||||
self._bot_user_id = None
|
self._bot_user_id = None
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class EmailChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start polling IMAP for inbound emails."""
|
"""Start polling IMAP for inbound emails."""
|
||||||
if not self.config.consent_granted:
|
if not self.config.consent_granted:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Email channel disabled: consent_granted is false. "
|
"Email channel disabled: consent_granted is false. "
|
||||||
"Set channels.email.consentGranted=true after explicit user permission."
|
"Set channels.email.consentGranted=true after explicit user permission."
|
||||||
)
|
)
|
||||||
@ -139,12 +139,12 @@ class EmailChannel(BaseChannel):
|
|||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
if not self.config.verify_dkim and not self.config.verify_spf:
|
if not self.config.verify_dkim and not self.config.verify_spf:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Email channel: DKIM and SPF verification are both DISABLED. "
|
"DKIM and SPF verification are both DISABLED. "
|
||||||
"Emails with spoofed From headers will be accepted. "
|
"Emails with spoofed From headers will be accepted. "
|
||||||
"Set verify_dkim=true and verify_spf=true for anti-spoofing protection."
|
"Set verify_dkim=true and verify_spf=true for anti-spoofing protection."
|
||||||
)
|
)
|
||||||
logger.info("Starting Email channel (IMAP polling mode)...")
|
self.logger.info("Starting Email channel (IMAP polling mode)...")
|
||||||
|
|
||||||
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
||||||
while self._running:
|
while self._running:
|
||||||
@ -167,8 +167,8 @@ class EmailChannel(BaseChannel):
|
|||||||
media=item.get("media") or None,
|
media=item.get("media") or None,
|
||||||
metadata=item.get("metadata", {}),
|
metadata=item.get("metadata", {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Email polling error: {}", e)
|
self.logger.exception("Polling error")
|
||||||
|
|
||||||
await asyncio.sleep(poll_seconds)
|
await asyncio.sleep(poll_seconds)
|
||||||
|
|
||||||
@ -179,16 +179,16 @@ class EmailChannel(BaseChannel):
|
|||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send email via SMTP."""
|
"""Send email via SMTP."""
|
||||||
if not self.config.consent_granted:
|
if not self.config.consent_granted:
|
||||||
logger.warning("Skip email send: consent_granted is false")
|
self.logger.warning("Skip email send: consent_granted is false")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.smtp_host:
|
if not self.config.smtp_host:
|
||||||
logger.warning("Email channel SMTP host not configured")
|
self.logger.warning("SMTP host not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
to_addr = msg.chat_id.strip()
|
to_addr = msg.chat_id.strip()
|
||||||
if not to_addr:
|
if not to_addr:
|
||||||
logger.warning("Email channel missing recipient address")
|
self.logger.warning("Missing recipient address")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Determine if this is a reply (recipient has sent us an email before)
|
# Determine if this is a reply (recipient has sent us an email before)
|
||||||
@ -197,7 +197,7 @@ class EmailChannel(BaseChannel):
|
|||||||
|
|
||||||
# autoReplyEnabled only controls automatic replies, not proactive sends
|
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||||
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||||
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
self.logger.info("Skip automatic reply to {}: auto_reply_enabled is false", to_addr)
|
||||||
return
|
return
|
||||||
|
|
||||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||||
@ -220,8 +220,8 @@ class EmailChannel(BaseChannel):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
self.logger.exception("Error sending to {}", to_addr)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_config(self) -> bool:
|
def _validate_config(self) -> bool:
|
||||||
@ -240,7 +240,7 @@ class EmailChannel(BaseChannel):
|
|||||||
missing.append("smtp_password")
|
missing.append("smtp_password")
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
self.logger.error("Channel not configured, missing: {}", ', '.join(missing))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ class EmailChannel(BaseChannel):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if attempt == 1 or not self._is_stale_imap_error(exc):
|
if attempt == 1 or not self._is_stale_imap_error(exc):
|
||||||
raise
|
raise
|
||||||
logger.warning("Email IMAP connection went stale, retrying once: {}", exc)
|
self.logger.warning("IMAP connection went stale, retrying once: {}", exc)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -348,11 +348,11 @@ class EmailChannel(BaseChannel):
|
|||||||
status, _ = client.select(mailbox)
|
status, _ = client.select(mailbox)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if self._is_missing_mailbox_error(exc):
|
if self._is_missing_mailbox_error(exc):
|
||||||
logger.warning("Email mailbox unavailable, skipping poll for {}: {}", mailbox, exc)
|
self.logger.warning("Mailbox unavailable, skipping poll for {}: {}", mailbox, exc)
|
||||||
return messages
|
return messages
|
||||||
raise
|
raise
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
logger.warning("Email mailbox select returned {}, skipping poll for {}", status, mailbox)
|
self.logger.warning("Mailbox select returned {}, skipping poll for {}", status, mailbox)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
status, data = client.search(None, *search_criteria)
|
status, data = client.search(None, *search_criteria)
|
||||||
@ -382,7 +382,7 @@ class EmailChannel(BaseChannel):
|
|||||||
if not sender:
|
if not sender:
|
||||||
continue
|
continue
|
||||||
if self._is_self_address(sender):
|
if self._is_self_address(sender):
|
||||||
logger.info("Email from {} ignored: matches bot-owned address", sender)
|
self.logger.info("From {} ignored: matches bot-owned address", sender)
|
||||||
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
||||||
if mark_seen:
|
if mark_seen:
|
||||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||||
@ -391,16 +391,16 @@ class EmailChannel(BaseChannel):
|
|||||||
# --- Anti-spoofing: verify Authentication-Results ---
|
# --- Anti-spoofing: verify Authentication-Results ---
|
||||||
spf_pass, dkim_pass = self._check_authentication_results(parsed)
|
spf_pass, dkim_pass = self._check_authentication_results(parsed)
|
||||||
if self.config.verify_spf and not spf_pass:
|
if self.config.verify_spf and not spf_pass:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Email from {} rejected: SPF verification failed "
|
"From {} rejected: SPF verification failed "
|
||||||
"(no 'spf=pass' in Authentication-Results header)",
|
"(no 'spf=pass' in Authentication-Results header)",
|
||||||
sender,
|
sender,
|
||||||
)
|
)
|
||||||
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
||||||
continue
|
continue
|
||||||
if self.config.verify_dkim and not dkim_pass:
|
if self.config.verify_dkim and not dkim_pass:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Email from {} rejected: DKIM verification failed "
|
"From {} rejected: DKIM verification failed "
|
||||||
"(no 'dkim=pass' in Authentication-Results header)",
|
"(no 'dkim=pass' in Authentication-Results header)",
|
||||||
sender,
|
sender,
|
||||||
)
|
)
|
||||||
@ -641,7 +641,7 @@ class EmailChannel(BaseChannel):
|
|||||||
|
|
||||||
content_type = part.get_content_type()
|
content_type = part.get_content_type()
|
||||||
if not any(fnmatch(content_type, pat) for pat in allowed_types):
|
if not any(fnmatch(content_type, pat) for pat in allowed_types):
|
||||||
logger.debug("Email attachment skipped (type {}): not in allowed list", content_type)
|
logger.debug("Attachment skipped (type {}): not in allowed list", content_type)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = part.get_payload(decode=True)
|
payload = part.get_payload(decode=True)
|
||||||
@ -649,7 +649,7 @@ class EmailChannel(BaseChannel):
|
|||||||
continue
|
continue
|
||||||
if len(payload) > max_size:
|
if len(payload) > max_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Email attachment skipped: size {} exceeds limit {}",
|
"Attachment skipped: size {} exceeds limit {}",
|
||||||
len(payload),
|
len(payload),
|
||||||
max_size,
|
max_size,
|
||||||
)
|
)
|
||||||
@ -662,9 +662,9 @@ class EmailChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
dest.write_bytes(payload)
|
dest.write_bytes(payload)
|
||||||
saved.append(dest)
|
saved.append(dest)
|
||||||
logger.info("Email attachment saved: {}", dest)
|
logger.info("Attachment saved: {}", dest)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to save email attachment {}: {}", dest, exc)
|
logger.warning("Failed to save attachment {}: {}", dest, exc)
|
||||||
|
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
||||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -23,6 +22,7 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||||
|
|
||||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||||
|
|
||||||
@ -320,15 +320,17 @@ class FeishuChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Feishu bot with WebSocket long connection."""
|
"""Start the Feishu bot with WebSocket long connection."""
|
||||||
if not FEISHU_AVAILABLE:
|
if not FEISHU_AVAILABLE:
|
||||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
self.logger.error("SDK not installed. Run: pip install lark-oapi")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.app_id or not self.config.app_secret:
|
if not self.config.app_id or not self.config.app_secret:
|
||||||
logger.error("Feishu app_id and app_secret not configured")
|
self.logger.error("app_id and app_secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
|
|
||||||
|
redirect_lib_logging("Lark")
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
@ -390,7 +392,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
self._ws_client.start()
|
self._ws_client.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Feishu WebSocket error: {}", e)
|
self.logger.warning("WebSocket error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
finally:
|
finally:
|
||||||
@ -404,12 +406,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
None, self._fetch_bot_open_id
|
None, self._fetch_bot_open_id
|
||||||
)
|
)
|
||||||
if self._bot_open_id:
|
if self._bot_open_id:
|
||||||
logger.info("Feishu bot open_id: {}", self._bot_open_id)
|
self.logger.info("bot open_id: {}", self._bot_open_id)
|
||||||
else:
|
else:
|
||||||
logger.warning("Could not fetch bot open_id; @mention matching may be inaccurate")
|
self.logger.warning("Could not fetch bot open_id; @mention matching may be inaccurate")
|
||||||
|
|
||||||
logger.info("Feishu bot started with WebSocket long connection")
|
self.logger.info("bot started with WebSocket long connection")
|
||||||
logger.info("No public IP required - using WebSocket to receive events")
|
self.logger.info("No public IP required - using WebSocket to receive events")
|
||||||
|
|
||||||
# Keep running until stopped
|
# Keep running until stopped
|
||||||
while self._running:
|
while self._running:
|
||||||
@ -424,7 +426,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
||||||
"""
|
"""
|
||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Feishu bot stopped")
|
self.logger.info("bot stopped")
|
||||||
|
|
||||||
def _fetch_bot_open_id(self) -> str | None:
|
def _fetch_bot_open_id(self) -> str | None:
|
||||||
"""Fetch the bot's own open_id via GET /open-apis/bot/v3/info."""
|
"""Fetch the bot's own open_id via GET /open-apis/bot/v3/info."""
|
||||||
@ -445,10 +447,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
data = json.loads(response.raw.content)
|
data = json.loads(response.raw.content)
|
||||||
bot = (data.get("data") or data).get("bot") or data.get("bot") or {}
|
bot = (data.get("data") or data).get("bot") or data.get("bot") or {}
|
||||||
return bot.get("open_id")
|
return bot.get("open_id")
|
||||||
logger.warning("Failed to get bot info: code={}, msg={}", response.code, response.msg)
|
self.logger.warning("Failed to get bot info: code={}, msg={}", response.code, response.msg)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error fetching bot info: {}", e)
|
self.logger.warning("Error fetching bot info: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -539,15 +541,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.message_reaction.create(request)
|
response = self._client.im.v1.message_reaction.create(request)
|
||||||
|
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Failed to add reaction: code={}, msg={}", response.code, response.msg
|
"Failed to add reaction: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
self.logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||||
return response.data.reaction_id if response.data else None
|
return response.data.reaction_id if response.data else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error adding reaction: {}", e)
|
self.logger.warning("Error adding reaction: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
||||||
@ -579,13 +581,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
response = self._client.im.v1.message_reaction.delete(request)
|
response = self._client.im.v1.message_reaction.delete(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
logger.debug("Removed reaction {} from message {}", reaction_id, message_id)
|
self.logger.debug("Removed reaction {} from message {}", reaction_id, message_id)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
self.logger.debug(
|
||||||
"Failed to remove reaction: code={}, msg={}", response.code, response.msg
|
"Failed to remove reaction: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Error removing reaction: {}", e)
|
self.logger.debug("Error removing reaction: {}", e)
|
||||||
|
|
||||||
async def _remove_reaction(self, message_id: str, reaction_id: str) -> None:
|
async def _remove_reaction(self, message_id: str, reaction_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
@ -607,7 +609,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
task.result()
|
task.result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Background task failed: {}", exc)
|
self.logger.warning("Background task failed: {}", exc)
|
||||||
|
|
||||||
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
|
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
|
||||||
"""Callback: store reaction_id after background add-reaction completes."""
|
"""Callback: store reaction_id after background add-reaction completes."""
|
||||||
@ -917,15 +919,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.image.create(request)
|
response = self._client.im.v1.image.create(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
image_key = response.data.image_key
|
image_key = response.data.image_key
|
||||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
self.logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||||
return image_key
|
return image_key
|
||||||
else:
|
else:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to upload image: code={}, msg={}", response.code, response.msg
|
"Failed to upload image: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error uploading image {}: {}", file_path, e)
|
self.logger.exception("Error uploading image {}", file_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||||
@ -951,15 +953,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.file.create(request)
|
response = self._client.im.v1.file.create(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
file_key = response.data.file_key
|
file_key = response.data.file_key
|
||||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
self.logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||||
return file_key
|
return file_key
|
||||||
else:
|
else:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to upload file: code={}, msg={}", response.code, response.msg
|
"Failed to upload file: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error uploading file {}: {}", file_path, e)
|
self.logger.exception("Error uploading file {}", file_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _download_image_sync(
|
def _download_image_sync(
|
||||||
@ -984,12 +986,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
file_data = file_data.read()
|
file_data = file_data.read()
|
||||||
return file_data, response.file_name
|
return file_data, response.file_name
|
||||||
else:
|
else:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to download image: code={}, msg={}", response.code, response.msg
|
"Failed to download image: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error downloading image {}: {}", image_key, e)
|
self.logger.exception("Error downloading image {}", image_key)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def _download_file_sync(
|
def _download_file_sync(
|
||||||
@ -1018,7 +1020,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
file_data = file_data.read()
|
file_data = file_data.read()
|
||||||
return file_data, response.file_name
|
return file_data, response.file_name
|
||||||
else:
|
else:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to download {}: code={}, msg={}",
|
"Failed to download {}: code={}, msg={}",
|
||||||
resource_type,
|
resource_type,
|
||||||
response.code,
|
response.code,
|
||||||
@ -1026,7 +1028,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
self.logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def _download_and_save_media(
|
async def _download_and_save_media(
|
||||||
@ -1055,10 +1057,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
elif msg_type in ("audio", "file", "media"):
|
elif msg_type in ("audio", "file", "media"):
|
||||||
file_key = content_json.get("file_key")
|
file_key = content_json.get("file_key")
|
||||||
if not file_key:
|
if not file_key:
|
||||||
logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json)
|
self.logger.warning("{} message missing file_key: {}", msg_type, content_json)
|
||||||
return None, f"[{msg_type}: missing file_key]"
|
return None, f"[{msg_type}: missing file_key]"
|
||||||
if not message_id:
|
if not message_id:
|
||||||
logger.warning("Feishu {} message missing message_id", msg_type)
|
self.logger.warning("{} message missing message_id", msg_type)
|
||||||
return None, f"[{msg_type}: missing message_id]"
|
return None, f"[{msg_type}: missing message_id]"
|
||||||
|
|
||||||
data, filename = await loop.run_in_executor(
|
data, filename = await loop.run_in_executor(
|
||||||
@ -1066,7 +1068,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key)
|
self.logger.warning("{} download failed: file_key={}", msg_type, file_key)
|
||||||
return None, f"[{msg_type}: download failed]"
|
return None, f"[{msg_type}: download failed]"
|
||||||
|
|
||||||
if not filename:
|
if not filename:
|
||||||
@ -1082,7 +1084,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
file_path = media_dir / filename
|
file_path = media_dir / filename
|
||||||
file_path.write_bytes(data)
|
file_path.write_bytes(data)
|
||||||
path_str = str(file_path)
|
path_str = str(file_path)
|
||||||
logger.debug("Downloaded {} to {}", msg_type, path_str)
|
self.logger.debug("Downloaded {} to {}", msg_type, path_str)
|
||||||
return path_str, f"[{msg_type}: {path_str}]"
|
return path_str, f"[{msg_type}: {path_str}]"
|
||||||
|
|
||||||
return None, f"[{msg_type}: download failed]"
|
return None, f"[{msg_type}: download failed]"
|
||||||
@ -1100,8 +1102,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
request = GetMessageRequest.builder().message_id(message_id).build()
|
request = GetMessageRequest.builder().message_id(message_id).build()
|
||||||
response = self._client.im.v1.message.get(request)
|
response = self._client.im.v1.message.get(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.debug(
|
self.logger.debug(
|
||||||
"Feishu: could not fetch parent message {}: code={}, msg={}",
|
"could not fetch parent message {}: code={}, msg={}",
|
||||||
message_id,
|
message_id,
|
||||||
response.code,
|
response.code,
|
||||||
response.msg,
|
response.msg,
|
||||||
@ -1133,7 +1135,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
text = text[: self._REPLY_CONTEXT_MAX_LEN] + "..."
|
text = text[: self._REPLY_CONTEXT_MAX_LEN] + "..."
|
||||||
return f"[Reply to: {text}]"
|
return f"[Reply to: {text}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
|
self.logger.debug("error fetching parent message {}: {}", message_id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
|
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
|
||||||
@ -1157,18 +1159,18 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
response = self._client.im.v1.message.reply(request)
|
response = self._client.im.v1.message.reply(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to reply to Feishu message {}: code={}, msg={}, log_id={}",
|
"Failed to reply to message {}: code={}, msg={}, log_id={}",
|
||||||
parent_message_id,
|
parent_message_id,
|
||||||
response.code,
|
response.code,
|
||||||
response.msg,
|
response.msg,
|
||||||
response.get_log_id(),
|
response.get_log_id(),
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
logger.debug("Feishu reply sent to message {}", parent_message_id)
|
self.logger.debug("reply sent to message {}", parent_message_id)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error replying to Feishu message {}: {}", parent_message_id, e)
|
self.logger.exception("Error replying to message {}", parent_message_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _should_use_reply_in_thread(self, metadata: dict[str, Any]) -> bool:
|
def _should_use_reply_in_thread(self, metadata: dict[str, Any]) -> bool:
|
||||||
@ -1207,8 +1209,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
response = self._client.im.v1.message.create(request)
|
response = self._client.im.v1.message.create(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
"Failed to send {} message: code={}, msg={}, log_id={}",
|
||||||
msg_type,
|
msg_type,
|
||||||
response.code,
|
response.code,
|
||||||
response.msg,
|
response.msg,
|
||||||
@ -1216,10 +1218,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
msg_id = getattr(response.data, "message_id", None)
|
msg_id = getattr(response.data, "message_id", None)
|
||||||
logger.debug("Feishu {} message sent to {}: {}", msg_type, receive_id, msg_id)
|
self.logger.debug("{} message sent to {}: {}", msg_type, receive_id, msg_id)
|
||||||
return msg_id
|
return msg_id
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
self.logger.exception("Error sending {} message", msg_type)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_streaming_card_sync(
|
def _create_streaming_card_sync(
|
||||||
@ -1259,7 +1261,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
response = self._client.cardkit.v1.card.create(request)
|
response = self._client.cardkit.v1.card.create(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Failed to create streaming card: code={}, msg={}", response.code, response.msg
|
"Failed to create streaming card: code={}, msg={}", response.code, response.msg
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -1279,12 +1281,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
) is not None
|
) is not None
|
||||||
if sent:
|
if sent:
|
||||||
return card_id
|
return card_id
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error creating streaming card: {}", e)
|
self.logger.warning("Error creating streaming card: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool:
|
def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool:
|
||||||
@ -1309,7 +1311,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
response = self._client.cardkit.v1.card_element.content(request)
|
response = self._client.cardkit.v1.card_element.content(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Failed to stream-update card {}: code={}, msg={}",
|
"Failed to stream-update card {}: code={}, msg={}",
|
||||||
card_id,
|
card_id,
|
||||||
response.code,
|
response.code,
|
||||||
@ -1318,7 +1320,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error stream-updating card {}: {}", card_id, e)
|
self.logger.warning("Error stream-updating card {}: {}", card_id, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _close_streaming_mode_sync(self, card_id: str, sequence: int) -> bool:
|
def _close_streaming_mode_sync(self, card_id: str, sequence: int) -> bool:
|
||||||
@ -1346,7 +1348,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
response = self._client.cardkit.v1.card.settings(request)
|
response = self._client.cardkit.v1.card.settings(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Failed to close streaming on card {}: code={}, msg={}",
|
"Failed to close streaming on card {}: code={}, msg={}",
|
||||||
card_id,
|
card_id,
|
||||||
response.code,
|
response.code,
|
||||||
@ -1355,7 +1357,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error closing streaming on card {}: {}", card_id, e)
|
self.logger.warning("Error closing streaming on card {}: {}", card_id, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_delta(
|
async def send_delta(
|
||||||
@ -1416,7 +1418,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
buf.sequence,
|
buf.sequence,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Streaming card {} final update failed, falling back to regular card",
|
"Streaming card {} final update failed, falling back to regular card",
|
||||||
buf.card_id,
|
buf.card_id,
|
||||||
)
|
)
|
||||||
@ -1484,7 +1486,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through Feishu, including media (images/files) if present."""
|
"""Send a message through Feishu, including media (images/files) if present."""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
logger.warning("Feishu client not initialized")
|
self.logger.warning("client not initialized")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1566,7 +1568,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
for file_path in msg.media:
|
for file_path in msg.media:
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
logger.warning("Media file not found: {}", file_path)
|
self.logger.warning("Media file not found: {}", file_path)
|
||||||
continue
|
continue
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
if ext in self._IMAGE_EXTS:
|
if ext in self._IMAGE_EXTS:
|
||||||
@ -1622,8 +1624,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
json.dumps(card, ensure_ascii=False),
|
json.dumps(card, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending Feishu message: {}", e)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _on_message_sync(self, data: Any) -> None:
|
def _on_message_sync(self, data: Any) -> None:
|
||||||
@ -1641,8 +1643,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
message = event.message
|
message = event.message
|
||||||
sender = event.sender
|
sender = event.sender
|
||||||
|
|
||||||
logger.debug("Feishu raw message: {}", message.content)
|
self.logger.debug("raw message: {}", message.content)
|
||||||
logger.debug("Feishu mentions: {}", getattr(message, "mentions", None))
|
self.logger.debug("mentions: {}", getattr(message, "mentions", None))
|
||||||
|
|
||||||
message_id = message.message_id
|
message_id = message.message_id
|
||||||
|
|
||||||
@ -1659,7 +1661,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
||||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
self.logger.debug("skipping group message (not mentioned)")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Deduplication check
|
# Deduplication check
|
||||||
@ -1784,8 +1786,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error processing Feishu message: {}", e)
|
self.logger.exception("Error processing message")
|
||||||
|
|
||||||
def _on_reaction_created(self, data: Any) -> None:
|
def _on_reaction_created(self, data: Any) -> None:
|
||||||
"""Ignore reaction events so they do not generate SDK noise."""
|
"""Ignore reaction events so they do not generate SDK noise."""
|
||||||
@ -1801,7 +1803,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
||||||
"""Ignore p2p-enter events when a user opens a bot chat."""
|
"""Ignore p2p-enter events when a user opens a bot chat."""
|
||||||
logger.debug("Bot entered p2p chat (user opened chat window)")
|
self.logger.debug("Bot entered p2p chat (user opened chat window)")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -174,8 +174,8 @@ class ChannelManager:
|
|||||||
"""Start a channel and log any exceptions."""
|
"""Start a channel and log any exceptions."""
|
||||||
try:
|
try:
|
||||||
await channel.start()
|
await channel.start()
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to start channel {}: {}", name, e)
|
logger.exception("Failed to start channel {}", name)
|
||||||
|
|
||||||
async def start_all(self) -> None:
|
async def start_all(self) -> None:
|
||||||
"""Start all channels and the outbound dispatcher."""
|
"""Start all channels and the outbound dispatcher."""
|
||||||
@ -230,8 +230,8 @@ class ChannelManager:
|
|||||||
try:
|
try:
|
||||||
await channel.stop()
|
await channel.stop()
|
||||||
logger.info("Stopped {} channel", name)
|
logger.info("Stopped {} channel", name)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error stopping {}: {}", name, e)
|
logger.exception("Error stopping {}", name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fingerprint_content(content: str) -> str:
|
def _fingerprint_content(content: str) -> str:
|
||||||
@ -392,9 +392,9 @@ class ChannelManager:
|
|||||||
raise # Propagate cancellation for graceful shutdown
|
raise # Propagate cancellation for graceful shutdown
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == max_attempts - 1:
|
if attempt == max_attempts - 1:
|
||||||
logger.error(
|
logger.exception(
|
||||||
"Failed to send to {} after {} attempts: {} - {}",
|
"Failed to send to {} after {} attempts",
|
||||||
msg.channel, max_attempts, type(e).__name__, e
|
msg.channel, max_attempts
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)]
|
delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)]
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
@ -10,7 +9,6 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -47,6 +45,7 @@ from nanobot.channels.base import BaseChannel
|
|||||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
from nanobot.utils.helpers import safe_filename
|
from nanobot.utils.helpers import safe_filename
|
||||||
|
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||||
|
|
||||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||||
@ -178,28 +177,6 @@ def _build_matrix_text_content(
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
class _NioLoguruHandler(logging.Handler):
|
|
||||||
"""Route matrix-nio stdlib logs into Loguru."""
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
try:
|
|
||||||
level = logger.level(record.levelname).name
|
|
||||||
except ValueError:
|
|
||||||
level = record.levelno
|
|
||||||
frame, depth = logging.currentframe(), 2
|
|
||||||
while frame and frame.f_code.co_filename == logging.__file__:
|
|
||||||
frame, depth = frame.f_back, depth + 1
|
|
||||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_nio_logging_bridge() -> None:
|
|
||||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
|
||||||
nio_logger = logging.getLogger("nio")
|
|
||||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
|
||||||
nio_logger.handlers = [_NioLoguruHandler()]
|
|
||||||
nio_logger.propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixConfig(Base):
|
class MatrixConfig(Base):
|
||||||
"""Matrix (Element) channel configuration."""
|
"""Matrix (Element) channel configuration."""
|
||||||
|
|
||||||
@ -259,7 +236,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
"""Start Matrix client and begin sync loop."""
|
"""Start Matrix client and begin sync loop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
self._started_at_ms = int(time.time() * 1000)
|
self._started_at_ms = int(time.time() * 1000)
|
||||||
_configure_nio_logging_bridge()
|
redirect_lib_logging("nio", level="WARNING")
|
||||||
|
|
||||||
self.store_path = get_data_dir() / "matrix-store"
|
self.store_path = get_data_dir() / "matrix-store"
|
||||||
self.store_path.mkdir(parents=True, exist_ok=True)
|
self.store_path.mkdir(parents=True, exist_ok=True)
|
||||||
@ -283,15 +260,15 @@ class MatrixChannel(BaseChannel):
|
|||||||
self._register_response_callbacks()
|
self._register_response_callbacks()
|
||||||
|
|
||||||
if not self.config.e2ee_enabled:
|
if not self.config.e2ee_enabled:
|
||||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
self.logger.warning("E2EE disabled; encrypted rooms may be undecryptable.")
|
||||||
|
|
||||||
if self.config.password:
|
if self.config.password:
|
||||||
if self.config.access_token or self.config.device_id:
|
if self.config.access_token or self.config.device_id:
|
||||||
logger.warning("Password-based Matrix login active; access_token and device_id fields will be ignored.")
|
self.logger.warning("Password-based login active; access_token and device_id fields will be ignored.")
|
||||||
|
|
||||||
create_new_session = True
|
create_new_session = True
|
||||||
if self.session_path.exists():
|
if self.session_path.exists():
|
||||||
logger.info("Found session.json at {}; attempting to use existing session...", self.session_path)
|
self.logger.info("Found session.json at {}; attempting to use existing session...", self.session_path)
|
||||||
try:
|
try:
|
||||||
with open(self.session_path, "r", encoding="utf-8") as f:
|
with open(self.session_path, "r", encoding="utf-8") as f:
|
||||||
session = json.load(f)
|
session = json.load(f)
|
||||||
@ -299,20 +276,20 @@ class MatrixChannel(BaseChannel):
|
|||||||
self.client.access_token = session["access_token"]
|
self.client.access_token = session["access_token"]
|
||||||
self.client.device_id = session["device_id"]
|
self.client.device_id = session["device_id"]
|
||||||
self.client.load_store()
|
self.client.load_store()
|
||||||
logger.info("Successfully loaded from existing session")
|
self.logger.info("Successfully loaded from existing session")
|
||||||
create_new_session = False
|
create_new_session = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load from existing session: {}", e)
|
self.logger.warning("Failed to load from existing session: {}", e)
|
||||||
logger.info("Falling back to password login...")
|
self.logger.info("Falling back to password login...")
|
||||||
|
|
||||||
if create_new_session:
|
if create_new_session:
|
||||||
logger.info("Using password login...")
|
self.logger.info("Using password login...")
|
||||||
resp = await self.client.login(self.config.password)
|
resp = await self.client.login(self.config.password)
|
||||||
if isinstance(resp, LoginResponse):
|
if isinstance(resp, LoginResponse):
|
||||||
logger.info("Logged in using a password; saving details to disk")
|
self.logger.info("Logged in using a password; saving details to disk")
|
||||||
self._write_session_to_disk(resp)
|
self._write_session_to_disk(resp)
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to log in: {}", resp)
|
self.logger.error("Failed to log in: {}", resp)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif self.config.access_token and self.config.device_id:
|
elif self.config.access_token and self.config.device_id:
|
||||||
@ -321,12 +298,12 @@ class MatrixChannel(BaseChannel):
|
|||||||
self.client.access_token = self.config.access_token
|
self.client.access_token = self.config.access_token
|
||||||
self.client.device_id = self.config.device_id
|
self.client.device_id = self.config.device_id
|
||||||
self.client.load_store()
|
self.client.load_store()
|
||||||
logger.info("Successfully loaded from existing session")
|
self.logger.info("Successfully loaded from existing session")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load from existing session: {}", e)
|
self.logger.warning("Failed to load from existing session: {}", e)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Unable to load a Matrix session due to missing password, access_token, or device_id; encryption may not work")
|
self.logger.warning("Unable to load a session due to missing password, access_token, or device_id; encryption may not work")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||||
@ -358,9 +335,9 @@ class MatrixChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
with open(self.session_path, "w", encoding="utf-8") as f:
|
with open(self.session_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(session, f, indent=2)
|
json.dump(session, f, indent=2)
|
||||||
logger.info("Session saved to {}", self.session_path)
|
self.logger.info("Session saved to {}", self.session_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to save session: {}", e)
|
self.logger.warning("Failed to save session: {}", e)
|
||||||
|
|
||||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||||
"""Check path is inside workspace (when restriction enabled)."""
|
"""Check path is inside workspace (when restriction enabled)."""
|
||||||
@ -598,14 +575,14 @@ class MatrixChannel(BaseChannel):
|
|||||||
def _log_response_error(self, label: str, response: Any) -> None:
|
def _log_response_error(self, label: str, response: Any) -> None:
|
||||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||||
is_fatal = self._is_fatal_auth_response(response)
|
is_fatal = self._is_fatal_auth_response(response)
|
||||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
(self.logger.error if is_fatal else self.logger.warning)("{} failed: {}", label, response)
|
||||||
|
|
||||||
async def _on_sync_error(self, response: SyncError) -> None:
|
async def _on_sync_error(self, response: SyncError) -> None:
|
||||||
self._log_response_error("sync", response)
|
self._log_response_error("sync", response)
|
||||||
if self._is_fatal_auth_response(response):
|
if self._is_fatal_auth_response(response):
|
||||||
# Auth errors won't recover by retry; stop the sync loop instead of
|
# Auth errors won't recover by retry; stop the sync loop instead of
|
||||||
# spamming the homeserver every 2s (#1851).
|
# spamming the homeserver every 2s (#1851).
|
||||||
logger.error("Matrix authentication failed irrecoverably; stopping sync loop")
|
self.logger.error("Authentication failed irrecoverably; stopping sync loop")
|
||||||
self._running = False
|
self._running = False
|
||||||
if self.client:
|
if self.client:
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
@ -625,7 +602,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||||
if isinstance(response, RoomTypingError):
|
if isinstance(response, RoomTypingError):
|
||||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
self.logger.debug("typing failed for {}: {}", room_id, response)
|
||||||
|
|
||||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||||
@ -796,7 +773,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return None
|
return None
|
||||||
response = await self.client.download(mxc=mxc_url)
|
response = await self.client.download(mxc=mxc_url)
|
||||||
if isinstance(response, DownloadError):
|
if isinstance(response, DownloadError):
|
||||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
self.logger.warning("download failed for {}: {}", mxc_url, response)
|
||||||
return None
|
return None
|
||||||
body = getattr(response, "body", None)
|
body = getattr(response, "body", None)
|
||||||
if isinstance(body, (bytes, bytearray)):
|
if isinstance(body, (bytes, bytearray)):
|
||||||
@ -821,7 +798,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||||
except (EncryptionError, ValueError, TypeError):
|
except (EncryptionError, ValueError, TypeError):
|
||||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
self.logger.warning("decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _fetch_media_attachment(
|
async def _fetch_media_attachment(
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -303,7 +302,7 @@ class MochatChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start Mochat channel workers and websocket connection."""
|
"""Start Mochat channel workers and websocket connection."""
|
||||||
if not self.config.claw_token:
|
if not self.config.claw_token:
|
||||||
logger.error("Mochat claw_token not configured")
|
self.logger.error("claw_token not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
@ -348,7 +347,7 @@ class MochatChannel(BaseChannel):
|
|||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send outbound message to session or panel."""
|
"""Send outbound message to session or panel."""
|
||||||
if not self.config.claw_token:
|
if not self.config.claw_token:
|
||||||
logger.warning("Mochat claw_token missing, skip send")
|
self.logger.warning("claw_token missing, skip send")
|
||||||
return
|
return
|
||||||
|
|
||||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||||
@ -360,7 +359,7 @@ class MochatChannel(BaseChannel):
|
|||||||
|
|
||||||
target = resolve_mochat_target(msg.chat_id)
|
target = resolve_mochat_target(msg.chat_id)
|
||||||
if not target.id:
|
if not target.id:
|
||||||
logger.warning("Mochat outbound target is empty")
|
self.logger.warning("outbound target is empty")
|
||||||
return
|
return
|
||||||
|
|
||||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||||
@ -371,8 +370,8 @@ class MochatChannel(BaseChannel):
|
|||||||
else:
|
else:
|
||||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||||
content, msg.reply_to)
|
content, msg.reply_to)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to send Mochat message: {}", e)
|
self.logger.exception("Failed to send message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# ---- config / init helpers ---------------------------------------------
|
# ---- config / init helpers ---------------------------------------------
|
||||||
@ -395,7 +394,7 @@ class MochatChannel(BaseChannel):
|
|||||||
|
|
||||||
async def _start_socket_client(self) -> bool:
|
async def _start_socket_client(self) -> bool:
|
||||||
if not SOCKETIO_AVAILABLE:
|
if not SOCKETIO_AVAILABLE:
|
||||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
self.logger.warning("python-socketio not installed, using polling fallback")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
serializer = "default"
|
serializer = "default"
|
||||||
@ -403,7 +402,7 @@ class MochatChannel(BaseChannel):
|
|||||||
if MSGPACK_AVAILABLE:
|
if MSGPACK_AVAILABLE:
|
||||||
serializer = "msgpack"
|
serializer = "msgpack"
|
||||||
else:
|
else:
|
||||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
self.logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||||
|
|
||||||
client = socketio.AsyncClient(
|
client = socketio.AsyncClient(
|
||||||
reconnection=True,
|
reconnection=True,
|
||||||
@ -416,7 +415,7 @@ class MochatChannel(BaseChannel):
|
|||||||
@client.event
|
@client.event
|
||||||
async def connect() -> None:
|
async def connect() -> None:
|
||||||
self._ws_connected, self._ws_ready = True, False
|
self._ws_connected, self._ws_ready = True, False
|
||||||
logger.info("Mochat websocket connected")
|
self.logger.info("websocket connected")
|
||||||
subscribed = await self._subscribe_all()
|
subscribed = await self._subscribe_all()
|
||||||
self._ws_ready = subscribed
|
self._ws_ready = subscribed
|
||||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||||
@ -426,12 +425,12 @@ class MochatChannel(BaseChannel):
|
|||||||
if not self._running:
|
if not self._running:
|
||||||
return
|
return
|
||||||
self._ws_connected = self._ws_ready = False
|
self._ws_connected = self._ws_ready = False
|
||||||
logger.warning("Mochat websocket disconnected")
|
self.logger.warning("websocket disconnected")
|
||||||
await self._ensure_fallback_workers()
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def connect_error(data: Any) -> None:
|
async def connect_error(data: Any) -> None:
|
||||||
logger.error("Mochat websocket connect error: {}", data)
|
self.logger.error("websocket connect error: {}", data)
|
||||||
|
|
||||||
@client.on("claw.session.events")
|
@client.on("claw.session.events")
|
||||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||||
@ -457,8 +456,8 @@ class MochatChannel(BaseChannel):
|
|||||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
self.logger.exception("Failed to connect websocket")
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
self._socket = None
|
self._socket = None
|
||||||
@ -493,7 +492,7 @@ class MochatChannel(BaseChannel):
|
|||||||
"limit": self.config.watch_limit,
|
"limit": self.config.watch_limit,
|
||||||
})
|
})
|
||||||
if not ack.get("result"):
|
if not ack.get("result"):
|
||||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
self.logger.error("subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
data = ack.get("data")
|
data = ack.get("data")
|
||||||
@ -515,7 +514,7 @@ class MochatChannel(BaseChannel):
|
|||||||
return True
|
return True
|
||||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||||
if not ack.get("result"):
|
if not ack.get("result"):
|
||||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
self.logger.error("subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -537,7 +536,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Mochat refresh failed: {}", e)
|
self.logger.warning("refresh failed: {}", e)
|
||||||
if self._fallback_mode:
|
if self._fallback_mode:
|
||||||
await self._ensure_fallback_workers()
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
@ -551,7 +550,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
response = await self._post_json("/api/claw/sessions/list", {})
|
response = await self._post_json("/api/claw/sessions/list", {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Mochat listSessions failed: {}", e)
|
self.logger.warning("listSessions failed: {}", e)
|
||||||
return
|
return
|
||||||
|
|
||||||
sessions = response.get("sessions")
|
sessions = response.get("sessions")
|
||||||
@ -585,7 +584,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
response = await self._post_json("/api/claw/groups/get", {})
|
response = await self._post_json("/api/claw/groups/get", {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
self.logger.warning("getWorkspaceGroup failed: {}", e)
|
||||||
return
|
return
|
||||||
|
|
||||||
raw_panels = response.get("panels")
|
raw_panels = response.get("panels")
|
||||||
@ -647,7 +646,7 @@ class MochatChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
self.logger.warning("watch fallback error ({}): {}", session_id, e)
|
||||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||||
|
|
||||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||||
@ -674,7 +673,7 @@ class MochatChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
self.logger.warning("panel polling error ({}): {}", panel_id, e)
|
||||||
await asyncio.sleep(sleep_s)
|
await asyncio.sleep(sleep_s)
|
||||||
|
|
||||||
# ---- inbound event processing ------------------------------------------
|
# ---- inbound event processing ------------------------------------------
|
||||||
@ -885,7 +884,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
self.logger.warning("Failed to read cursor file: {}", e)
|
||||||
return
|
return
|
||||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||||
if isinstance(cursors, dict):
|
if isinstance(cursors, dict):
|
||||||
@ -901,7 +900,7 @@ class MochatChannel(BaseChannel):
|
|||||||
"cursors": self._session_cursor,
|
"cursors": self._session_cursor,
|
||||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
self.logger.warning("Failed to save cursor file: {}", e)
|
||||||
|
|
||||||
# ---- HTTP helpers ------------------------------------------------------
|
# ---- HTTP helpers ------------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,6 @@ except ImportError: # pragma: no cover
|
|||||||
fcntl = None
|
fcntl = None
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -134,16 +133,16 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Teams webhook listener."""
|
"""Start the Teams webhook listener."""
|
||||||
if not MSTEAMS_AVAILABLE:
|
if not MSTEAMS_AVAILABLE:
|
||||||
logger.error("PyJWT not installed. Run: pip install nanobot-ai[msteams]")
|
self.logger.error("PyJWT not installed. Run: pip install nanobot-ai[msteams]")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.app_id or not self.config.app_password:
|
if not self.config.app_id or not self.config.app_password:
|
||||||
logger.error("MSTeams app_id/app_password not configured")
|
self.logger.error("app_id/app_password not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.validate_inbound_auth:
|
if not self.config.validate_inbound_auth:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"MSTeams inbound auth validation was explicitly DISABLED in config. "
|
"Inbound auth validation was explicitly DISABLED in config. "
|
||||||
"Anyone who knows the webhook URL can send messages as any user. "
|
"Anyone who knows the webhook URL can send messages as any user. "
|
||||||
"Only disable this for local development or controlled testing."
|
"Only disable this for local development or controlled testing."
|
||||||
)
|
)
|
||||||
@ -166,7 +165,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
raw = self.rfile.read(length) if length > 0 else b"{}"
|
raw = self.rfile.read(length) if length > 0 else b"{}"
|
||||||
payload = json.loads(raw.decode("utf-8"))
|
payload = json.loads(raw.decode("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("MSTeams invalid request body: {}", e)
|
channel.logger.warning("Invalid request body: {}", e)
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
return
|
return
|
||||||
@ -180,7 +179,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
fut.result(timeout=15)
|
fut.result(timeout=15)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("MSTeams inbound auth validation failed: {}", e)
|
channel.logger.warning("Inbound auth validation failed: {}", e)
|
||||||
self.send_response(401)
|
self.send_response(401)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
@ -193,7 +192,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
fut.result(timeout=15)
|
fut.result(timeout=15)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("MSTeams activity handling failed: {}", e)
|
channel.logger.warning("Activity handling failed: {}", e)
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
@ -211,8 +210,8 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
self._server_thread.start()
|
self._server_thread.start()
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"MSTeams webhook listening on http://{}:{}{}",
|
"Webhook listening on http://{}:{}{}",
|
||||||
self.config.host,
|
self.config.host,
|
||||||
self.config.port,
|
self.config.port,
|
||||||
self.config.path,
|
self.config.path,
|
||||||
@ -261,10 +260,10 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
resp = await self._http.post(base_url, headers=headers, json=payload)
|
resp = await self._http.post(base_url, headers=headers, json=payload)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info("MSTeams message sent to {}", ref.conversation_id)
|
self.logger.info("Message sent to {}", ref.conversation_id)
|
||||||
self._touch_conversation_ref(str(msg.chat_id), persist=True)
|
self._touch_conversation_ref(str(msg.chat_id), persist=True)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("MSTeams send failed: {}", e)
|
self.logger.exception("Send failed")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _handle_activity(self, activity: dict[str, Any]) -> None:
|
async def _handle_activity(self, activity: dict[str, Any]) -> None:
|
||||||
@ -291,18 +290,18 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
|
|
||||||
# DM-only MVP: ignore group/channel traffic for now
|
# DM-only MVP: ignore group/channel traffic for now
|
||||||
if conversation_type and conversation_type not in ("personal", ""):
|
if conversation_type and conversation_type not in ("personal", ""):
|
||||||
logger.debug("MSTeams ignoring non-DM conversation {}", conversation_type)
|
self.logger.debug("Ignoring non-DM conversation {}", conversation_type)
|
||||||
return
|
return
|
||||||
|
|
||||||
text = self._sanitize_inbound_text(activity)
|
text = self._sanitize_inbound_text(activity)
|
||||||
if not text:
|
if not text:
|
||||||
text = self.config.mention_only_response.strip()
|
text = self.config.mention_only_response.strip()
|
||||||
if not text:
|
if not text:
|
||||||
logger.debug("MSTeams ignoring empty message after Teams text sanitization")
|
self.logger.debug("Ignoring empty message after Teams text sanitization")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Access denied for sender {} on channel {}. "
|
"Access denied for sender {} on channel {}. "
|
||||||
"Add them to allowFrom list in config to grant access.",
|
"Add them to allowFrom list in config to grant access.",
|
||||||
sender_id, self.name,
|
sender_id, self.name,
|
||||||
@ -554,7 +553,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
if isinstance(loaded, dict):
|
if isinstance(loaded, dict):
|
||||||
main_data = loaded
|
main_data = loaded
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load MSTeams conversation refs: {}", e)
|
self.logger.warning("Failed to load conversation refs: {}", e)
|
||||||
|
|
||||||
if meta_exists:
|
if meta_exists:
|
||||||
try:
|
try:
|
||||||
@ -562,7 +561,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
if isinstance(loaded_meta, dict):
|
if isinstance(loaded_meta, dict):
|
||||||
meta_data = loaded_meta
|
meta_data = loaded_meta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load MSTeams conversation refs metadata: {}", e)
|
self.logger.warning("Failed to load conversation refs metadata: {}", e)
|
||||||
|
|
||||||
return main_data, meta_data, meta_exists
|
return main_data, meta_data, meta_exists
|
||||||
|
|
||||||
@ -660,8 +659,8 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
|
|
||||||
for key in keys_to_drop:
|
for key in keys_to_drop:
|
||||||
self._conversation_refs.pop(key, None)
|
self._conversation_refs.pop(key, None)
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"MSTeams pruned {} stale/unsupported conversation refs (ttl={} days)",
|
"Pruned {} stale/unsupported conversation refs (ttl={} days)",
|
||||||
len(keys_to_drop),
|
len(keys_to_drop),
|
||||||
ttl_days,
|
ttl_days,
|
||||||
)
|
)
|
||||||
@ -742,7 +741,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
self._write_json_atomically(self._refs_path, refs_data)
|
self._write_json_atomically(self._refs_path, refs_data)
|
||||||
self._write_json_atomically(self._refs_meta_path, refs_meta)
|
self._write_json_atomically(self._refs_meta_path, refs_meta)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to save MSTeams conversation refs: {}", e)
|
self.logger.warning("Failed to save conversation refs: {}", e)
|
||||||
|
|
||||||
def _save_refs(self, *, prune: bool = True) -> None:
|
def _save_refs(self, *, prune: bool = True) -> None:
|
||||||
"""Persist conversation references."""
|
"""Persist conversation references."""
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
from nanobot.security.network import validate_url_target
|
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
@ -187,24 +187,25 @@ class QQChannel(BaseChannel):
|
|||||||
root = Path.home() / ".nanobot" / "media" / "qq"
|
root = Path.home() / ".nanobot" / "media" / "qq"
|
||||||
|
|
||||||
root.mkdir(parents=True, exist_ok=True)
|
root.mkdir(parents=True, exist_ok=True)
|
||||||
logger.info("QQ media directory: {}", str(root))
|
self.logger.info("media directory: {}", str(root))
|
||||||
return root
|
return root
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the QQ bot with auto-reconnect loop."""
|
"""Start the QQ bot with auto-reconnect loop."""
|
||||||
|
redirect_lib_logging("botpy", level="WARNING")
|
||||||
if not QQ_AVAILABLE:
|
if not QQ_AVAILABLE:
|
||||||
logger.error("QQ SDK not installed. Run: pip install qq-botpy")
|
self.logger.error("SDK not installed. Run: pip install qq-botpy")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.app_id or not self.config.secret:
|
if not self.config.app_id or not self.config.secret:
|
||||||
logger.error("QQ app_id and secret not configured")
|
self.logger.error("app_id and secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
||||||
|
|
||||||
self._client = _make_bot_class(self)()
|
self._client = _make_bot_class(self)()
|
||||||
logger.info("QQ bot started (C2C & Group supported)")
|
self.logger.info("bot started (C2C & Group supported)")
|
||||||
await self._run_bot()
|
await self._run_bot()
|
||||||
|
|
||||||
async def _run_bot(self) -> None:
|
async def _run_bot(self) -> None:
|
||||||
@ -213,9 +214,9 @@ class QQChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("QQ bot error: {}", e)
|
self.logger.warning("bot error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
self.logger.info("Reconnecting bot in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
@ -231,7 +232,7 @@ class QQChannel(BaseChannel):
|
|||||||
await self._http.close()
|
await self._http.close()
|
||||||
self._http = None
|
self._http = None
|
||||||
|
|
||||||
logger.info("QQ bot stopped")
|
self.logger.info("bot stopped")
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# Outbound (send)
|
# Outbound (send)
|
||||||
@ -241,7 +242,7 @@ class QQChannel(BaseChannel):
|
|||||||
"""Send attachments first, then text."""
|
"""Send attachments first, then text."""
|
||||||
try:
|
try:
|
||||||
if not self._client:
|
if not self._client:
|
||||||
logger.warning("QQ client not initialized")
|
self.logger.warning("client not initialized")
|
||||||
return
|
return
|
||||||
|
|
||||||
msg_id = msg.metadata.get("message_id")
|
msg_id = msg.metadata.get("message_id")
|
||||||
@ -281,7 +282,7 @@ class QQChannel(BaseChannel):
|
|||||||
# Network / transport errors — propagate so ChannelManager can retry
|
# Network / transport errors — propagate so ChannelManager can retry
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
self.logger.exception("Error sending message to chat_id={}", msg.chat_id)
|
||||||
|
|
||||||
async def _send_text_only(
|
async def _send_text_only(
|
||||||
self,
|
self,
|
||||||
@ -339,7 +340,7 @@ class QQChannel(BaseChannel):
|
|||||||
srv_send_msg=False,
|
srv_send_msg=False,
|
||||||
)
|
)
|
||||||
if not media_obj:
|
if not media_obj:
|
||||||
logger.error("QQ media upload failed: empty response")
|
self.logger.error("media upload failed: empty response")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._msg_seq += 1
|
self._msg_seq += 1
|
||||||
@ -360,15 +361,15 @@ class QQChannel(BaseChannel):
|
|||||||
media=media_obj,
|
media=media_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("QQ media sent: {}", filename)
|
self.logger.info("media sent: {}", filename)
|
||||||
return True
|
return True
|
||||||
except (aiohttp.ClientError, OSError) as e:
|
except (aiohttp.ClientError, OSError) as e:
|
||||||
# Network / transport errors — propagate for retry by caller
|
# Network / transport errors — propagate for retry by caller
|
||||||
logger.warning("QQ send media network error filename={} err={}", filename, e)
|
self.logger.warning("send media network error filename={} err={}", filename, e)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# API-level or other non-network errors — return False so send() can fallback
|
# API-level or other non-network errors — return False so send() can fallback
|
||||||
logger.error("QQ send media failed filename={} err={}", filename, e)
|
self.logger.exception("send media failed filename={}", filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | None]:
|
async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | None]:
|
||||||
@ -389,19 +390,19 @@ class QQChannel(BaseChannel):
|
|||||||
local_path = Path(os.path.expanduser(media_ref))
|
local_path = Path(os.path.expanduser(media_ref))
|
||||||
|
|
||||||
if not local_path.is_file():
|
if not local_path.is_file():
|
||||||
logger.warning("QQ outbound media file not found: {}", str(local_path))
|
self.logger.warning("outbound media file not found: {}", str(local_path))
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
data = await asyncio.to_thread(local_path.read_bytes)
|
data = await asyncio.to_thread(local_path.read_bytes)
|
||||||
return data, local_path.name
|
return data, local_path.name
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("QQ outbound media read error ref={} err={}", media_ref, e)
|
self.logger.warning("outbound media read error ref={} err={}", media_ref, e)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Remote URL
|
# Remote URL
|
||||||
ok, err = validate_url_target(media_ref)
|
ok, err = validate_url_target(media_ref)
|
||||||
if not ok:
|
if not ok:
|
||||||
logger.warning("QQ outbound media URL validation failed url={} err={}", media_ref, err)
|
self.logger.warning("outbound media URL validation failed url={} err={}", media_ref, err)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if not self._http:
|
if not self._http:
|
||||||
@ -409,8 +410,8 @@ class QQChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
async with self._http.get(media_ref, allow_redirects=True) as resp:
|
async with self._http.get(media_ref, allow_redirects=True) as resp:
|
||||||
if resp.status >= 400:
|
if resp.status >= 400:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"QQ outbound media download failed status={} url={}",
|
"outbound media download failed status={} url={}",
|
||||||
resp.status,
|
resp.status,
|
||||||
media_ref,
|
media_ref,
|
||||||
)
|
)
|
||||||
@ -421,7 +422,7 @@ class QQChannel(BaseChannel):
|
|||||||
filename = os.path.basename(urlparse(media_ref).path) or "file.bin"
|
filename = os.path.basename(urlparse(media_ref).path) or "file.bin"
|
||||||
return data, filename
|
return data, filename
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("QQ outbound media download error url={} err={}", media_ref, e)
|
self.logger.warning("outbound media download error url={} err={}", media_ref, e)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# https://github.com/tencent-connect/botpy/issues/198
|
# https://github.com/tencent-connect/botpy/issues/198
|
||||||
@ -525,7 +526,7 @@ class QQChannel(BaseChannel):
|
|||||||
content=self.config.ack_message,
|
content=self.config.ack_message,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
self.logger.debug("ack message failed for chat_id={}", chat_id)
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=user_id,
|
sender_id=user_id,
|
||||||
@ -538,7 +539,7 @@ class QQChannel(BaseChannel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
|
self.logger.exception("Error handling inbound message id={}", getattr(data, "id", "?"))
|
||||||
|
|
||||||
async def _handle_attachments(
|
async def _handle_attachments(
|
||||||
self,
|
self,
|
||||||
@ -557,7 +558,7 @@ class QQChannel(BaseChannel):
|
|||||||
filename = getattr(att, "filename", None) or ""
|
filename = getattr(att, "filename", None) or ""
|
||||||
ctype = getattr(att, "content_type", None) or ""
|
ctype = getattr(att, "content_type", None) or ""
|
||||||
|
|
||||||
logger.info("Downloading file from QQ: {}", filename or url)
|
self.logger.info("Downloading file: {}", filename or url)
|
||||||
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
||||||
|
|
||||||
att_meta.append(
|
att_meta.append(
|
||||||
@ -608,7 +609,7 @@ class QQChannel(BaseChannel):
|
|||||||
allow_redirects=True,
|
allow_redirects=True,
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
logger.warning("QQ download failed: status={} url={}", resp.status, url)
|
self.logger.warning("download failed: status={} url={}", resp.status, url)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ctype = (resp.headers.get("Content-Type") or "").lower()
|
ctype = (resp.headers.get("Content-Type") or "").lower()
|
||||||
@ -662,8 +663,8 @@ class QQChannel(BaseChannel):
|
|||||||
continue
|
continue
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
if downloaded > max_bytes:
|
if downloaded > max_bytes:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"QQ download exceeded max_bytes={} url={} -> abort",
|
"download exceeded max_bytes={} url={} -> abort",
|
||||||
max_bytes,
|
max_bytes,
|
||||||
url,
|
url,
|
||||||
)
|
)
|
||||||
@ -675,11 +676,11 @@ class QQChannel(BaseChannel):
|
|||||||
# Atomic rename
|
# Atomic rename
|
||||||
await asyncio.to_thread(os.replace, tmp_path, target)
|
await asyncio.to_thread(os.replace, tmp_path, target)
|
||||||
tmp_path = None # mark as moved
|
tmp_path = None # mark as moved
|
||||||
logger.info("QQ file saved: {}", str(target))
|
self.logger.info("file saved: {}", str(target))
|
||||||
return str(target)
|
return str(target)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("QQ download error: {}", e)
|
self.logger.exception("download error")
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
# Cleanup partial file
|
# Cleanup partial file
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
@ -84,10 +83,10 @@ class SlackChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Slack Socket Mode client."""
|
"""Start the Slack Socket Mode client."""
|
||||||
if not self.config.bot_token or not self.config.app_token:
|
if not self.config.bot_token or not self.config.app_token:
|
||||||
logger.error("Slack bot/app token not configured")
|
self.logger.error("bot/app token not configured")
|
||||||
return
|
return
|
||||||
if self.config.mode != "socket":
|
if self.config.mode != "socket":
|
||||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
self.logger.error("Unsupported mode: {}", self.config.mode)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
@ -104,11 +103,11 @@ class SlackChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
auth = await self._web_client.auth_test()
|
auth = await self._web_client.auth_test()
|
||||||
self._bot_user_id = auth.get("user_id")
|
self._bot_user_id = auth.get("user_id")
|
||||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
self.logger.info("bot connected as {}", self._bot_user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Slack auth_test failed: {}", e)
|
self.logger.warning("auth_test failed: {}", e)
|
||||||
|
|
||||||
logger.info("Starting Slack Socket Mode client...")
|
self.logger.info("Starting Socket Mode client...")
|
||||||
await self._socket_client.connect()
|
await self._socket_client.connect()
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
@ -121,13 +120,13 @@ class SlackChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._socket_client.close()
|
await self._socket_client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Slack socket close failed: {}", e)
|
self.logger.warning("socket close failed: {}", e)
|
||||||
self._socket_client = None
|
self._socket_client = None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through Slack."""
|
"""Send a message through Slack."""
|
||||||
if not self._web_client:
|
if not self._web_client:
|
||||||
logger.warning("Slack client not running")
|
self.logger.warning("client not running")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
|
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
|
||||||
@ -162,16 +161,16 @@ class SlackChannel(BaseChannel):
|
|||||||
file=media_path,
|
file=media_path,
|
||||||
thread_ts=thread_ts_param,
|
thread_ts=thread_ts_param,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
self.logger.exception("Failed to upload file {}", media_path)
|
||||||
|
|
||||||
# Update reaction emoji when the final (non-progress) response is sent
|
# Update reaction emoji when the final (non-progress) response is sent
|
||||||
if not (msg.metadata or {}).get("_progress"):
|
if not (msg.metadata or {}).get("_progress"):
|
||||||
event = slack_meta.get("event", {})
|
event = slack_meta.get("event", {})
|
||||||
await self._update_react_emoji(origin_chat_id, event.get("ts"))
|
await self._update_react_emoji(origin_chat_id, event.get("ts"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending Slack message: {}", e)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _resolve_target_chat_id(self, target: str) -> str:
|
async def _resolve_target_chat_id(self, target: str) -> str:
|
||||||
@ -328,8 +327,8 @@ class SlackChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Debug: log basic event shape
|
# Debug: log basic event shape
|
||||||
logger.debug(
|
self.logger.debug(
|
||||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
"event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||||
event_type,
|
event_type,
|
||||||
subtype,
|
subtype,
|
||||||
sender_id,
|
sender_id,
|
||||||
@ -371,7 +370,7 @@ class SlackChannel(BaseChannel):
|
|||||||
timestamp=event.get("ts"),
|
timestamp=event.get("ts"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Slack reactions_add failed: {}", e)
|
self.logger.debug("reactions_add failed: {}", e)
|
||||||
|
|
||||||
# Thread-scoped session key whenever the user is in a real thread
|
# Thread-scoped session key whenever the user is in a real thread
|
||||||
# (raw_thread_ts is set). DM threads get their own session, separate
|
# (raw_thread_ts is set). DM threads get their own session, separate
|
||||||
@ -420,7 +419,7 @@ class SlackChannel(BaseChannel):
|
|||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error handling Slack message from {}", sender_id)
|
self.logger.exception("Error handling message from {}", sender_id)
|
||||||
|
|
||||||
async def _download_slack_file(self, file_info: dict[str, Any]) -> tuple[str | None, str]:
|
async def _download_slack_file(self, file_info: dict[str, Any]) -> tuple[str | None, str]:
|
||||||
"""Download a Slack private file to the local media directory."""
|
"""Download a Slack private file to the local media directory."""
|
||||||
@ -453,7 +452,7 @@ class SlackChannel(BaseChannel):
|
|||||||
path.write_bytes(response.content)
|
path.write_bytes(response.content)
|
||||||
return str(path), marker
|
return str(path), marker
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to download Slack file {}: {}", file_id, e)
|
self.logger.warning("Failed to download file {}: {}", file_id, e)
|
||||||
return None, self._download_failure_marker(marker_type, name, "download failed")
|
return None, self._download_failure_marker(marker_type, name, "download failed")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -500,7 +499,7 @@ class SlackChannel(BaseChannel):
|
|||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error handling Slack button click from {}", sender_id)
|
self.logger.exception("Error handling button click from {}", sender_id)
|
||||||
|
|
||||||
async def _with_thread_context(
|
async def _with_thread_context(
|
||||||
self,
|
self,
|
||||||
@ -537,7 +536,7 @@ class SlackChannel(BaseChannel):
|
|||||||
limit=max(1, self.config.thread_context_limit),
|
limit=max(1, self.config.thread_context_limit),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Slack thread context unavailable for {}: {}", key, e)
|
self.logger.warning("thread context unavailable for {}: {}", key, e)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
lines = self._format_thread_context(
|
lines = self._format_thread_context(
|
||||||
@ -597,7 +596,7 @@ class SlackChannel(BaseChannel):
|
|||||||
timestamp=ts,
|
timestamp=ts,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Slack reactions_remove failed: {}", e)
|
self.logger.debug("reactions_remove failed: {}", e)
|
||||||
if self.config.done_emoji:
|
if self.config.done_emoji:
|
||||||
try:
|
try:
|
||||||
await self._web_client.reactions_add(
|
await self._web_client.reactions_add(
|
||||||
@ -606,7 +605,7 @@ class SlackChannel(BaseChannel):
|
|||||||
timestamp=ts,
|
timestamp=ts,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Slack done reaction failed: {}", e)
|
self.logger.debug("done reaction failed: {}", e)
|
||||||
|
|
||||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||||
if channel_type == "im":
|
if channel_type == "im":
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from telegram import (
|
from telegram import (
|
||||||
BotCommand,
|
BotCommand,
|
||||||
@ -320,7 +319,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
if not self.config.token:
|
if not self.config.token:
|
||||||
logger.error("Telegram bot token not configured")
|
self.logger.error("bot token not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
@ -382,11 +381,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
if self.config.inline_keyboards:
|
if self.config.inline_keyboards:
|
||||||
self._app.add_handler(CallbackQueryHandler(self._on_callback_query))
|
self._app.add_handler(CallbackQueryHandler(self._on_callback_query))
|
||||||
allowed_updates = ["message", "callback_query"]
|
allowed_updates = ["message", "callback_query"]
|
||||||
logger.debug("Telegram inline keyboards enabled")
|
self.logger.debug("inline keyboards enabled")
|
||||||
else:
|
else:
|
||||||
allowed_updates = ["message"]
|
allowed_updates = ["message"]
|
||||||
|
|
||||||
logger.info("Starting Telegram bot (polling mode)...")
|
self.logger.info("Starting bot (polling mode)...")
|
||||||
|
|
||||||
# Initialize and start polling
|
# Initialize and start polling
|
||||||
await self._app.initialize()
|
await self._app.initialize()
|
||||||
@ -396,13 +395,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
bot_info = await self._app.bot.get_me()
|
bot_info = await self._app.bot.get_me()
|
||||||
self._bot_user_id = getattr(bot_info, "id", None)
|
self._bot_user_id = getattr(bot_info, "id", None)
|
||||||
self._bot_username = getattr(bot_info, "username", None)
|
self._bot_username = getattr(bot_info, "username", None)
|
||||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
self.logger.info("bot @{} connected", bot_info.username)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||||
logger.debug("Telegram bot commands registered")
|
self.logger.debug("bot commands registered")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to register bot commands: {}", e)
|
self.logger.warning("Failed to register bot commands: {}", e)
|
||||||
|
|
||||||
# Start polling (this runs until stopped)
|
# Start polling (this runs until stopped)
|
||||||
await self._app.updater.start_polling(
|
await self._app.updater.start_polling(
|
||||||
@ -429,7 +428,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._media_group_buffers.clear()
|
self._media_group_buffers.clear()
|
||||||
|
|
||||||
if self._app:
|
if self._app:
|
||||||
logger.info("Stopping Telegram bot...")
|
self.logger.info("Stopping bot...")
|
||||||
await self._app.updater.stop()
|
await self._app.updater.stop()
|
||||||
await self._app.stop()
|
await self._app.stop()
|
||||||
await self._app.shutdown()
|
await self._app.shutdown()
|
||||||
@ -456,7 +455,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through Telegram."""
|
"""Send a message through Telegram."""
|
||||||
if not self._app:
|
if not self._app:
|
||||||
logger.warning("Telegram bot not running")
|
self.logger.warning("bot not running")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Only stop typing indicator and remove reaction for final responses
|
# Only stop typing indicator and remove reaction for final responses
|
||||||
@ -469,7 +468,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
chat_id = int(msg.chat_id)
|
chat_id = int(msg.chat_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
self.logger.exception("Invalid chat_id: {}", msg.chat_id)
|
||||||
return
|
return
|
||||||
reply_to_message_id = msg.metadata.get("message_id")
|
reply_to_message_id = msg.metadata.get("message_id")
|
||||||
message_thread_id = msg.metadata.get("message_thread_id")
|
message_thread_id = msg.metadata.get("message_thread_id")
|
||||||
@ -533,9 +532,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
**extra,
|
**extra,
|
||||||
**send_kwargs,
|
**send_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
filename = media_path.rsplit("/", 1)[-1]
|
filename = media_path.rsplit("/", 1)[-1]
|
||||||
logger.error("Failed to send media {}: {}", media_path, e)
|
self.logger.exception("Failed to send media {}", media_path)
|
||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=f"[Failed to send: {filename}]",
|
text=f"[Failed to send: {filename}]",
|
||||||
@ -572,8 +571,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
if attempt == _SEND_MAX_RETRIES:
|
if attempt == _SEND_MAX_RETRIES:
|
||||||
raise
|
raise
|
||||||
delay = _SEND_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
delay = _SEND_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Telegram timeout (attempt {}/{}), retrying in {:.1f}s",
|
"timeout (attempt {}/{}), retrying in {:.1f}s",
|
||||||
attempt, _SEND_MAX_RETRIES, delay,
|
attempt, _SEND_MAX_RETRIES, delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
@ -581,8 +580,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
if attempt == _SEND_MAX_RETRIES:
|
if attempt == _SEND_MAX_RETRIES:
|
||||||
raise
|
raise
|
||||||
delay = float(e.retry_after)
|
delay = float(e.retry_after)
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s",
|
"Flood Control (attempt {}/{}), retrying in {:.1f}s",
|
||||||
attempt, _SEND_MAX_RETRIES, delay,
|
attempt, _SEND_MAX_RETRIES, delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
@ -607,7 +606,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
**(thread_kwargs or {}),
|
**(thread_kwargs or {}),
|
||||||
)
|
)
|
||||||
except BadRequest as e:
|
except BadRequest as e:
|
||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
self.logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._call_with_retry(
|
await self._call_with_retry(
|
||||||
self._app.bot.send_message,
|
self._app.bot.send_message,
|
||||||
@ -617,8 +616,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
reply_markup=reply_markup,
|
reply_markup=reply_markup,
|
||||||
**(thread_kwargs or {}),
|
**(thread_kwargs or {}),
|
||||||
)
|
)
|
||||||
except Exception as e2:
|
except Exception:
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -666,10 +665,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Network errors (TimedOut, NetworkError) should propagate immediately
|
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||||
# to avoid doubling connection demand during pool exhaustion.
|
# to avoid doubling connection demand during pool exhaustion.
|
||||||
if self._is_not_modified_error(e):
|
if self._is_not_modified_error(e):
|
||||||
logger.debug("Final stream edit already applied for {}", chat_id)
|
self.logger.debug("Final stream edit already applied for {}", chat_id)
|
||||||
self._stream_bufs.pop(chat_id, None)
|
self._stream_bufs.pop(chat_id, None)
|
||||||
return
|
return
|
||||||
logger.debug("Final stream edit failed (HTML), trying plain: {}", e)
|
self.logger.debug("Final stream edit failed (HTML), trying plain: {}", e)
|
||||||
# Fall back to raw markdown (not HTML) so users don't see raw tags.
|
# Fall back to raw markdown (not HTML) so users don't see raw tags.
|
||||||
primary_plain = split_message(raw_text, TELEGRAM_MAX_MESSAGE_LEN)[0] if len(raw_text) > TELEGRAM_MAX_MESSAGE_LEN else raw_text
|
primary_plain = split_message(raw_text, TELEGRAM_MAX_MESSAGE_LEN)[0] if len(raw_text) > TELEGRAM_MAX_MESSAGE_LEN else raw_text
|
||||||
try:
|
try:
|
||||||
@ -680,9 +679,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
if self._is_not_modified_error(e2):
|
if self._is_not_modified_error(e2):
|
||||||
logger.debug("Final stream plain edit already applied for {}", chat_id)
|
self.logger.debug("Final stream plain edit already applied for {}", chat_id)
|
||||||
else:
|
else:
|
||||||
logger.warning("Final stream edit failed: {}", e2)
|
self.logger.warning("Final stream edit failed: {}", e2)
|
||||||
raise # Let ChannelManager handle retry
|
raise # Let ChannelManager handle retry
|
||||||
for extra_html_chunk in extra_html_chunks:
|
for extra_html_chunk in extra_html_chunks:
|
||||||
try:
|
try:
|
||||||
@ -724,7 +723,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
buf.message_id = sent.message_id
|
buf.message_id = sent.message_id
|
||||||
buf.last_edit = now
|
buf.last_edit = now
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Stream initial send failed: {}", e)
|
self.logger.warning("Stream initial send failed: {}", e)
|
||||||
raise # Let ChannelManager handle retry
|
raise # Let ChannelManager handle retry
|
||||||
elif (now - buf.last_edit) >= self.config.stream_edit_interval:
|
elif (now - buf.last_edit) >= self.config.stream_edit_interval:
|
||||||
if len(buf.text) > TELEGRAM_MAX_MESSAGE_LEN:
|
if len(buf.text) > TELEGRAM_MAX_MESSAGE_LEN:
|
||||||
@ -743,7 +742,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
if self._is_not_modified_error(e):
|
if self._is_not_modified_error(e):
|
||||||
buf.last_edit = now
|
buf.last_edit = now
|
||||||
return
|
return
|
||||||
logger.warning("Stream edit failed: {}", e)
|
self.logger.warning("Stream edit failed: {}", e)
|
||||||
raise # Let ChannelManager handle retry
|
raise # Let ChannelManager handle retry
|
||||||
|
|
||||||
async def _flush_stream_overflow(
|
async def _flush_stream_overflow(
|
||||||
@ -769,7 +768,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not self._is_not_modified_error(e):
|
if not self._is_not_modified_error(e):
|
||||||
logger.warning("Stream overflow edit failed: {}", e)
|
self.logger.warning("Stream overflow edit failed: {}", e)
|
||||||
raise
|
raise
|
||||||
for chunk in chunks[1:-1]:
|
for chunk in chunks[1:-1]:
|
||||||
await self._call_with_retry(
|
await self._call_with_retry(
|
||||||
@ -903,12 +902,12 @@ class TelegramChannel(BaseChannel):
|
|||||||
if media_type in ("voice", "audio"):
|
if media_type in ("voice", "audio"):
|
||||||
transcription = await self.transcribe_audio(file_path)
|
transcription = await self.transcribe_audio(file_path)
|
||||||
if transcription:
|
if transcription:
|
||||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
self.logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||||
return [path_str], [f"[transcription: {transcription}]"]
|
return [path_str], [f"[transcription: {transcription}]"]
|
||||||
return [path_str], [f"[{media_type}: {path_str}]"]
|
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||||
return [path_str], [f"[{media_type}: {path_str}]"]
|
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to download message media: {}", e)
|
self.logger.warning("Failed to download message media: {}", e)
|
||||||
if add_failure_content:
|
if add_failure_content:
|
||||||
return [], [f"[{media_type}: download failed]"]
|
return [], [f"[{media_type}: download failed]"]
|
||||||
return [], []
|
return [], []
|
||||||
@ -1056,7 +1055,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
media_paths.extend(current_media_paths)
|
media_paths.extend(current_media_paths)
|
||||||
content_parts.extend(current_media_parts)
|
content_parts.extend(current_media_parts)
|
||||||
if current_media_paths:
|
if current_media_paths:
|
||||||
logger.debug("Downloaded message media to {}", current_media_paths[0])
|
self.logger.debug("Downloaded message media to {}", current_media_paths[0])
|
||||||
|
|
||||||
# Reply context: text and/or media from the replied-to message
|
# Reply context: text and/or media from the replied-to message
|
||||||
reply = getattr(message, "reply_to_message", None)
|
reply = getattr(message, "reply_to_message", None)
|
||||||
@ -1065,13 +1064,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
reply_media, reply_media_parts = await self._download_message_media(reply)
|
reply_media, reply_media_parts = await self._download_message_media(reply)
|
||||||
if reply_media:
|
if reply_media:
|
||||||
media_paths = reply_media + media_paths
|
media_paths = reply_media + media_paths
|
||||||
logger.debug("Attached replied-to media: {}", reply_media[0])
|
self.logger.debug("Attached replied-to media: {}", reply_media[0])
|
||||||
tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
|
tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
|
||||||
if tag:
|
if tag:
|
||||||
content_parts.insert(0, tag)
|
content_parts.insert(0, tag)
|
||||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||||
|
|
||||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
self.logger.debug("message from {}: {}...", sender_id, content[:50])
|
||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
metadata = self._build_message_metadata(message, user)
|
metadata = self._build_message_metadata(message, user)
|
||||||
@ -1150,7 +1149,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
reaction=[ReactionTypeEmoji(emoji=emoji)],
|
reaction=[ReactionTypeEmoji(emoji=emoji)],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Telegram reaction failed: {}", e)
|
self.logger.debug("reaction failed: {}", e)
|
||||||
|
|
||||||
async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
|
async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
|
||||||
"""Remove emoji reaction from a message (best-effort, non-blocking)."""
|
"""Remove emoji reaction from a message (best-effort, non-blocking)."""
|
||||||
@ -1163,7 +1162,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
reaction=[],
|
reaction=[],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Telegram reaction removal failed: {}", e)
|
self.logger.debug("reaction removal failed: {}", e)
|
||||||
|
|
||||||
async def _typing_loop(self, chat_id: str) -> None:
|
async def _typing_loop(self, chat_id: str) -> None:
|
||||||
"""Repeatedly send 'typing' action until cancelled."""
|
"""Repeatedly send 'typing' action until cancelled."""
|
||||||
@ -1173,7 +1172,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||||
await asyncio.sleep(4)
|
await asyncio.sleep(4)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
self.logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_telegram_error(exc: Exception) -> str:
|
def _format_telegram_error(exc: Exception) -> str:
|
||||||
@ -1193,18 +1192,18 @@ class TelegramChannel(BaseChannel):
|
|||||||
"""Keep long-polling network failures to a single readable line."""
|
"""Keep long-polling network failures to a single readable line."""
|
||||||
summary = self._format_telegram_error(exc)
|
summary = self._format_telegram_error(exc)
|
||||||
if isinstance(exc, (NetworkError, TimedOut)):
|
if isinstance(exc, (NetworkError, TimedOut)):
|
||||||
logger.warning("Telegram polling network issue: {}", summary)
|
self.logger.warning("polling network issue: {}", summary)
|
||||||
else:
|
else:
|
||||||
logger.error("Telegram polling error: {}", summary)
|
self.logger.error("polling error: {}", summary)
|
||||||
|
|
||||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Log polling / handler errors instead of silently swallowing them."""
|
"""Log polling / handler errors instead of silently swallowing them."""
|
||||||
summary = self._format_telegram_error(context.error)
|
summary = self._format_telegram_error(context.error)
|
||||||
|
|
||||||
if isinstance(context.error, (NetworkError, TimedOut)):
|
if isinstance(context.error, (NetworkError, TimedOut)):
|
||||||
logger.warning("Telegram network issue: {}", summary)
|
self.logger.warning("network issue: {}", summary)
|
||||||
else:
|
else:
|
||||||
logger.error("Telegram error: {}", summary)
|
self.logger.error("error: {}", summary)
|
||||||
|
|
||||||
def _get_extension(
|
def _get_extension(
|
||||||
self,
|
self,
|
||||||
@ -1265,7 +1264,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id = query.message.chat_id if query.message else None
|
chat_id = query.message.chat_id if query.message else None
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
if not chat_id:
|
if not chat_id:
|
||||||
logger.warning("Callback query without chat_id")
|
self.logger.warning("Callback query without chat_id")
|
||||||
return
|
return
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
return
|
return
|
||||||
@ -1274,7 +1273,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
if query.message:
|
if query.message:
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await query.message.edit_reply_markup(reply_markup=None)
|
await query.message.edit_reply_markup(reply_markup=None)
|
||||||
logger.debug("Inline button tap from {}: {}", sender_id, button_label)
|
self.logger.debug("Inline button tap from {}: {}", sender_id, button_label)
|
||||||
self._start_typing(str(chat_id))
|
self._start_typing(str(chat_id))
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
|
|||||||
@ -448,7 +448,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
except ConnectionClosed:
|
except ConnectionClosed:
|
||||||
self._cleanup_connection(connection)
|
self._cleanup_connection(connection)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("websocket: failed to send {} event: {}", event, e)
|
self.logger.warning("failed to send {} event: {}", event, e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
@ -464,7 +464,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
return None
|
return None
|
||||||
if not cert or not key:
|
if not cert or not key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
"ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||||
)
|
)
|
||||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
@ -501,14 +501,14 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if not _issue_route_secret_matches(request.headers, secret):
|
if not _issue_route_secret_matches(request.headers, secret):
|
||||||
return connection.respond(401, "Unauthorized")
|
return connection.respond(401, "Unauthorized")
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
"token_issue_path is set but token_issue_secret is empty; "
|
||||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||||
)
|
)
|
||||||
self._purge_expired_issued_tokens()
|
self._purge_expired_issued_tokens()
|
||||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||||
logger.error(
|
self.logger.error(
|
||||||
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
|
"too many outstanding issued tokens ({}), rejecting issuance",
|
||||||
len(self._issued_tokens),
|
len(self._issued_tokens),
|
||||||
)
|
)
|
||||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||||
@ -821,7 +821,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
|
staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
|
||||||
shutil.copyfile(path, staged)
|
shutil.copyfile(path, staged)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
logger.warning("websocket: failed to stage outbound media {}: {}", path, exc)
|
self.logger.warning("failed to stage outbound media {}: {}", path, exc)
|
||||||
return None
|
return None
|
||||||
signed = self._sign_media_path(staged)
|
signed = self._sign_media_path(staged)
|
||||||
if signed is None:
|
if signed is None:
|
||||||
@ -917,7 +917,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
body = candidate.read_bytes()
|
body = candidate.read_bytes()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.warning("websocket static: failed to read {}: {}", candidate, e)
|
self.logger.warning("static: failed to read {}: {}", candidate, e)
|
||||||
return _http_error(500, "Internal Server Error")
|
return _http_error(500, "Internal Server Error")
|
||||||
ctype, _ = mimetypes.guess_type(candidate.name)
|
ctype, _ = mimetypes.guess_type(candidate.name)
|
||||||
if ctype is None:
|
if ctype is None:
|
||||||
@ -972,7 +972,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
async def handler(connection: ServerConnection) -> None:
|
async def handler(connection: ServerConnection) -> None:
|
||||||
await self._connection_loop(connection)
|
await self._connection_loop(connection)
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"WebSocket server listening on {}://{}:{}{}",
|
"WebSocket server listening on {}://{}:{}{}",
|
||||||
scheme,
|
scheme,
|
||||||
self.config.host,
|
self.config.host,
|
||||||
@ -980,7 +980,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
self.config.path,
|
self.config.path,
|
||||||
)
|
)
|
||||||
if self.config.token_issue_path:
|
if self.config.token_issue_path:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"WebSocket token issue route: {}://{}:{}{}",
|
"WebSocket token issue route: {}://{}:{}{}",
|
||||||
scheme,
|
scheme,
|
||||||
self.config.host,
|
self.config.host,
|
||||||
@ -1014,7 +1014,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if not client_id:
|
if not client_id:
|
||||||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||||
elif len(client_id) > 128:
|
elif len(client_id) > 128:
|
||||||
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
|
self.logger.warning("client_id too long ({} chars), truncating", len(client_id))
|
||||||
client_id = client_id[:128]
|
client_id = client_id[:128]
|
||||||
|
|
||||||
default_chat_id = str(uuid.uuid4())
|
default_chat_id = str(uuid.uuid4())
|
||||||
@ -1039,7 +1039,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
raw = raw.decode("utf-8")
|
raw = raw.decode("utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
logger.warning("websocket: ignoring non-utf8 binary frame")
|
self.logger.warning("ignoring non-utf8 binary frame")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
envelope = _parse_envelope(raw)
|
envelope = _parse_envelope(raw)
|
||||||
@ -1057,7 +1057,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
metadata={"remote": getattr(connection, "remote_address", None)},
|
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("websocket connection ended: {}", e)
|
self.logger.debug("connection ended: {}", e)
|
||||||
finally:
|
finally:
|
||||||
self._cleanup_connection(connection)
|
self._cleanup_connection(connection)
|
||||||
|
|
||||||
@ -1097,8 +1097,8 @@ class WebSocketChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
Path(p).unlink(missing_ok=True)
|
Path(p).unlink(missing_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"websocket: failed to unlink partial media {}: {}", p, exc
|
"failed to unlink partial media {}: {}", p, exc
|
||||||
)
|
)
|
||||||
return [], reason
|
return [], reason
|
||||||
|
|
||||||
@ -1122,7 +1122,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
except FileSizeExceeded:
|
except FileSizeExceeded:
|
||||||
return _abort("size")
|
return _abort("size")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("websocket: media decode failed: {}", exc)
|
self.logger.warning("media decode failed: {}", exc)
|
||||||
return _abort("decode")
|
return _abort("decode")
|
||||||
if saved is None:
|
if saved is None:
|
||||||
return _abort("decode")
|
return _abort("decode")
|
||||||
@ -1204,7 +1204,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._server_task
|
await self._server_task
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("websocket: server task error during shutdown: {}", e)
|
self.logger.warning("server task error during shutdown: {}", e)
|
||||||
self._server_task = None
|
self._server_task = None
|
||||||
self._subs.clear()
|
self._subs.clear()
|
||||||
self._conn_chats.clear()
|
self._conn_chats.clear()
|
||||||
@ -1218,16 +1218,16 @@ class WebSocketChannel(BaseChannel):
|
|||||||
await connection.send(raw)
|
await connection.send(raw)
|
||||||
except ConnectionClosed:
|
except ConnectionClosed:
|
||||||
self._cleanup_connection(connection)
|
self._cleanup_connection(connection)
|
||||||
logger.warning("websocket{}connection gone", label)
|
self.logger.warning("connection gone{}", label)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("websocket{}send failed: {}", label, e)
|
self.logger.exception("send failed{}", label)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
||||||
conns = list(self._subs.get(msg.chat_id, ()))
|
conns = list(self._subs.get(msg.chat_id, ()))
|
||||||
if not conns:
|
if not conns:
|
||||||
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
self.logger.warning("no active subscribers for chat_id={}", msg.chat_id)
|
||||||
return
|
return
|
||||||
# Signal that the agent has fully finished processing the current turn.
|
# Signal that the agent has fully finished processing the current turn.
|
||||||
if msg.metadata.get("_turn_end"):
|
if msg.metadata.get("_turn_end"):
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from collections import OrderedDict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -103,11 +102,11 @@ class WecomChannel(BaseChannel):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the WeCom bot with WebSocket long connection."""
|
"""Start the WeCom bot with WebSocket long connection."""
|
||||||
if not WECOM_AVAILABLE:
|
if not WECOM_AVAILABLE:
|
||||||
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
|
self.logger.error("SDK not installed. Run: pip install nanobot-ai[wecom]")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.bot_id or not self.config.secret:
|
if not self.config.bot_id or not self.config.secret:
|
||||||
logger.error("WeCom bot_id and secret not configured")
|
self.logger.error("bot_id and secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
from wecom_aibot_sdk import WSClient, generate_req_id
|
from wecom_aibot_sdk import WSClient, generate_req_id
|
||||||
@ -137,8 +136,8 @@ class WecomChannel(BaseChannel):
|
|||||||
self._client.on("message.mixed", self._on_mixed_message)
|
self._client.on("message.mixed", self._on_mixed_message)
|
||||||
self._client.on("event.enter_chat", self._on_enter_chat)
|
self._client.on("event.enter_chat", self._on_enter_chat)
|
||||||
|
|
||||||
logger.info("WeCom bot starting with WebSocket long connection")
|
self.logger.info("bot starting with WebSocket long connection")
|
||||||
logger.info("No public IP required - using WebSocket to receive events")
|
self.logger.info("No public IP required - using WebSocket to receive events")
|
||||||
|
|
||||||
# Connect
|
# Connect
|
||||||
await self._client.connect_async()
|
await self._client.connect_async()
|
||||||
@ -152,24 +151,24 @@ class WecomChannel(BaseChannel):
|
|||||||
self._running = False
|
self._running = False
|
||||||
if self._client:
|
if self._client:
|
||||||
await self._client.disconnect()
|
await self._client.disconnect()
|
||||||
logger.info("WeCom bot stopped")
|
self.logger.info("bot stopped")
|
||||||
|
|
||||||
async def _on_connected(self, frame: Any) -> None:
|
async def _on_connected(self, frame: Any) -> None:
|
||||||
"""Handle WebSocket connected event."""
|
"""Handle WebSocket connected event."""
|
||||||
logger.info("WeCom WebSocket connected")
|
self.logger.info("WebSocket connected")
|
||||||
|
|
||||||
async def _on_authenticated(self, frame: Any) -> None:
|
async def _on_authenticated(self, frame: Any) -> None:
|
||||||
"""Handle authentication success event."""
|
"""Handle authentication success event."""
|
||||||
logger.info("WeCom authenticated successfully")
|
self.logger.info("authenticated successfully")
|
||||||
|
|
||||||
async def _on_disconnected(self, frame: Any) -> None:
|
async def _on_disconnected(self, frame: Any) -> None:
|
||||||
"""Handle WebSocket disconnected event."""
|
"""Handle WebSocket disconnected event."""
|
||||||
reason = frame.body if hasattr(frame, 'body') else str(frame)
|
reason = frame.body if hasattr(frame, 'body') else str(frame)
|
||||||
logger.warning("WeCom WebSocket disconnected: {}", reason)
|
self.logger.warning("WebSocket disconnected: {}", reason)
|
||||||
|
|
||||||
async def _on_error(self, frame: Any) -> None:
|
async def _on_error(self, frame: Any) -> None:
|
||||||
"""Handle error event."""
|
"""Handle error event."""
|
||||||
logger.error("WeCom error: {}", frame)
|
self.logger.error("error: {}", frame)
|
||||||
|
|
||||||
async def _on_text_message(self, frame: Any) -> None:
|
async def _on_text_message(self, frame: Any) -> None:
|
||||||
"""Handle text message."""
|
"""Handle text message."""
|
||||||
@ -212,8 +211,8 @@ class WecomChannel(BaseChannel):
|
|||||||
"msgtype": "text",
|
"msgtype": "text",
|
||||||
"text": {"content": self.config.welcome_message},
|
"text": {"content": self.config.welcome_message},
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error handling enter_chat: {}", e)
|
self.logger.exception("Error handling enter_chat")
|
||||||
|
|
||||||
async def _process_message(self, frame: Any, msg_type: str) -> None:
|
async def _process_message(self, frame: Any, msg_type: str) -> None:
|
||||||
"""Process incoming message and forward to bus."""
|
"""Process incoming message and forward to bus."""
|
||||||
@ -228,7 +227,7 @@ class WecomChannel(BaseChannel):
|
|||||||
|
|
||||||
# Ensure body is a dict
|
# Ensure body is a dict
|
||||||
if not isinstance(body, dict):
|
if not isinstance(body, dict):
|
||||||
logger.warning("Invalid body type: {}", type(body))
|
self.logger.warning("Invalid body type: {}", type(body))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Extract message info
|
# Extract message info
|
||||||
@ -350,8 +349,8 @@ class WecomChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error processing WeCom message: {}", e)
|
self.logger.exception("Error processing message")
|
||||||
|
|
||||||
async def _download_and_save_media(
|
async def _download_and_save_media(
|
||||||
self,
|
self,
|
||||||
@ -370,12 +369,12 @@ class WecomChannel(BaseChannel):
|
|||||||
data, fname = await self._client.download_file(file_url, aes_key)
|
data, fname = await self._client.download_file(file_url, aes_key)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
logger.warning("Failed to download media from WeCom")
|
self.logger.warning("Failed to download media")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"WeCom inbound media too large: {} bytes (max {})",
|
"inbound media too large: {} bytes (max {})",
|
||||||
len(data),
|
len(data),
|
||||||
WECOM_UPLOAD_MAX_BYTES,
|
WECOM_UPLOAD_MAX_BYTES,
|
||||||
)
|
)
|
||||||
@ -388,11 +387,11 @@ class WecomChannel(BaseChannel):
|
|||||||
|
|
||||||
file_path = media_dir / filename
|
file_path = media_dir / filename
|
||||||
await asyncio.to_thread(file_path.write_bytes, data)
|
await asyncio.to_thread(file_path.write_bytes, data)
|
||||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
self.logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||||
return str(file_path)
|
return str(file_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error downloading media: {}", e)
|
self.logger.exception("Error downloading media")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _upload_media_ws(
|
async def _upload_media_ws(
|
||||||
@ -445,11 +444,11 @@ class WecomChannel(BaseChannel):
|
|||||||
"md5": md5_hash,
|
"md5": md5_hash,
|
||||||
}, "aibot_upload_media_init")
|
}, "aibot_upload_media_init")
|
||||||
if resp.errcode != 0:
|
if resp.errcode != 0:
|
||||||
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
self.logger.warning("upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
||||||
return None, None
|
return None, None
|
||||||
upload_id = resp.body.get("upload_id") if resp.body else None
|
upload_id = resp.body.get("upload_id") if resp.body else None
|
||||||
if not upload_id:
|
if not upload_id:
|
||||||
logger.warning("WeCom upload init: no upload_id in response")
|
self.logger.warning("upload init: no upload_id in response")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Step 2: send chunks
|
# Step 2: send chunks
|
||||||
@ -461,7 +460,7 @@ class WecomChannel(BaseChannel):
|
|||||||
"base64_data": base64.b64encode(chunk).decode(),
|
"base64_data": base64.b64encode(chunk).decode(),
|
||||||
}, "aibot_upload_media_chunk")
|
}, "aibot_upload_media_chunk")
|
||||||
if resp.errcode != 0:
|
if resp.errcode != 0:
|
||||||
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
self.logger.warning("upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Step 3: finish
|
# Step 3: finish
|
||||||
@ -470,29 +469,29 @@ class WecomChannel(BaseChannel):
|
|||||||
"upload_id": upload_id,
|
"upload_id": upload_id,
|
||||||
}, "aibot_upload_media_finish")
|
}, "aibot_upload_media_finish")
|
||||||
if resp.errcode != 0:
|
if resp.errcode != 0:
|
||||||
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
self.logger.warning("upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
media_id = resp.body.get("media_id") if resp.body else None
|
media_id = resp.body.get("media_id") if resp.body else None
|
||||||
if not media_id:
|
if not media_id:
|
||||||
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
|
self.logger.warning("upload finish: no media_id in response body={}", resp.body)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
suffix = "..." if len(media_id) > 16 else ""
|
suffix = "..." if len(media_id) > 16 else ""
|
||||||
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
self.logger.debug("uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
||||||
return media_id, media_type
|
return media_id, media_type
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
|
self.logger.warning("upload skipped for {}: {}", file_path, e)
|
||||||
return None, None
|
return None, None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
|
self.logger.exception("_upload_media_ws error for {}", file_path)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through WeCom."""
|
"""Send a message through WeCom."""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
logger.warning("WeCom client not initialized")
|
self.logger.warning("client not initialized")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -505,7 +504,7 @@ class WecomChannel(BaseChannel):
|
|||||||
# Send media files via WebSocket upload
|
# Send media files via WebSocket upload
|
||||||
for file_path in msg.media or []:
|
for file_path in msg.media or []:
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
logger.warning("WeCom media file not found: {}", file_path)
|
self.logger.warning("media file not found: {}", file_path)
|
||||||
continue
|
continue
|
||||||
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
||||||
if media_id:
|
if media_id:
|
||||||
@ -519,7 +518,7 @@ class WecomChannel(BaseChannel):
|
|||||||
"msgtype": media_type,
|
"msgtype": media_type,
|
||||||
media_type: {"media_id": media_id},
|
media_type: {"media_id": media_id},
|
||||||
})
|
})
|
||||||
logger.debug("WeCom sent {} → {}", media_type, msg.chat_id)
|
self.logger.debug("sent {} → {}", media_type, msg.chat_id)
|
||||||
else:
|
else:
|
||||||
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
||||||
|
|
||||||
@ -537,8 +536,8 @@ class WecomChannel(BaseChannel):
|
|||||||
content,
|
content,
|
||||||
finish=not is_progress,
|
finish=not is_progress,
|
||||||
)
|
)
|
||||||
logger.debug(
|
self.logger.debug(
|
||||||
"WeCom {} sent to {}",
|
"{} sent to {}",
|
||||||
"progress" if is_progress else "message",
|
"progress" if is_progress else "message",
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
)
|
)
|
||||||
@ -548,7 +547,7 @@ class WecomChannel(BaseChannel):
|
|||||||
"msgtype": "markdown",
|
"msgtype": "markdown",
|
||||||
"markdown": {"content": content},
|
"markdown": {"content": content},
|
||||||
})
|
})
|
||||||
logger.info("WeCom proactive send to {}", msg.chat_id)
|
self.logger.info("proactive send to {}", msg.chat_id)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
|
self.logger.exception("Error sending message to chat_id={}", msg.chat_id)
|
||||||
|
|||||||
@ -366,14 +366,14 @@ class WeixinChannel(BaseChannel):
|
|||||||
if base_url:
|
if base_url:
|
||||||
self.config.base_url = base_url
|
self.config.base_url = base_url
|
||||||
self._save_state()
|
self._save_state()
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"WeChat login successful! bot_id={} user_id={}",
|
"login successful! bot_id={} user_id={}",
|
||||||
bot_id,
|
bot_id,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error("Login confirmed but no bot_token in response")
|
self.logger.error("Login confirmed but no bot_token in response")
|
||||||
return False
|
return False
|
||||||
elif status == "scaned_but_redirect":
|
elif status == "scaned_but_redirect":
|
||||||
redirect_host = str(status_data.get("redirect_host", "") or "").strip()
|
redirect_host = str(status_data.get("redirect_host", "") or "").strip()
|
||||||
@ -387,7 +387,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
elif status == "expired":
|
elif status == "expired":
|
||||||
refresh_count += 1
|
refresh_count += 1
|
||||||
if refresh_count > MAX_QR_REFRESH_COUNT:
|
if refresh_count > MAX_QR_REFRESH_COUNT:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"QR code expired too many times ({}/{}), giving up.",
|
"QR code expired too many times ({}/{}), giving up.",
|
||||||
refresh_count - 1,
|
refresh_count - 1,
|
||||||
MAX_QR_REFRESH_COUNT,
|
MAX_QR_REFRESH_COUNT,
|
||||||
@ -401,8 +401,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("WeChat QR login failed: {}", e)
|
self.logger.exception("QR login failed")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -469,11 +469,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._token = self.config.token
|
self._token = self.config.token
|
||||||
elif not self._load_state():
|
elif not self._load_state():
|
||||||
if not await self._qr_login():
|
if not await self._qr_login():
|
||||||
logger.error("WeChat login failed. Run 'nanobot channels login weixin' to authenticate.")
|
self.logger.error("login failed. Run 'nanobot channels login weixin' to authenticate.")
|
||||||
self._running = False
|
self._running = False
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("WeChat channel starting with long-poll...")
|
self.logger.info("channel starting with long-poll...")
|
||||||
|
|
||||||
consecutive_failures = 0
|
consecutive_failures = 0
|
||||||
while self._running:
|
while self._running:
|
||||||
@ -551,8 +551,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
||||||
self._pause_session()
|
self._pause_session()
|
||||||
remaining = self._session_pause_remaining_s()
|
remaining = self._session_pause_remaining_s()
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"WeChat session expired (errcode {}). Pausing {} min.",
|
"session expired (errcode {}). Pausing {} min.",
|
||||||
errcode,
|
errcode,
|
||||||
max((remaining + 59) // 60, 1),
|
max((remaining + 59) // 60, 1),
|
||||||
)
|
)
|
||||||
@ -759,8 +759,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
if not content:
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
"WeChat inbound: from={} items={} bodyLen={}",
|
"inbound: from={} items={} bodyLen={}",
|
||||||
from_user_id,
|
from_user_id,
|
||||||
",".join(str(i.get("type", 0)) for i in item_list),
|
",".join(str(i.get("type", 0)) for i in item_list),
|
||||||
len(content),
|
len(content),
|
||||||
@ -843,8 +843,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
and self._is_retryable_media_download_error(e)
|
and self._is_retryable_media_download_error(e)
|
||||||
)
|
)
|
||||||
if should_fallback:
|
if should_fallback:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"WeChat media download failed via full_url, falling back to encrypt_query_param: type={} err={}",
|
"media download failed via full_url, falling back to encrypt_query_param: type={} err={}",
|
||||||
media_type,
|
media_type,
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
@ -869,8 +869,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
file_path.write_bytes(data)
|
file_path.write_bytes(data)
|
||||||
return str(file_path)
|
return str(file_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error downloading WeChat media: {}", e)
|
self.logger.exception("Error downloading media")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -940,7 +940,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
if not self._client or not self._token:
|
if not self._client or not self._token:
|
||||||
logger.warning("WeChat client not initialized or not authenticated")
|
self.logger.warning("client not initialized or not authenticated")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
self._assert_session_active()
|
self._assert_session_active()
|
||||||
@ -954,8 +954,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
content = msg.content.strip()
|
content = msg.content.strip()
|
||||||
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
||||||
if not ctx_token:
|
if not ctx_token:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"WeChat: no context_token for chat_id={}, cannot send",
|
"no context_token for chat_id={}, cannot send",
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -980,14 +980,13 @@ class WeixinChannel(BaseChannel):
|
|||||||
for media_path in (msg.media or []):
|
for media_path in (msg.media or []):
|
||||||
try:
|
try:
|
||||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||||
except (httpx.TimeoutException, httpx.TransportError) as net_err:
|
except (httpx.TimeoutException, httpx.TransportError):
|
||||||
# Network/transport errors: do NOT fall back to text —
|
# Network/transport errors: do NOT fall back to text —
|
||||||
# the text send would also likely fail, and the outer
|
# the text send would also likely fail, and the outer
|
||||||
# except will re-raise so ChannelManager retries properly.
|
# except will re-raise so ChannelManager retries properly.
|
||||||
logger.error(
|
self.logger.opt(exception=True).warning(
|
||||||
"Network error sending WeChat media {}: {}",
|
"Network error sending media {}",
|
||||||
media_path,
|
media_path,
|
||||||
net_err,
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
except httpx.HTTPStatusError as http_err:
|
except httpx.HTTPStatusError as http_err:
|
||||||
@ -998,27 +997,26 @@ class WeixinChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
if status_code >= 500:
|
if status_code >= 500:
|
||||||
# Server-side / retryable HTTP error — same as network.
|
# Server-side / retryable HTTP error — same as network.
|
||||||
logger.error(
|
self.logger.exception(
|
||||||
"Server error ({} {}) sending WeChat media {}: {}",
|
"Server error ({} {}) sending media {}",
|
||||||
status_code,
|
status_code,
|
||||||
http_err.response.reason_phrase
|
http_err.response.reason_phrase
|
||||||
if http_err.response is not None
|
if http_err.response is not None
|
||||||
else "",
|
else "",
|
||||||
media_path,
|
media_path,
|
||||||
http_err,
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
# 4xx client errors are NOT retryable — fall back to text.
|
# 4xx client errors are NOT retryable — fall back to text.
|
||||||
filename = Path(media_path).name
|
filename = Path(media_path).name
|
||||||
logger.error("Failed to send WeChat media {}: {}", media_path, http_err)
|
self.logger.exception("Failed to send media {}", media_path)
|
||||||
await self._send_text(
|
await self._send_text(
|
||||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Non-network errors (format, file-not-found, etc.):
|
# Non-network errors (format, file-not-found, etc.):
|
||||||
# notify the user via text fallback.
|
# notify the user via text fallback.
|
||||||
filename = Path(media_path).name
|
filename = Path(media_path).name
|
||||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
self.logger.exception("Failed to send media {}", media_path)
|
||||||
# Notify user about failure via text
|
# Notify user about failure via text
|
||||||
await self._send_text(
|
await self._send_text(
|
||||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||||
@ -1031,8 +1029,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN)
|
chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
await self._send_text(msg.chat_id, chunk, ctx_token)
|
await self._send_text(msg.chat_id, chunk, ctx_token)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending WeChat message: {}", e)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if typing_keepalive_task:
|
if typing_keepalive_task:
|
||||||
@ -1056,7 +1054,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e)
|
self.logger.debug("typing indicator start failed for {}: {}", chat_id, e)
|
||||||
return
|
return
|
||||||
|
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
@ -1095,7 +1093,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL)
|
await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("WeChat typing clear failed for {}: {}", chat_id, e)
|
self.logger.debug("typing clear failed for {}: {}", chat_id, e)
|
||||||
|
|
||||||
async def _send_text(
|
async def _send_text(
|
||||||
self,
|
self,
|
||||||
@ -1130,8 +1128,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
data = await self._api_post("ilink/bot/sendmessage", body)
|
data = await self._api_post("ilink/bot/sendmessage", body)
|
||||||
errcode = data.get("errcode", 0)
|
errcode = data.get("errcode", 0)
|
||||||
if errcode and errcode != 0:
|
if errcode and errcode != 0:
|
||||||
logger.warning(
|
self.logger.warning(
|
||||||
"WeChat send error (code {}): {}",
|
"send error (code {}): {}",
|
||||||
errcode,
|
errcode,
|
||||||
data.get("errmsg", ""),
|
data.get("errmsg", ""),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -99,15 +99,15 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
bridge_dir = _ensure_bridge_setup()
|
bridge_dir = _ensure_bridge_setup()
|
||||||
except RuntimeError as e:
|
except RuntimeError:
|
||||||
logger.error("{}", e)
|
self.logger.exception("bridge setup failed")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
env = {**os.environ}
|
env = {**os.environ}
|
||||||
env["BRIDGE_TOKEN"] = self._effective_bridge_token()
|
env["BRIDGE_TOKEN"] = self._effective_bridge_token()
|
||||||
env["AUTH_DIR"] = str(_bridge_token_path().parent)
|
env["AUTH_DIR"] = str(_bridge_token_path().parent)
|
||||||
|
|
||||||
logger.info("Starting WhatsApp bridge for QR login...")
|
self.logger.info("Starting WhatsApp bridge for QR login...")
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[shutil.which("npm"), "start"], cwd=bridge_dir, check=True, env=env
|
[shutil.which("npm"), "start"], cwd=bridge_dir, check=True, env=env
|
||||||
@ -123,7 +123,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
|
|
||||||
bridge_url = self.config.bridge_url
|
bridge_url = self.config.bridge_url
|
||||||
|
|
||||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
self.logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
@ -135,24 +135,24 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
json.dumps({"type": "auth", "token": self._effective_bridge_token()})
|
json.dumps({"type": "auth", "token": self._effective_bridge_token()})
|
||||||
)
|
)
|
||||||
self._connected = True
|
self._connected = True
|
||||||
logger.info("Connected to WhatsApp bridge")
|
self.logger.info("Connected to WhatsApp bridge")
|
||||||
|
|
||||||
# Listen for messages
|
# Listen for messages
|
||||||
async for message in ws:
|
async for message in ws:
|
||||||
try:
|
try:
|
||||||
await self._handle_bridge_message(message)
|
await self._handle_bridge_message(message)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error handling bridge message: {}", e)
|
self.logger.exception("Error handling bridge message")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._ws = None
|
self._ws = None
|
||||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
self.logger.warning("WhatsApp bridge connection error: {}", e)
|
||||||
|
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting in 5 seconds...")
|
self.logger.info("Reconnecting in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
@ -167,7 +167,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through WhatsApp."""
|
"""Send a message through WhatsApp."""
|
||||||
if not self._ws or not self._connected:
|
if not self._ws or not self._connected:
|
||||||
logger.warning("WhatsApp bridge not connected")
|
self.logger.warning("WhatsApp bridge not connected")
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_id = msg.chat_id
|
chat_id = msg.chat_id
|
||||||
@ -176,8 +176,8 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
payload = {"type": "send", "to": chat_id, "text": msg.content}
|
payload = {"type": "send", "to": chat_id, "text": msg.content}
|
||||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending WhatsApp message: {}", e)
|
self.logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
for media_path in msg.media or []:
|
for media_path in msg.media or []:
|
||||||
@ -191,8 +191,8 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
"fileName": media_path.rsplit("/", 1)[-1],
|
"fileName": media_path.rsplit("/", 1)[-1],
|
||||||
}
|
}
|
||||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error sending WhatsApp media {}: {}", media_path, e)
|
self.logger.exception("Error sending media {}", media_path)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _handle_bridge_message(self, raw: str) -> None:
|
async def _handle_bridge_message(self, raw: str) -> None:
|
||||||
@ -200,7 +200,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
self.logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||||
return
|
return
|
||||||
|
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
@ -253,7 +253,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
if phone_id and lid_id:
|
if phone_id and lid_id:
|
||||||
self._lid_to_phone[lid_id] = phone_id
|
self._lid_to_phone[lid_id] = phone_id
|
||||||
|
|
||||||
logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
|
self.logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
|
||||||
|
|
||||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||||
media_paths = data.get("media") or []
|
media_paths = data.get("media") or []
|
||||||
@ -261,11 +261,11 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
# Handle voice transcription if it's a voice message
|
# Handle voice transcription if it's a voice message
|
||||||
if content == "[Voice Message]":
|
if content == "[Voice Message]":
|
||||||
if media_paths:
|
if media_paths:
|
||||||
logger.info("Transcribing voice message from {}...", sender_id)
|
self.logger.info("Transcribing voice message from {}...", sender_id)
|
||||||
transcription = await self.transcribe_audio(media_paths[0])
|
transcription = await self.transcribe_audio(media_paths[0])
|
||||||
if transcription:
|
if transcription:
|
||||||
content = transcription
|
content = transcription
|
||||||
logger.info("Transcribed voice from {}: {}...", sender_id, transcription[:50])
|
self.logger.info("Transcribed voice from {}: {}...", sender_id, transcription[:50])
|
||||||
else:
|
else:
|
||||||
content = "[Voice Message: Transcription failed]"
|
content = "[Voice Message: Transcription failed]"
|
||||||
else:
|
else:
|
||||||
@ -294,7 +294,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
elif msg_type == "status":
|
elif msg_type == "status":
|
||||||
# Connection status update
|
# Connection status update
|
||||||
status = data.get("status")
|
status = data.get("status")
|
||||||
logger.info("WhatsApp status: {}", status)
|
self.logger.info("Status: {}", status)
|
||||||
|
|
||||||
if status == "connected":
|
if status == "connected":
|
||||||
self._connected = True
|
self._connected = True
|
||||||
@ -303,10 +303,10 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
|
|
||||||
elif msg_type == "qr":
|
elif msg_type == "qr":
|
||||||
# QR code for authentication
|
# QR code for authentication
|
||||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
self.logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||||
|
|
||||||
elif msg_type == "error":
|
elif msg_type == "error":
|
||||||
logger.error("WhatsApp bridge error: {}", data.get("error"))
|
self.logger.error("Bridge error: {}", data.get("error"))
|
||||||
|
|
||||||
|
|
||||||
def _ensure_bridge_setup() -> Path:
|
def _ensure_bridge_setup() -> Path:
|
||||||
|
|||||||
@ -21,6 +21,22 @@ if sys.platform == "win32":
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
# Remove default handler and re-add with unified nanobot format
|
||||||
|
logger.remove()
|
||||||
|
_log_handler_id = logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=(
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
||||||
|
"<level>{level: <5}</level> | "
|
||||||
|
"<cyan>{extra[channel]}</cyan> | "
|
||||||
|
"<level>{message}</level>"
|
||||||
|
),
|
||||||
|
level="INFO",
|
||||||
|
colorize=None,
|
||||||
|
filter=lambda record: record["extra"].setdefault("channel", "-") or True,
|
||||||
|
)
|
||||||
|
|
||||||
from prompt_toolkit import PromptSession, print_formatted_text
|
from prompt_toolkit import PromptSession, print_formatted_text
|
||||||
from prompt_toolkit.application import run_in_terminal
|
from prompt_toolkit.application import run_in_terminal
|
||||||
from prompt_toolkit.formatted_text import ANSI, HTML
|
from prompt_toolkit.formatted_text import ANSI, HTML
|
||||||
@ -541,6 +557,7 @@ def serve(
|
|||||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||||
|
tool_hint_max_length=runtime_config.agents.defaults.tool_hint_max_length,
|
||||||
web_config=runtime_config.tools.web,
|
web_config=runtime_config.tools.web,
|
||||||
exec_config=runtime_config.tools.exec,
|
exec_config=runtime_config.tools.exec,
|
||||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||||
@ -597,9 +614,19 @@ def gateway(
|
|||||||
):
|
):
|
||||||
"""Start the nanobot gateway."""
|
"""Start the nanobot gateway."""
|
||||||
if verbose:
|
if verbose:
|
||||||
import logging
|
logger.remove(_log_handler_id)
|
||||||
|
logger.add(
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
sys.stderr,
|
||||||
|
format=(
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
||||||
|
"<level>{level: <5}</level> | "
|
||||||
|
"<cyan>{extra[channel]}</cyan> | "
|
||||||
|
"<level>{message}</level>"
|
||||||
|
),
|
||||||
|
level="DEBUG",
|
||||||
|
colorize=None,
|
||||||
|
filter=lambda record: record["extra"].setdefault("channel", "-") or True,
|
||||||
|
)
|
||||||
cfg = _load_runtime_config(config, workspace)
|
cfg = _load_runtime_config(config, workspace)
|
||||||
_run_gateway(cfg, port=port)
|
_run_gateway(cfg, port=port)
|
||||||
|
|
||||||
@ -655,6 +682,7 @@ def _run_gateway(
|
|||||||
context_block_limit=config.agents.defaults.context_block_limit,
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
|
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
@ -1047,6 +1075,7 @@ def agent(
|
|||||||
context_block_limit=config.agents.defaults.context_block_limit,
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
|
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
|||||||
@ -840,7 +840,7 @@ def _get_channel_info() -> dict[str, tuple[str, type[BaseModel]]]:
|
|||||||
display_name = getattr(channel_cls, "display_name", name.capitalize())
|
display_name = getattr(channel_cls, "display_name", name.capitalize())
|
||||||
result[name] = (display_name, config_cls)
|
result[name] = (display_name, config_cls)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(f"Failed to load channel module: {name}")
|
logger.warning("Failed to load channel module: {}", name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def load_config(config_path: Path | None = None) -> Config:
|
|||||||
data = _migrate_config(data)
|
data = _migrate_config(data)
|
||||||
config = Config.model_validate(data)
|
config = Config.model_validate(data)
|
||||||
except (json.JSONDecodeError, ValueError, pydantic.ValidationError) as e:
|
except (json.JSONDecodeError, ValueError, pydantic.ValidationError) as e:
|
||||||
logger.warning(f"Failed to load config from {path}: {e}")
|
logger.warning("Failed to load config from {}: {}", path, e)
|
||||||
logger.warning("Using default configuration.")
|
logger.warning("Using default configuration.")
|
||||||
|
|
||||||
_apply_ssrf_whitelist(config)
|
_apply_ssrf_whitelist(config)
|
||||||
|
|||||||
@ -81,6 +81,13 @@ class AgentDefaults(Base):
|
|||||||
max_concurrent_subagents: int = Field(default=1, ge=1)
|
max_concurrent_subagents: int = Field(default=1, ge=1)
|
||||||
max_tool_result_chars: int = 16_000
|
max_tool_result_chars: int = 16_000
|
||||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||||
|
tool_hint_max_length: int = Field(
|
||||||
|
default=40,
|
||||||
|
ge=20,
|
||||||
|
le=500,
|
||||||
|
validation_alias=AliasChoices("toolHintMaxLength"),
|
||||||
|
serialization_alias="toolHintMaxLength",
|
||||||
|
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
||||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||||
|
|||||||
@ -156,7 +156,7 @@ class CronService:
|
|||||||
updated_at_ms=j.get("updatedAtMs", 0),
|
updated_at_ms=j.get("updatedAtMs", 0),
|
||||||
delete_after_run=j.get("deleteAfterRun", False),
|
delete_after_run=j.get("deleteAfterRun", False),
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Preserve the corrupt file for forensic recovery instead of
|
# Preserve the corrupt file for forensic recovery instead of
|
||||||
# letting the next save overwrite it with an empty job list.
|
# letting the next save overwrite it with an empty job list.
|
||||||
backup = self.store_path.with_suffix(
|
backup = self.store_path.with_suffix(
|
||||||
@ -164,12 +164,11 @@ class CronService:
|
|||||||
)
|
)
|
||||||
with suppress(OSError):
|
with suppress(OSError):
|
||||||
self.store_path.rename(backup)
|
self.store_path.rename(backup)
|
||||||
logger.error(
|
logger.exception(
|
||||||
"Failed to load cron store at {}: {}. "
|
"Failed to load cron store at {}. "
|
||||||
"Corrupt file preserved at {}. "
|
"Corrupt file preserved at {}. "
|
||||||
"Refusing to overwrite to avoid data loss.",
|
"Refusing to overwrite to avoid data loss.",
|
||||||
self.store_path,
|
self.store_path,
|
||||||
e,
|
|
||||||
backup,
|
backup,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -202,8 +201,8 @@ class CronService:
|
|||||||
else:
|
else:
|
||||||
_update(action.get("params", {}))
|
_update(action.get("params", {}))
|
||||||
changed = True
|
changed = True
|
||||||
except Exception as exp:
|
except Exception:
|
||||||
logger.debug(f"load action line error: {exp}")
|
logger.exception("load action line error")
|
||||||
continue
|
continue
|
||||||
self._store.jobs = list(jobs_map.values())
|
self._store.jobs = list(jobs_map.values())
|
||||||
if self._running and changed:
|
if self._running and changed:
|
||||||
@ -434,7 +433,7 @@ class CronService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
job.state.last_status = "error"
|
job.state.last_status = "error"
|
||||||
job.state.last_error = str(e)
|
job.state.last_error = str(e)
|
||||||
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
logger.exception("Cron: job '{}' failed", job.name)
|
||||||
|
|
||||||
end_ms = _now_ms()
|
end_ms = _now_ms()
|
||||||
job.state.last_run_at_ms = start_ms
|
job.state.last_run_at_ms = start_ms
|
||||||
|
|||||||
@ -144,8 +144,8 @@ class HeartbeatService:
|
|||||||
await self._tick()
|
await self._tick()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Heartbeat error: {}", e)
|
logger.exception("Heartbeat error")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_deliverable(response: str) -> bool:
|
def _is_deliverable(response: str) -> bool:
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class Nanobot:
|
|||||||
context_block_limit=defaults.context_block_limit,
|
context_block_limit=defaults.context_block_limit,
|
||||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=defaults.provider_retry_mode,
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
|
tool_hint_max_length=defaults.tool_hint_max_length,
|
||||||
web_config=config.tools.web,
|
web_config=config.tools.web,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
|||||||
@ -93,7 +93,7 @@ def _extract_pdf(path: Path) -> str:
|
|||||||
pages.append(f"--- Page {i} ---\n{text}")
|
pages.append(f"--- Page {i} ---\n{text}")
|
||||||
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
|
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract PDF {}: {}", path, e)
|
logger.exception("Failed to extract PDF {}", path)
|
||||||
return f"[error: failed to extract PDF: {e!s}]"
|
return f"[error: failed to extract PDF: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def _extract_docx(path: Path) -> str:
|
|||||||
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
|
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||||
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
|
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract DOCX {}: {}", path, e)
|
logger.exception("Failed to extract DOCX {}", path)
|
||||||
return f"[error: failed to extract DOCX: {e!s}]"
|
return f"[error: failed to extract DOCX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ def _extract_xlsx(path: Path) -> str:
|
|||||||
finally:
|
finally:
|
||||||
wb.close()
|
wb.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract XLSX {}: {}", path, e)
|
logger.exception("Failed to extract XLSX {}", path)
|
||||||
return f"[error: failed to extract XLSX: {e!s}]"
|
return f"[error: failed to extract XLSX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
@ -156,7 +156,7 @@ def _extract_pptx(path: Path) -> str:
|
|||||||
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
|
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
|
||||||
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
|
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract PPTX {}: {}", path, e)
|
logger.exception("Failed to extract PPTX {}", path)
|
||||||
return f"[error: failed to extract PPTX: {e!s}]"
|
return f"[error: failed to extract PPTX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ def _extract_text_file(path: Path) -> str:
|
|||||||
content = path.read_text(encoding="latin-1")
|
content = path.read_text(encoding="latin-1")
|
||||||
return _truncate(content, _MAX_TEXT_LENGTH)
|
return _truncate(content, _MAX_TEXT_LENGTH)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to read text file {}: {}", path, e)
|
logger.exception("Failed to read text file {}", path)
|
||||||
return f"[error: failed to read file: {e!s}]"
|
return f"[error: failed to read file: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class GitStore:
|
|||||||
logger.info("Git store initialized at {}", self._workspace)
|
logger.info("Git store initialized at {}", self._workspace)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git store init failed for {}", self._workspace)
|
logger.exception("Git store init failed for {}", self._workspace)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# -- daily operations ------------------------------------------------------
|
# -- daily operations ------------------------------------------------------
|
||||||
@ -149,7 +149,7 @@ class GitStore:
|
|||||||
logger.debug("Git auto-commit: {} ({})", sha, message)
|
logger.debug("Git auto-commit: {} ({})", sha, message)
|
||||||
return sha
|
return sha
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git auto-commit failed: {}", message)
|
logger.exception("Git auto-commit failed: {}", message)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# -- internal helpers ------------------------------------------------------
|
# -- internal helpers ------------------------------------------------------
|
||||||
@ -243,7 +243,7 @@ class GitStore:
|
|||||||
|
|
||||||
return entries
|
return entries
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git log failed")
|
logger.exception("Git log failed")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def line_ages(self, file_path: str) -> list[LineAge]:
|
def line_ages(self, file_path: str) -> list[LineAge]:
|
||||||
@ -266,7 +266,7 @@ class GitStore:
|
|||||||
|
|
||||||
annotated = porcelain.annotate(str(self._workspace), file_path)
|
annotated = porcelain.annotate(str(self._workspace), file_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git line_ages annotate failed for {}", file_path)
|
logger.exception("Git line_ages annotate failed for {}", file_path)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not annotated:
|
if not annotated:
|
||||||
@ -296,7 +296,7 @@ class GitStore:
|
|||||||
)
|
)
|
||||||
return out.getvalue().decode("utf-8", errors="replace")
|
return out.getvalue().decode("utf-8", errors="replace")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git diff_commits failed")
|
logger.exception("Git diff_commits failed")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def find_commit(self, short_sha: str, max_entries: int = 20) -> CommitInfo | None:
|
def find_commit(self, short_sha: str, max_entries: int = 20) -> CommitInfo | None:
|
||||||
@ -367,7 +367,7 @@ class GitStore:
|
|||||||
msg = f"revert: undo {commit}"
|
msg = f"revert: undo {commit}"
|
||||||
return self.auto_commit(msg)
|
return self.auto_commit(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git revert failed for {}", commit)
|
logger.exception("Git revert failed for {}", commit)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -268,8 +268,8 @@ def maybe_persist_tool_result(
|
|||||||
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||||
try:
|
try:
|
||||||
_cleanup_tool_result_buckets(root, bucket)
|
_cleanup_tool_result_buckets(root, bucket)
|
||||||
except Exception as exc:
|
except Exception:
|
||||||
logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc)
|
logger.exception("Failed to clean stale tool result buckets in {}", root)
|
||||||
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
if suffix == "json" and isinstance(content, list):
|
if suffix == "json" and isinstance(content, list):
|
||||||
@ -540,6 +540,6 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
|||||||
)
|
)
|
||||||
gs.init()
|
gs.init()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to initialize git store for {}", workspace)
|
logger.exception("Failed to initialize git store for {}", workspace)
|
||||||
|
|
||||||
return added
|
return added
|
||||||
|
|||||||
47
nanobot/utils/logging_bridge.py
Normal file
47
nanobot/utils/logging_bridge.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Utilities for redirecting stdlib logging to loguru."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class _LoguruBridge(logging.Handler):
|
||||||
|
"""Route stdlib log records into loguru with consistent formatting."""
|
||||||
|
|
||||||
|
_LEVEL_MAP: dict[int, str] = {
|
||||||
|
logging.DEBUG: "DEBUG",
|
||||||
|
logging.INFO: "INFO",
|
||||||
|
logging.WARNING: "WARNING",
|
||||||
|
logging.ERROR: "ERROR",
|
||||||
|
logging.CRITICAL: "CRITICAL",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, lib_name: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.lib_name = lib_name
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
level = self._LEVEL_MAP.get(record.levelno, "INFO")
|
||||||
|
frame, depth = logging.currentframe(), 2
|
||||||
|
while frame and frame.f_code.co_filename == logging.__file__:
|
||||||
|
frame, depth = frame.f_back, depth + 1
|
||||||
|
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||||
|
level, "[{lib}] {message}", lib=self.lib_name, message=record.getMessage()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def redirect_lib_logging(name: str, level: str | None = None) -> None:
|
||||||
|
"""Redirect stdlib logging from *name* into loguru.
|
||||||
|
|
||||||
|
Adds a bridge handler if one is not already present and disables
|
||||||
|
propagation so messages are not duplicated. When *level* is None the
|
||||||
|
handler does not filter — loguru's own level controls visibility.
|
||||||
|
"""
|
||||||
|
lib_logger = logging.getLogger(name)
|
||||||
|
if not any(isinstance(h, _LoguruBridge) for h in lib_logger.handlers):
|
||||||
|
handler = _LoguruBridge(name)
|
||||||
|
if level is not None:
|
||||||
|
handler.setLevel(getattr(logging, level.upper(), logging.WARNING))
|
||||||
|
lib_logger.handlers = [handler]
|
||||||
|
lib_logger.propagate = False
|
||||||
@ -27,7 +27,7 @@ _PATH_IN_CMD_RE = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_tool_hints(tool_calls: list) -> str:
|
def format_tool_hints(tool_calls: list, max_length: int = 40) -> str:
|
||||||
"""Format tool calls as concise hints with smart abbreviation."""
|
"""Format tool calls as concise hints with smart abbreviation."""
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return ""
|
return ""
|
||||||
@ -36,11 +36,11 @@ def format_tool_hints(tool_calls: list) -> str:
|
|||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
fmt = _TOOL_FORMATS.get(tc.name)
|
fmt = _TOOL_FORMATS.get(tc.name)
|
||||||
if fmt:
|
if fmt:
|
||||||
formatted.append(_fmt_known(tc, fmt))
|
formatted.append(_fmt_known(tc, fmt, max_length))
|
||||||
elif tc.name.startswith("mcp_"):
|
elif tc.name.startswith("mcp_"):
|
||||||
formatted.append(_fmt_mcp(tc))
|
formatted.append(_fmt_mcp(tc, max_length))
|
||||||
else:
|
else:
|
||||||
formatted.append(_fmt_fallback(tc))
|
formatted.append(_fmt_fallback(tc, max_length))
|
||||||
|
|
||||||
hints = []
|
hints = []
|
||||||
for hint in formatted:
|
for hint in formatted:
|
||||||
@ -80,26 +80,28 @@ def _extract_arg(tc, key_args: list[str]) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _fmt_known(tc, fmt: tuple) -> str:
|
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
|
||||||
"""Format a registered tool using its template."""
|
"""Format a registered tool using its template."""
|
||||||
val = _extract_arg(tc, fmt[0])
|
val = _extract_arg(tc, fmt[0])
|
||||||
if val is None:
|
if val is None:
|
||||||
return tc.name
|
return tc.name
|
||||||
if fmt[2]: # is_path
|
if fmt[2]: # is_path
|
||||||
val = abbreviate_path(val)
|
val = abbreviate_path(val, max_len=max_length)
|
||||||
elif fmt[3]: # is_command
|
elif fmt[3]: # is_command
|
||||||
val = _abbreviate_command(val)
|
val = _abbreviate_command(val, max_len=max_length)
|
||||||
return fmt[1].format(val)
|
return fmt[1].format(val)
|
||||||
|
|
||||||
|
|
||||||
def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
|
def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
|
||||||
"""Abbreviate paths in a command string, then truncate."""
|
"""Abbreviate paths in a command string, then truncate."""
|
||||||
|
path_max = max(max_len // 2, 25)
|
||||||
|
|
||||||
def _replace_path(match: re.Match[str]) -> str:
|
def _replace_path(match: re.Match[str]) -> str:
|
||||||
if match.group("double") is not None:
|
if match.group("double") is not None:
|
||||||
return f'"{abbreviate_path(match.group("double"), max_len=25)}"'
|
return f'"{abbreviate_path(match.group("double"), max_len=path_max)}"'
|
||||||
if match.group("single") is not None:
|
if match.group("single") is not None:
|
||||||
return f"'{abbreviate_path(match.group('single'), max_len=25)}'"
|
return f"'{abbreviate_path(match.group('single'), max_len=path_max)}'"
|
||||||
return abbreviate_path(match.group("bare"), max_len=25)
|
return abbreviate_path(match.group("bare"), max_len=path_max)
|
||||||
|
|
||||||
abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd)
|
abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd)
|
||||||
if len(abbreviated) <= max_len:
|
if len(abbreviated) <= max_len:
|
||||||
@ -107,7 +109,7 @@ def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
|
|||||||
return abbreviated[:max_len - 1] + "\u2026"
|
return abbreviated[:max_len - 1] + "\u2026"
|
||||||
|
|
||||||
|
|
||||||
def _fmt_mcp(tc) -> str:
|
def _fmt_mcp(tc, max_length: int = 40) -> str:
|
||||||
"""Format MCP tool as server::tool."""
|
"""Format MCP tool as server::tool."""
|
||||||
name = tc.name
|
name = tc.name
|
||||||
if "__" in name:
|
if "__" in name:
|
||||||
@ -125,13 +127,13 @@ def _fmt_mcp(tc) -> str:
|
|||||||
val = next((v for v in args.values() if isinstance(v, str) and v), None)
|
val = next((v for v in args.values() if isinstance(v, str) and v), None)
|
||||||
if val is None:
|
if val is None:
|
||||||
return f"{server}::{tool}"
|
return f"{server}::{tool}"
|
||||||
return f'{server}::{tool}("{abbreviate_path(val, 40)}")'
|
return f'{server}::{tool}("{abbreviate_path(val, max_length)}")'
|
||||||
|
|
||||||
|
|
||||||
def _fmt_fallback(tc) -> str:
|
def _fmt_fallback(tc, max_length: int = 40) -> str:
|
||||||
"""Original formatting logic for unregistered tools."""
|
"""Original formatting logic for unregistered tools."""
|
||||||
args = _get_args(tc)
|
args = _get_args(tc)
|
||||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||||
if not isinstance(val, str):
|
if not isinstance(val, str):
|
||||||
return tc.name
|
return tc.name
|
||||||
return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")'
|
return f'{tc.name}("{abbreviate_path(val, max_length)}")' if len(val) > max_length else f'{tc.name}("{val}")'
|
||||||
|
|||||||
@ -130,11 +130,44 @@ class TestToolEventProgress:
|
|||||||
assert finish["result"] == "file.txt"
|
assert finish["result"] == "file.txt"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bus_progress_streams_provider_deltas_for_codex_style_provider(
|
async def test_non_streaming_channel_does_not_publish_codex_progress_deltas(
|
||||||
self,
|
self,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Providers that opt in can stream content deltas through _progress messages."""
|
"""Non-streaming channels should get one final reply, not token progress spam."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello", tool_calls=[]))
|
||||||
|
provider.chat_stream_with_retry = AsyncMock()
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop._dispatch(InboundMessage(
|
||||||
|
channel="whatsapp",
|
||||||
|
sender_id="u1",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="say hello",
|
||||||
|
))
|
||||||
|
|
||||||
|
outbound = []
|
||||||
|
while bus.outbound_size > 0:
|
||||||
|
outbound.append(await bus.consume_outbound())
|
||||||
|
|
||||||
|
assert [m.content for m in outbound] == ["Hello"]
|
||||||
|
assert not any(m.metadata.get("_progress") for m in outbound)
|
||||||
|
assert not any(m.metadata.get("_streamed") for m in outbound)
|
||||||
|
provider.chat_stream_with_retry.assert_not_awaited()
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_channel_streams_provider_deltas_for_codex_style_provider(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Streaming channels still receive provider deltas through _stream_delta messages."""
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.supports_progress_deltas = True
|
provider.supports_progress_deltas = True
|
||||||
@ -156,18 +189,27 @@ class TestToolEventProgress:
|
|||||||
sender_id="u1",
|
sender_id="u1",
|
||||||
chat_id="chat1",
|
chat_id="chat1",
|
||||||
content="say hello",
|
content="say hello",
|
||||||
|
metadata={"_wants_stream": True},
|
||||||
))
|
))
|
||||||
|
|
||||||
outbound = []
|
outbound = []
|
||||||
while bus.outbound_size > 0:
|
while bus.outbound_size > 0:
|
||||||
outbound.append(await bus.consume_outbound())
|
outbound.append(await bus.consume_outbound())
|
||||||
|
|
||||||
progress = [m for m in outbound if m.metadata.get("_progress")]
|
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
|
||||||
final = [m for m in outbound if not m.metadata.get("_progress")]
|
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
|
||||||
|
final = [
|
||||||
|
m for m in outbound
|
||||||
|
if not m.metadata.get("_stream_delta")
|
||||||
|
and not m.metadata.get("_stream_end")
|
||||||
|
and not m.metadata.get("_turn_end")
|
||||||
|
]
|
||||||
|
|
||||||
assert [m.content for m in progress] == ["Hel", "lo"]
|
assert [m.content for m in deltas] == ["Hel", "lo"]
|
||||||
assert final[-2].content == "Hello"
|
assert len(stream_end) == 1
|
||||||
assert (final[-1].metadata or {}).get("_turn_end") is True
|
assert final[-1].content == "Hello"
|
||||||
|
assert final[-1].metadata.get("_streamed") is True
|
||||||
|
assert outbound[-1].metadata.get("_turn_end") is True
|
||||||
provider.chat_with_retry.assert_not_awaited()
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -197,8 +239,12 @@ class TestToolEventProgress:
|
|||||||
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
||||||
loop.tools.execute = AsyncMock(return_value="ok")
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
streamed: list[str] = []
|
||||||
progress: list[tuple[str, bool, list[dict] | None]] = []
|
progress: list[tuple[str, bool, list[dict] | None]] = []
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
streamed.append(delta)
|
||||||
|
|
||||||
async def on_progress(
|
async def on_progress(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@ -207,14 +253,15 @@ class TestToolEventProgress:
|
|||||||
) -> None:
|
) -> None:
|
||||||
progress.append((content, tool_hint, tool_events))
|
progress.append((content, tool_hint, tool_events))
|
||||||
|
|
||||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||||
|
[],
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
)
|
||||||
|
|
||||||
assert final_content == "Done"
|
assert final_content == "Done"
|
||||||
assert [item[0] for item in progress[:3]] == [
|
assert streamed == ["I will", " inspect it."]
|
||||||
"I will",
|
assert progress[0][0] == 'custom_tool("foo.txt")'
|
||||||
" inspect it.",
|
|
||||||
'custom_tool("foo.txt")',
|
|
||||||
]
|
|
||||||
assert all(item[0] != "I will inspect it." for item in progress)
|
assert all(item[0] != "I will inspect it." for item in progress)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -643,7 +643,7 @@ def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
|||||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.utils.helpers.logger.warning",
|
"nanobot.utils.helpers.logger.exception",
|
||||||
lambda message, *args: warnings.append(message.format(*args)),
|
lambda message, *args: warnings.append(message.format(*args)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
79
tests/agent/test_runner_progress_deltas.py
Normal file
79
tests/agent/test_runner_progress_deltas.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for provider progress delta routing in the shared runner."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_can_disable_provider_progress_delta_streaming():
|
||||||
|
"""AgentLoop disables token progress streaming for non-streaming channels."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
progress_cb = AsyncMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
progress_callback=progress_cb,
|
||||||
|
stream_progress_deltas=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
provider.chat_stream_with_retry.assert_not_awaited()
|
||||||
|
progress_cb.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_streams_provider_progress_deltas_by_default():
|
||||||
|
"""Direct runner users keep the existing opt-in provider progress behavior."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("he")
|
||||||
|
await on_content_delta("llo")
|
||||||
|
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
provider.chat_with_retry = AsyncMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
progress_cb = AsyncMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
progress_callback=progress_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "hello"
|
||||||
|
assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"]
|
||||||
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
@ -8,9 +8,9 @@ def _tc(name: str, args) -> ToolCallRequest:
|
|||||||
return ToolCallRequest(id="c1", name=name, arguments=args)
|
return ToolCallRequest(id="c1", name=name, arguments=args)
|
||||||
|
|
||||||
|
|
||||||
def _hint(calls):
|
def _hint(calls, max_length=40):
|
||||||
"""Shortcut for format_tool_hints."""
|
"""Shortcut for format_tool_hints."""
|
||||||
return format_tool_hints(calls)
|
return format_tool_hints(calls, max_length=max_length)
|
||||||
|
|
||||||
|
|
||||||
class TestToolHintKnownTools:
|
class TestToolHintKnownTools:
|
||||||
@ -254,3 +254,59 @@ class TestToolHintMixedFolding:
|
|||||||
assert "\u00d7" not in result
|
assert "\u00d7" not in result
|
||||||
parts = result.split(", ")
|
parts = result.split(", ")
|
||||||
assert len(parts) == 5
|
assert len(parts) == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHintMaxLength:
|
||||||
|
"""Test max_length parameter controls truncation of tool hints."""
|
||||||
|
|
||||||
|
def test_exec_default_truncates_at_40(self):
|
||||||
|
cmd = "cd /very/long/path/to/some/project && npm run build && npm test"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})], max_length=40)
|
||||||
|
assert len(result) <= 50 # "$ " prefix + 40 + ellipsis
|
||||||
|
assert "\u2026" in result
|
||||||
|
|
||||||
|
def test_exec_larger_max_length_shows_more(self):
|
||||||
|
cmd = "cd /very/long/path/to/some/project && npm run build && npm test"
|
||||||
|
short = _hint([_tc("exec", {"command": cmd})], max_length=40)
|
||||||
|
long = _hint([_tc("exec", {"command": cmd})], max_length=120)
|
||||||
|
assert len(long) > len(short)
|
||||||
|
assert "npm test" in long
|
||||||
|
|
||||||
|
def test_exec_max_length_120_shows_full_command(self):
|
||||||
|
cmd = "cd /home/user/project && npm install && npm run build"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})], max_length=120)
|
||||||
|
assert "npm run build" in result
|
||||||
|
|
||||||
|
def test_fallback_respects_max_length(self):
|
||||||
|
long_val = "a" * 100
|
||||||
|
result = _hint([_tc("custom_tool", {"data": long_val})], max_length=60)
|
||||||
|
assert "\u2026" in result
|
||||||
|
result_40 = _hint([_tc("custom_tool", {"data": long_val})], max_length=40)
|
||||||
|
assert len(result) > len(result_40)
|
||||||
|
|
||||||
|
def test_mcp_respects_max_length(self):
|
||||||
|
long_url = "https://example.com/very/long/path/to/resource"
|
||||||
|
result = _hint([_tc("mcp_github__fetch", {"url": long_url})], max_length=80)
|
||||||
|
result_40 = _hint([_tc("mcp_github__fetch", {"url": long_url})], max_length=40)
|
||||||
|
assert len(result) >= len(result_40)
|
||||||
|
|
||||||
|
def test_path_type_respects_max_length(self):
|
||||||
|
"""Path-type tools (read_file, write_file, etc.) should honor max_length."""
|
||||||
|
long_path = "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"
|
||||||
|
short = _hint([_tc("read_file", {"path": long_path})], max_length=40)
|
||||||
|
long = _hint([_tc("read_file", {"path": long_path})], max_length=120)
|
||||||
|
assert len(long) > len(short)
|
||||||
|
|
||||||
|
def test_edit_path_respects_max_length(self):
|
||||||
|
"""edit (is_path=True) should honor max_length, not stay hardcoded at 40."""
|
||||||
|
long_path = "/home/user/projects/nanobot/src/agent/loop.py"
|
||||||
|
short = _hint([_tc("edit", {"file_path": long_path})], max_length=40)
|
||||||
|
long = _hint([_tc("edit", {"file_path": long_path})], max_length=120)
|
||||||
|
assert len(long) > len(short)
|
||||||
|
|
||||||
|
def test_list_dir_path_respects_max_length(self):
|
||||||
|
"""list_dir (is_path=True) should honor max_length."""
|
||||||
|
long_path = "/home/user/.local/share/uv/tools/nanobot/"
|
||||||
|
short = _hint([_tc("list_dir", {"path": long_path})], max_length=40)
|
||||||
|
long = _hint([_tc("list_dir", {"path": long_path})], max_length=120)
|
||||||
|
assert len(long) > len(short)
|
||||||
|
|||||||
@ -306,17 +306,19 @@ async def test_on_error_logs_network_issues_as_warning(monkeypatch) -> None:
|
|||||||
recorded: list[tuple[str, str]] = []
|
recorded: list[tuple[str, str]] = []
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.telegram.logger.warning",
|
channel.logger,
|
||||||
|
"warning",
|
||||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.telegram.logger.error",
|
channel.logger,
|
||||||
|
"error",
|
||||||
lambda message, error: recorded.append(("error", message.format(error))),
|
lambda message, error: recorded.append(("error", message.format(error))),
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._on_error(object(), SimpleNamespace(error=NetworkError("proxy disconnected")))
|
await channel._on_error(object(), SimpleNamespace(error=NetworkError("proxy disconnected")))
|
||||||
|
|
||||||
assert recorded == [("warning", "Telegram network issue: proxy disconnected")]
|
assert recorded == [("warning", "network issue: proxy disconnected")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -330,13 +332,14 @@ async def test_on_error_summarizes_empty_network_error(monkeypatch) -> None:
|
|||||||
recorded: list[tuple[str, str]] = []
|
recorded: list[tuple[str, str]] = []
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.telegram.logger.warning",
|
channel.logger,
|
||||||
|
"warning",
|
||||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._on_error(object(), SimpleNamespace(error=NetworkError("")))
|
await channel._on_error(object(), SimpleNamespace(error=NetworkError("")))
|
||||||
|
|
||||||
assert recorded == [("warning", "Telegram network issue: NetworkError")]
|
assert recorded == [("warning", "network issue: NetworkError")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -348,17 +351,19 @@ async def test_on_error_keeps_non_network_exceptions_as_error(monkeypatch) -> No
|
|||||||
recorded: list[tuple[str, str]] = []
|
recorded: list[tuple[str, str]] = []
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.telegram.logger.warning",
|
channel.logger,
|
||||||
|
"warning",
|
||||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.telegram.logger.error",
|
channel.logger,
|
||||||
|
"error",
|
||||||
lambda message, error: recorded.append(("error", message.format(error))),
|
lambda message, error: recorded.append(("error", message.format(error))),
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._on_error(object(), SimpleNamespace(error=RuntimeError("boom")))
|
await channel._on_error(object(), SimpleNamespace(error=RuntimeError("boom")))
|
||||||
|
|
||||||
assert recorded == [("error", "Telegram error: boom")]
|
assert recorded == [("error", "error: boom")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -835,7 +835,7 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
|
|||||||
ch = make_channel()
|
ch = make_channel()
|
||||||
errors = []
|
errors = []
|
||||||
monkeypatch.setattr(msteams_module, "MSTEAMS_AVAILABLE", False)
|
monkeypatch.setattr(msteams_module, "MSTEAMS_AVAILABLE", False)
|
||||||
monkeypatch.setattr(msteams_module.logger, "error", lambda message, *args: errors.append(message.format(*args)))
|
monkeypatch.setattr(ch.logger, "error", lambda message, *args: errors.append(message.format(*args)))
|
||||||
|
|
||||||
await ch.start()
|
await ch.start()
|
||||||
|
|
||||||
|
|||||||
@ -467,7 +467,7 @@ async def test_connect_mcp_servers_logs_stdio_pollution_hint(
|
|||||||
yield # pragma: no cover
|
yield # pragma: no cover
|
||||||
|
|
||||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
|
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
|
||||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error)
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.exception", _error)
|
||||||
|
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
|
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user