跳转至

3793. 查找高词元使用量的用户

题目描述

表:prompts

+-------------+---------+
| Column Name | Type    |
+-------------+---------+
| user_id     | int     |
| prompt      | varchar |
| tokens      | int     |
+-------------+---------+
(user_id, prompt) 是这张表的主键(值互不相同)。
每一行表示一个用户提交给 AI 系统的提示词以及所消耗的词元数量。

根据下列要求编写一个解决方案来分析 AI 提示词的使用模式

  • 对每一个用户,计算他们提交的 提示词的总数
  • 对每个用户,计算 每个提示词所使用的平均词元数(舍入到 2 位小数)。
  • 仅包含 至少提交了 3 个提示词 的用户。
  • 仅包含那些 至少提交过一个提示词 且其中的 tokens 数量 超过 自己平均词元使用量的用户。

返回结果表按 平均词元数 降序 排序,然后按 user_id 升序 排序。

结果格式如下所示。

 

示例:

输入:

prompts 表:

+---------+--------------------------+--------+
| user_id | prompt                   | tokens |
+---------+--------------------------+--------+
| 1       | Write a blog outline     | 120    |
| 1       | Generate SQL query       | 80     |
| 1       | Summarize an article     | 200    |
| 2       | Create resume bullet     | 60     |
| 2       | Improve LinkedIn bio     | 70     |
| 3       | Explain neural networks  | 300    |
| 3       | Generate interview Q&A   | 250    |
| 3       | Write cover letter       | 180    |
| 3       | Optimize Python code     | 220    |
+---------+--------------------------+--------+

输出:

+---------+---------------+------------+
| user_id | prompt_count  | avg_tokens |
+---------+---------------+------------+
| 3       | 4             | 237.5      |
| 1       | 3             | 133.33     |
+---------+---------------+------------+

解释:

  • 用户 1:
    • 总提示词数 = 3
    • 平均词元数 = (120 + 80 + 200) / 3 = 133.33
    • 有一个提示词为 200 个词元,这超过了平均值
    • 包含在结果中
  • 用户 2:
    • 总提示词数 = 2(少于所需的最小值)
    • 从结果中排除
  • 用户 3:
    • 总提示词数 = 4
    • 平均词元数 = (300 + 250 + 180 + 220) / 4 = 237.5
    • 有包含 300 和 250 个词元的提示词,都大于平均数
    • 包含在结果中

结果表按 avg_tokens 降序排序,然后按 user_id 升序排序。

解法

方法一:分组统计

我们首先将数组按照 user_id 进行分组统计,计算每个用户的提示词数量 prompt_count、平均令牌数 avg_tokens 以及最大令牌数 max_tokens。然后筛选出满足条件的用户,即提示词数量不少于 3 且存在提示词的令牌数大于平均令牌数的用户。最后按照平均令牌数降序和用户 ID 升序排序输出结果。

1
2
3
4
5
6
7
8
9
# Write your MySQL query statement below
SELECT
    user_id,
    COUNT(1) AS prompt_count,
    ROUND(AVG(tokens), 2) AS avg_tokens
FROM prompts
GROUP BY user_id
HAVING prompt_count >= 3 AND MAX(tokens) > avg_tokens
ORDER BY avg_tokens DESC, user_id;
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import pandas as pd


def find_users_with_high_tokens(prompts: pd.DataFrame) -> pd.DataFrame:
    df = prompts.groupby("user_id", as_index=False).agg(
        prompt_count=("user_id", "size"),
        avg_tokens=("tokens", "mean"),
        max_tokens=("tokens", "max"),
    )

    df["avg_tokens"] = df["avg_tokens"].round(2)

    df = df[(df["prompt_count"] >= 3) & (df["max_tokens"] > df["avg_tokens"])]

    df = (
        df.sort_values(["avg_tokens", "user_id"], ascending=[False, True])
        .loc[:, ["user_id", "prompt_count", "avg_tokens"]]
        .reset_index(drop=True)
    )

    return df

评论