137 lines
3.6 KiB
Python
137 lines
3.6 KiB
Python
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import signal
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from nio import (
|
||
|
|
AsyncClient,
|
||
|
|
AsyncClientConfig,
|
||
|
|
LoginResponse,
|
||
|
|
RoomMessageText,
|
||
|
|
)
|
||
|
|
|
||
|
|
from config import (
|
||
|
|
MATRIX_HOMESERVER,
|
||
|
|
MATRIX_USER_ID,
|
||
|
|
MATRIX_ACCESS_TOKEN,
|
||
|
|
MATRIX_DEVICE_ID,
|
||
|
|
LOG_LEVEL,
|
||
|
|
ConfigValidator,
|
||
|
|
)
|
||
|
|
from callbacks import Callbacks
|
||
|
|
from utils import setup_logging
|
||
|
|
|
||
|
|
logger = setup_logging(LOG_LEVEL)
|
||
|
|
|
||
|
|
CREDENTIALS_FILE = Path("credentials.json")
|
||
|
|
STORE_PATH = Path("nio_store")
|
||
|
|
|
||
|
|
|
||
|
|
def save_credentials(resp, homeserver):
|
||
|
|
data = {
|
||
|
|
"homeserver": homeserver,
|
||
|
|
"user_id": resp.user_id,
|
||
|
|
"device_id": resp.device_id,
|
||
|
|
"access_token": resp.access_token,
|
||
|
|
}
|
||
|
|
CREDENTIALS_FILE.write_text(json.dumps(data, indent=2))
|
||
|
|
logger.info("Credentials saved to %s", CREDENTIALS_FILE)
|
||
|
|
|
||
|
|
|
||
|
|
def trust_devices(client: AsyncClient):
|
||
|
|
"""Auto-trust all devices for all users we share rooms with."""
|
||
|
|
if not client.olm:
|
||
|
|
logger.warning("Olm not loaded, skipping device trust")
|
||
|
|
return
|
||
|
|
for user_id, devices in client.device_store.items():
|
||
|
|
for device_id, olm_device in devices.items():
|
||
|
|
if not client.olm.is_device_verified(olm_device):
|
||
|
|
client.verify_device(olm_device)
|
||
|
|
logger.info("Trusted all known devices")
|
||
|
|
|
||
|
|
|
||
|
|
async def main():
|
||
|
|
errors = ConfigValidator.validate()
|
||
|
|
if errors:
|
||
|
|
for e in errors:
|
||
|
|
logger.error(e)
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
STORE_PATH.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
client_config = AsyncClientConfig(
|
||
|
|
store_sync_tokens=True,
|
||
|
|
encryption_enabled=True,
|
||
|
|
store_name="matrixbot",
|
||
|
|
)
|
||
|
|
|
||
|
|
client = AsyncClient(
|
||
|
|
MATRIX_HOMESERVER,
|
||
|
|
MATRIX_USER_ID,
|
||
|
|
device_id=MATRIX_DEVICE_ID,
|
||
|
|
config=client_config,
|
||
|
|
store_path=str(STORE_PATH),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Restore access token (no password login needed)
|
||
|
|
client.access_token = MATRIX_ACCESS_TOKEN
|
||
|
|
client.user_id = MATRIX_USER_ID
|
||
|
|
client.device_id = MATRIX_DEVICE_ID
|
||
|
|
|
||
|
|
# Load the olm/e2ee store if it exists
|
||
|
|
client.load_store()
|
||
|
|
|
||
|
|
callbacks = Callbacks(client)
|
||
|
|
client.add_event_callback(callbacks.message, RoomMessageText)
|
||
|
|
|
||
|
|
# Graceful shutdown
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
shutdown_event = asyncio.Event()
|
||
|
|
|
||
|
|
def _signal_handler():
|
||
|
|
logger.info("Shutdown signal received")
|
||
|
|
shutdown_event.set()
|
||
|
|
|
||
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||
|
|
loop.add_signal_handler(sig, _signal_handler)
|
||
|
|
|
||
|
|
logger.info("Starting initial sync...")
|
||
|
|
|
||
|
|
# Do a first sync to catch up, then mark startup complete so we only
|
||
|
|
# process new messages going forward.
|
||
|
|
sync_resp = await client.sync(timeout=30000, full_state=True)
|
||
|
|
if hasattr(sync_resp, "next_batch"):
|
||
|
|
callbacks.startup_sync_token = sync_resp.next_batch
|
||
|
|
logger.info("Initial sync complete, token: %s", sync_resp.next_batch[:20])
|
||
|
|
else:
|
||
|
|
logger.error("Initial sync failed: %s", sync_resp)
|
||
|
|
await client.close()
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Trust devices after initial sync loads the device store
|
||
|
|
trust_devices(client)
|
||
|
|
|
||
|
|
logger.info("Bot ready as %s — listening for commands", MATRIX_USER_ID)
|
||
|
|
|
||
|
|
# Run sync_forever in a task so we can cancel on shutdown
|
||
|
|
async def _sync_loop():
|
||
|
|
await client.sync_forever(timeout=30000, full_state=False)
|
||
|
|
|
||
|
|
sync_task = asyncio.create_task(_sync_loop())
|
||
|
|
|
||
|
|
await shutdown_event.wait()
|
||
|
|
sync_task.cancel()
|
||
|
|
try:
|
||
|
|
await sync_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
await client.close()
|
||
|
|
logger.info("Bot shut down cleanly")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
asyncio.run(main())
|