Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bot/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Group(Base):
filters: Mapped[list["Filter"]] = relationship(back_populates="group", cascade="all, delete-orphan")
blacklists: Mapped[list["Blacklist"]] = relationship(back_populates="group", cascade="all, delete-orphan")
warn_filters: Mapped[list["WarnFilter"]] = relationship(back_populates="group", cascade="all, delete-orphan")
notes: Mapped[list["Note"]] = relationship(back_populates="group", cascade="all, delete-orphan")


class GroupSettings(Base):
Expand Down Expand Up @@ -137,6 +138,22 @@ class Blacklist(Base):
group: Mapped["Group"] = relationship(back_populates="blacklists")


class Note(Base):
__tablename__ = "notes"

id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
group_id: Mapped[int] = mapped_column(
BigInteger, ForeignKey("groups_.telegram_id", ondelete="CASCADE")
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[str | None] = mapped_column(Text)
file_id: Mapped[str | None] = mapped_column(String(255))
file_type: Mapped[str | None] = mapped_column(String(50))
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)

group: Mapped["Group"] = relationship(back_populates="notes")


class RssFeed(Base):
__tablename__ = "rss_feeds"

Expand Down
44 changes: 43 additions & 1 deletion bot/database/repo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy import select, delete, func
from bot.database.engine import async_session
from bot.database.models import User, Group, GroupSettings, Warning, StickerPack, Filter, Blacklist, RssFeed, WarnFilter
from bot.database.models import User, Group, GroupSettings, Warning, StickerPack, Filter, Blacklist, RssFeed, WarnFilter, Note


class Repository:
Expand Down Expand Up @@ -322,3 +322,45 @@ async def get_warn_filters(group_id: int) -> list[WarnFilter]:
select(WarnFilter).where(WarnFilter.group_id == group_id)
)
return list(result.all())

@staticmethod
async def add_note(group_id: int, name: str, content: str, file_id: str = None, file_type: str = None) -> Note:
async with async_session() as session:
name = name.lower()
existing = await session.scalar(
select(Note).where(Note.group_id == group_id, Note.name == name)
)
if existing:
existing.content = content
existing.file_id = file_id
existing.file_type = file_type
else:
existing = Note(group_id=group_id, name=name, content=content, file_id=file_id, file_type=file_type)
session.add(existing)
await session.commit()
await session.refresh(existing)
return existing

@staticmethod
async def get_note(group_id: int, name: str) -> Note | None:
async with async_session() as session:
return await session.scalar(
select(Note).where(Note.group_id == group_id, Note.name == name.lower())
)

@staticmethod
async def remove_note(group_id: int, name: str) -> bool:
async with async_session() as session:
result = await session.execute(
delete(Note).where(Note.group_id == group_id, Note.name == name.lower())
)
await session.commit()
return result.rowcount > 0

@staticmethod
async def get_notes(group_id: int) -> list[Note]:
async with async_session() as session:
result = await session.scalars(
select(Note).where(Note.group_id == group_id)
)
return list(result.all())
2 changes: 1 addition & 1 deletion bot/plugins/general/afk.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def register(app: Application):
afk,
), group=AFK_GROUP)
app.add_handler(MessageHandler(
filters.ALL & ~filters.COMMAND & filters.ChatType.GROUPS & ~filters.StatusUpdate & ~filters.UpdateType.EDITED_MESSAGE,
filters.ALL & ~filters.COMMAND & filters.ChatType.GROUPS & ~filters.StatusUpdate.ALL,
no_longer_afk,
), group=AFK_GROUP)
app.add_handler(MessageHandler(
Expand Down
186 changes: 186 additions & 0 deletions bot/plugins/group/notes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Application, CommandHandler, MessageHandler, CallbackQueryHandler, filters, ContextTypes
from bot.database.repo import Repository
from bot.logger import get_logger
from bot.utils.decorators import group_only, admin_only

logger = get_logger(__name__)


@group_only
@admin_only
async def save_note(update: Update, context: ContextTypes.DEFAULT_TYPE):
args = context.args
if not args:
await update.effective_message.reply_text("Usage: /save <note_name> (or reply to a message)")
return

note_name = args[0].lower()
content = ""
file_id = None
file_type = None

reply = update.effective_message.reply_to_message

if reply:
if reply.text:
content = reply.text
elif reply.caption:
content = reply.caption

if reply.photo:
file_id = reply.photo[-1].file_id
file_type = "photo"
elif reply.video:
file_id = reply.video.file_id
file_type = "video"
elif reply.sticker:
file_id = reply.sticker.file_id
file_type = "sticker"
elif reply.document:
file_id = reply.document.file_id
file_type = "document"
elif reply.audio:
file_id = reply.audio.file_id
file_type = "audio"
elif reply.voice:
file_id = reply.voice.file_id
file_type = "voice"
elif reply.animation:
file_id = reply.animation.file_id
file_type = "animation"
else:
if len(args) < 2:
await update.effective_message.reply_text("Please provide content or reply to a message.")
return
content = " ".join(args[1:])

await Repository.upsert_group(update.effective_chat.id, update.effective_chat.title)
await Repository.add_note(
group_id=update.effective_chat.id,
name=note_name,
content=content,
file_id=file_id,
file_type=file_type
)

logger.info("Note '#%s' saved in chat %s", note_name, update.effective_chat.id)
await update.effective_message.reply_text(f"Note <code>#{note_name}</code> saved!", parse_mode="HTML")


@group_only
async def get_note(update: Update, context: ContextTypes.DEFAULT_TYPE):
args = context.args
if not args:
await update.effective_message.reply_text("Usage: /get <note_name>")
return

note_name = args[0].lower()
note = await Repository.get_note(update.effective_chat.id, note_name)

if not note:
await update.effective_message.reply_text("Note not found.")
return

await _send_note(update, note)


@group_only
@admin_only
async def clear_note(update: Update, context: ContextTypes.DEFAULT_TYPE):
args = context.args
if not args:
await update.effective_message.reply_text("Usage: /clear <note_name>")
return

note_name = args[0].lower()
deleted = await Repository.remove_note(update.effective_chat.id, note_name)

if deleted:
logger.info("Note '#%s' deleted in chat %s", note_name, update.effective_chat.id)
await update.effective_message.reply_text(f"Note <code>#{note_name}</code> deleted.", parse_mode="HTML")
else:
await update.effective_message.reply_text("Note not found.")


@group_only
async def list_notes(update: Update, context: ContextTypes.DEFAULT_TYPE):
notes = await Repository.get_notes(update.effective_chat.id)
if not notes:
await update.effective_message.reply_text("No notes saved in this group.")
return

text = f"✨ <b>Notes for {update.effective_chat.title}</b>\n\nYou can use <code>#notename</code> to recall them.\n"

buttons = []
sorted_notes = sorted(notes, key=lambda x: x.name)

current_row = []
for note in sorted_notes:
current_row.append(InlineKeyboardButton(f"📎 {note.name}", callback_data=f"get_note:{note.name}"))
if len(current_row) == 2:
buttons.append(current_row)
current_row = []
if current_row:
buttons.append(current_row)

await update.effective_message.reply_text(text, parse_mode="HTML", reply_markup=InlineKeyboardMarkup(buttons))


async def note_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.callback_query
note_name = query.data.split(":")[1]
note = await Repository.get_note(update.effective_chat.id, note_name)

if note:
await _send_note(update, note)
await query.answer()
else:
await query.answer("Note not found.", show_alert=True)


async def hashtag_listener(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.effective_message or not update.effective_message.text:
return

text = update.effective_message.text
if not text.startswith("#") or len(text) < 2:
return

note_name = text.split()[0][1:].lower()
note = await Repository.get_note(update.effective_chat.id, note_name)

if note:
await _send_note(update, note)


async def _send_note(update: Update, note):
msg = update.effective_message
reply_to = msg.reply_to_message.message_id if msg.reply_to_message else msg.message_id

if note.file_id:
if note.file_type == "photo":
await msg.reply_photo(note.file_id, caption=note.content, reply_to_message_id=reply_to)
elif note.file_type == "video":
await msg.reply_video(note.file_id, caption=note.content, reply_to_message_id=reply_to)
elif note.file_type == "sticker":
await msg.reply_sticker(note.file_id, reply_to_message_id=reply_to)
elif note.file_type == "document":
await msg.reply_document(note.file_id, caption=note.content, reply_to_message_id=reply_to)
elif note.file_type == "audio":
await msg.reply_audio(note.file_id, caption=note.content, reply_to_message_id=reply_to)
elif note.file_type == "voice":
await msg.reply_voice(note.file_id, caption=note.content, reply_to_message_id=reply_to)
elif note.file_type == "animation":
await msg.reply_animation(note.file_id, caption=note.content, reply_to_message_id=reply_to)
else:
await msg.reply_text(note.content, reply_to_message_id=reply_to)


def register(app: Application):
app.add_handler(CommandHandler("save", save_note))
app.add_handler(CommandHandler("get", get_note))
app.add_handler(CommandHandler("clear", clear_note))
app.add_handler(CommandHandler("notes", list_notes))
app.add_handler(CallbackQueryHandler(note_callback, pattern=r"^get_note:"))
app.add_handler(MessageHandler(filters.TEXT & filters.ChatType.GROUPS, hashtag_listener))