sql总结

351 阅读14分钟

常用sql总结

日常SQL使用笔记

count函数返回匹配指定条件的行数

1. COUNT(column_name) 返回指定列的值的数目(NULL 不计入):
  SELECT COUNT(column_name) FROM table_name;
2. COUNT(*) 返回表中的记录数:
  SELECT COUNT(*) FROM table_name;
3. COUNT(DISTINCT column_name)   返回指定列的不同值的数目:
  SELECT COUNT(DISTINCT column_name) FROM table_name;
  
  注释:COUNT(DISTINCT) 适用于 ORACLE 和   Microsoft     SQL   Server,但是无法用于 Microsoft Access。
1.count(student_id) 返回学生id的行数,如果是null,则不计入
2.如果需要得到【成绩为90分】以及缺考的学生(成绩为null)人数,count(score = 90 or null)
为什么要加 or null 部分呢?
因为当score = 90 时返回为True,score不等于90时返回为False,而只有值为null时不计数,无论是True还是False都会计数。
3.还可以通过其他两种方式获得限制条件后的记录数量。
如果需要得到【成绩大于90分】的学生人数,
(1)sum (case when score > 90 then 1 else 0 end)
(2)count (if(score > 90, 1, null))
或者可以省略去if函数里的null,if函数会判断score函数是否大于90,如果满足返回1,不满足默认返回null
count (if(score > 90, 1))
综上所述,count加条件的核心依据是count对null不计数
4.如果想对null值也同样计数,可使用count(1),count(1)会对所有行计数
eg:通过【成绩】统计有多少学生参加考试,包括缺考
————————————————
版权声明:本文为CSDN博主「除了学习无所事事」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_41011449/article/details/115025766

sum函数求和

  • SUM() 函数返回数值列的总数。

avg求均值

length/size(hive中的函数)求长度

alert table删除hive表的某个分区

/* 删除表的某个分区:alter table table_name  drop partition (version = 'test') */
alter table
        ad_yuntu.yuntu_tagging_filter_data_with_industry drop partition (industry = 'clothing')

explode对列表进行拆分

select  topic_id,
        keyword_data.keyword
from    ad_yuntu.topic_keywords_count_2 t
lateral view
        explode(t.topic_keywords) as keyword_data
where   t.version = '2'
/* 把topic_keywords列表中的每一项分别拿出来 */
/* keyword_data是一个struct值,通过.keyword获取keyword属性 */

lateral view生成包含一行或多行的虚拟表

# Lateral view通常与生成器函数结合使用,比如explode,生成包含一行或多行的虚拟表
# 举例见上文explode
# 如果要拆分的字段有null值,需要使用lateral view outer 替代,避免数据缺失

row_number给每一条记录编号

每个cluster_id随机抽样200条

select  *
from    (
            select  cluster_id,
                    cluster_keywords,
                    topic_id,
                    topic_name,
                    topic_keywords,
                    row_number() over (
                        partition by
                                cluster_id
                        order by
                                rand()
                    ) as rn
            from    ad_yuntu.topic_cluster_relation
            where   app_id = 1128
            and     industry_id = 5
            and     version = 'kmeans'
        ) t
where   t.rn <= 200

row_number的高级用法,配合where使用,进行数据的筛选

和groupby进行去重的区别,groupby只能进行去重,row_number配合where可以按照某种条件进行选择性的去重

内层的row_number对brand_id和product_name同时重名的数据进行过滤,只保留product_gmv数据最大的一条。

外层的row_number对brand_id按照product_gmv进行降序处理,只保留gmv高的前十条

from    (
            select  brand_id,
                    product_id,
                    product_name,
                    shop_name,
                    first_name_new,
                    product_gmv,
                    row_number() over (
                        partition by
                                brand_id
                        order by
                                product_gmv desc
                    ) as rn
            from    (
                        select  *,
                                row_number() over (
                                    partition by
                                            brand_id,
                                            product_name
                                    order by
                                            product_gmv desc
                                ) as rn
                        from    joiner_data
                    ) t
            where   t.rn <= 1
        ) t2
where   t2.rn <= 10

collect_set生成不重复的列表/collect_list生成不去重列表

把相关的内容生成一个集合

SELECT
    topic_id_1,
    collect_set(topic_id_2)
from
    ad_yuntu.topic_similarity
where
    industry_id = 5
    and app_id = 1128
    and version = 'v1'
group by
    topic_id_1
limit
    100000000

if提供判断逻辑

if(cost > 0, 1, null) as flag
# flag字段,如果cost字段大于0,flag为1,否则flag为null

coalesce返回第一个非空值

coalesce (expression_1, expression_2, ...,expression_n),依次参考各参数表达式,遇到非null值即停止并返回该值。如果所有的表达式都是空值,最终将返回一个空值。使用COALESCE在于大部分包含空值的表达式最终将返回空值。
select coalesce(success_cnt, 1) from table1
--当success_cnt 为null值的时候,将返回1,否则将返回success_cnt的真实值
select coalesce(success_cnt,period,1) from table2
--当success_cnt不为null,那么无论period是否为null,都将返回success_cnt的真实值(因为success_cnt是第一个参数)
--当success_cnt为null,而period不为null的时候,返回period的真实值。只有当success_cnt和period均为null的时候,将返回1

nvl根据值是否为null进行返回

nvl(gender, '')
# 如果gender is not null, the value is gender, otherwise ther value is ''

concat/concat_ws连接字符串类型的字段

select  p_date,
        main_brand_id,
        level_1_industry_id,
        ad_id,
        trigger_point,
        concat(
            nvl(gender, ''),
            '#',
            nvl(age, ''),
            '#',
            nvl(city, '')
        ) as target_val,
        cost,
        new_cost
from    ad_measure.ba_real_ta_data
where   p_date between '20220301' and '20220531'
and     trigger_point is not null
and     cost > 0
concat_ws( '-', '2022', '08', '11')
# 返回 2022-08-11

group_concat对group by之后的内容进行连接

group_concat(distinct col1 order by col2 separator '_')

split分割字符串

split('a,b,c,d,e', ',') # 按照第二个参数对第一个参数进行分割
# |需要进行转义,并且使用\进行转义
cast(split(bidlandscape_redis_history_data, '\|')[0] as double) as budget

percentile_approx计算某比例位置的数

percentile_approx(new_cost, 0.5)
# 返回new_cost字段的中位数(即0.5*100%位置)

substr(用法等于substring)截取字符串

SUBSTR(str,pos,len): 从pos开始的位置,截取len个字符

substr(string ,1,3) :取string左边第1位置起,3字长的字符串。
所以结果为: str
substr(string, -1,3):取string右边第1位置起,3字长的字符串。显然右边第一位置起往右不够3字长。结果只能是: g
substr(string, -3,3):取string右边第3位置起,3字长的字符串。
结果为: ing
SUBSTR(str,pos): pos开始的位置,一直截取到最后

substr(string ,4) : 从右第4位置截取到最后
结果是: ing
————————————————
版权声明:本文为CSDN博主「MaggieChenn」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/u012973218/article/details/71374314

round, floor, ceil, ceiling对浮点数进行操作

# round是四舍五入返回指定几位小数的数值
# floor是返回小数的整数数值
# ceil和ceiling都是返回大于等于参数的最小整
round(3.1415926, 2)  # 返回3 .14
floor(3.1415926)  # 返回3
ceil(3.1415926)  # 返回4
ceiling(3.1415926)  # 返回4

分组最大值 / 最小值 firs_tvalue / last_value

select
    *,
    ## 分组取每个组的最小值对应的人
    first_value(name) over (PARTITION BY department ORDER BY cost) as min_cost_user,
    ## 分组取每个组的最大值对应的人
    last_value(name) over (PARTITION BY department ORDER BY cost) as max_cost_user
from table

不同格式数据的转换:cast

# 把A的类型转换为字符串
select cast(A as string) as A

unset、array_agg使用,合并列表

SELECT
    title,
    array_agg(DISTINCT tag)
FROM(
        SELECT
            title,
            unnest(tags)
        FROM
            my_test
    ) AS t(title, tag)
GROUP BY
    title

unix_timestamp将字符串转为10位的时间戳值

# Note: yyyy-MM-dd HH:mm:ss
unix_timestamp('20200802','yyyyMMdd') # 返回的是bigint

from_unixtime(bigint unixtime,string format)

  • Hive

将时间戳秒数转化为UTC时间,并用字符串表示,可通过format规定的时间格式,指定输出的时间格式,其中unixtime 是10位的时间戳值,而13位的所谓毫秒的是不可以的。

  • Mysql

    •   SELECT from_unixtime(1156219870, "%Y-%m-%d %H:%m:%s");
  • 举例

    • with city_data as (
          select  get_json_object(get_json_object(data, '$.audience'), '$.city') as city
          from    ad_dim.dim_brand_stock
          where   p_date = '${date}'
          and     from_unixtime(unix_timestamp(create_time, 'yyyy-MM-dd'), 'yyyyMMdd') between '${date-180}' and '${date}'
      )
      select  city,
              count(1) as cnt
      from    city_data
      group by
              city
      

current_timestamp毫秒级时间戳

blog.csdn.net/mingming205…

get_json_object提取json格式中的数据

# 第一个参数为json对象变量,第二个参数使用$表示json变量标识,然后用.或[]读取对象或数组,如果json无效,返回null,每次只能返回一个数据项
举例:
data = {
                 "store":  {
                 "fruit":[{"size":8,"name":"apple"}, {"size":9,"name":"pear"}],
                 "bicycle":{"price":20,"color":"yellow"}
                  },
               "email":"amy@123.com",
                "owner":"amy"
        }
--获取单层json
select get_json_object(data,"$.owner") from table
--获取多层json
select get_json_object(data,"$.store.fruit") from table
--获取数组值[ ]
select get_json_object(data,"$.store.fruit[0]") from table
--获取所有name值
select get_json_object('[[{"id":"123","name":"苹果"}],[{"id":"456","name":"香蕉"}],[{"id":"789","name":"西瓜"}]]', '$.[*].[*].name');

rand生成随机数

常用于order by后面,随机打乱数据

select  main_brand_id,
        level_1_industry_id,
        tag,
        feature
from    ad_measure.brand_budget_allocation_v2_enhanced
where   p_date = '20220531'
and     size(feature) > 0
order by
        rand()

with as 给子表起别名

只需要写一个with

with click_cpa as (
    select  main_brand_id,
            level_1_industry_id,
            trigger_point,
            avg(new_cost / click_count) as avg_cpa
    from    new_v2_data
    where   click_count >= 5
    group by
            main_brand_id,
            level_1_industry_id,
            trigger_point
),
play_5s_cpa as (
    select  main_brand_id,
            level_1_industry_id,
            trigger_point,
            avg(new_cost / play_5s_count) as avg_cpa
    from    new_v2_data
    where   play_5s_count >= 5
    group by
            main_brand_id,
            level_1_industry_id,
            trigger_point
)

union合并多个子查询

select  *
    from    (
                select  *,
                        'a3' as label_name
                from    a3_cpa
                union
                select  *,
                        'click' as label_name
                from    click_cpa
                union
                select  *,
                        'play_5s' as label_name
                from    play_5s_cpa
                union
                select  *,
                        'play_over' as label_name
                from    play_over_cpa
                union
                select  *,
                        'interact' as label_name
                from    interact_cpa
            )

union会自动去重,union_all不会去重

blog.csdn.net/weixin_4238…

percentile_approx计算百分位数

select  label_name,
        percentile_approx(avg_cpa, array(0.1, 0.2, 0.3, 0.7, 0.8, 0.9))
from    union_cpa
group by
        label_name

使用row_number / rank / dense_rank进行排序

row_number() 则在排序相同时不重复,会根据顺序排序;

rank() 排序相同时会重复,总数不会变 ,意思是会出现1、1、3这样的排序结果;

dense_rank() 排序相同时会重复,总数会减少,意思是会出现1、1、2这样的排序结果.

select
    *,
    row_number() over ( partition by department order by cost desc ) as row_number_result,
    rank() over ( partition by department order by cost desc) as rank_result,
    dense_rank() over (partition by department order by cost desc) as dense_rank_result
from table
 

ARRAY_CONTAINS判断列表中是否出现某个值

main_industry_list_cn列表中是否出现'服饰'这个元素

select  main_industry_list_cn
from    ad_yuntu.ecom_dim_brand_info_df_latest
where   array_contains(main_industry_list_cn, '服饰')

字符串常见处理函数:length/trim/lower/upper

--使用length(string A)返回字符串A的长度
select length('abcedfg') 得到 7 

--使用trim(string A) 去除字符串两边的空格
select trim(' abc ') 得到 'abc'

--使用lower(string A)/ lcase(string A)返回字符串的小写形式,常用于不确定原始字段是否统一为大小写
select lower('abSEd') 得到 absed

--使用upper(string A)/ ucase(string A)返回字符串的大写形式,常用于不确定原始字段是否统一为大小写
select upper('abSEd') 得到 ABSED

with rollup对数据进行汇总统计

blog.csdn.net/qq_40591233…

SELECT coalesce(name, '总数'), SUM(signin) as signin_count 
FROM  employee_tbl 
GROUP BY name 
WITH ROLLUP

group by name之后,保留出id最大的那一条数据

使用row_number()进行编号,然后通过where进行筛选

stackoverflow.com/questions/1…

WITH ranked_messages AS (
  SELECT m.*, ROW_NUMBER() OVER (PARTITION BY name ORDER BY id DESC) AS rn
  FROM messages AS m
)
SELECT * FROM ranked_messages WHERE rn = 1

substring_index()函数进行字符串截取

zhuanlan.zhihu.com/p/109778760

    substring_index(str,delim,count)
    str:要处理的字符串
    delim:分隔符
    count:计数

示例:

如 str=www.wiki.com

则 substring_index(str,'.',1) 处理的结果是:www

substring_index(str,'.',2) 得到的结果是:www.wiki

也就是说,如果count是正数,那么就是从左往右数,第N个分隔符的左边的全部内容,

相反,如果是负数,那么就是从右边开始数,第N个分隔符右边的所有内容。

如:

substring_index(str,'.',-2) 得到的结果为:wikibt.com

如果要中间的的 wiki 怎么办?

很简单的,需要从两个方向截取:

先截取从右数第二个分隔符的右边的全部内容,再截取从左数的第一个分隔符的左边的全部内容:

substring_index(substring_index(str,'.',-2),'.',1);

Having和where结合快速过滤数量

select  brand_word,
        brand_id,
        industry_list
from    all_data
where   label = '0'
group by
        brand_word,
        brand_id,
        industry_list
having  count(*) >= 10

使用lateral view和explode解析map和list

having和where的使用区别

-- 原始查询(可能产生大量中间数据)
SELECT user_id, COUNT(*) 
FROM clicks 
GROUP BY user_id 
HAVING COUNT(*) > 1000;

-- 优化后的查询
SELECT user_id, cnt 
FROM (
    SELECT user_id, COUNT(*) as cnt 
    FROM clicks 
    GROUP BY user_id
) t 
WHERE cnt > 1000;

解析map中的key、value

saved_images: {
"https://cdn.pixabay.com/303587__340.jpg": "Rf91t4c2UQmHpt",
"https://cdn.pixabay.com/4285323__340.jpg": "RfAax31DrDD"
}
SELECT
    DISTINCT
    img_url,
    img_tos
from
    dm_content.bee_doc_hourly LATERAL VIEW explode(saved_images) x as img_url,
    img_tos

拆分列表

{
    'info': [
        {'totalcount': 1, 'name': 'name1'},
        {'totalcount': 2, 'name': 'name2'},
    ]
}
SELECT
    get_json_object(coll,'$.totalcount') as `数量`,
    get_json_object(coll,'$.name') as  `名称`
FROM
    dm_content.bee_doc_hourly LATERAL VIEW explode(json_split(parsed_content ['info'])) x as coll
WHERE
    date = '20190919'

hive中判断一个字符串是否包含另一个子串的几种方法like/rlike/regexp/locate/instr

blog.csdn.net/weixin_4642…

SELECT
    brand_id,
    brand_name,
    brand_name_cn,
    brand_name_en,
    brand_p_level,
    scene,
    branch,
    date
FROM
    ad_yuntu.model_building_brand_info
WHERE
    scene = 'brand_list'
    and branch = 'ecom_p456'
    and date = '20230102'
    and (brand_name like('%梓意%') or brand_name rlike('梓意')) or brand_name regexp '梓意' or locate('梓意', brand_name, 0) >= 0 or instr(brand_name, '梓意') >= 0

regexp_extract解析字符串

regexp_extract(str,regexp[,idx])
-- str 是被解析的字符串或字段名。
-- regexp 是正则表达式。
-- idx 是返回结果 取表达式的哪一部分 默认值是1。
-- 0 是表示把整个表达式对应的结果全部返回
-- 1 表示返回正则表达式中第一个()对应的结果 依次类推

-- 解析整数
CAST(regexp_extract(req, 'external_action=([0-9]+),', 1) AS bigint) = 492
-- 解析浮点数
cast(regexp_extract(rsp, 'bid_max=([0-9]+.?[0-9]+),', 1) as double) / 100000 as bid_max,

Left join|right join|inner join

只保留left only部分

select  a.item_id
from    star_data_1 a
left join
        star_data_2 b
on      a.item_id = b.item_id
where   b.item_id is null  -- 这里可以直接用where这么写

Dorado中读表和写表,可以使用PySpark也可以使用HSQL

HSQL例子:

# https://data.bytedance.net/dorado/development/node/108211588?project=cn_310&version=-1#Node
set bytequery.sql.cartesian.product.check.enabled=false;

with all_data as (
        select  cast(nvl(item_id, -1) as bigint) as item_id,
                lower(nvl(item_title, '')) as title
        from    ad_yuntu.dwd_brdm_item_title_di
        where   app_type = 'aweme'
        and     p_date = '${date}'
    ),

search_words as (
    select brand_name, brand_id, lower(word) as word
    from ad_yuntu_dev.brdm_brand_name_word_tmp
    where p_date='20221207'
)

insert OVERWRITE table ad_yuntu_dev.brdm_simple_brand_tagging_data_tmp partition(p_date='${date}', source='douyin')
select a.item_id, a.title, b.brand_id, b.brand_name, b.word
from all_data a,
search_words b 
where instr(a.title, b.word)>0

PySpark例子:

# https://data.bytedance.net/dorado/development/node/109478015?project=cn_895&version=-1
# -*- coding: utf-8 -*-

from common_utils.spark_util import get_spark_context
from rpc_utils.yuntu_tagging.client import YuntuModelTaggingChoiceClient

brand_ner_content = """
    with all_data as (
        select  cast(nvl(item_id, -1) as bigint) as item_id,
                nvl(item_title, '') as title,
                '' as content,
                p_date
        from    ad_yuntu.dwd_brdm_item_title_di
        where   app_type = '{app_type}'
        and     p_date between '{start_date}' and '{end_date}'
    )
    select  title,
            content,
            p_date,
            collect_list(item_id) as item_id_list
    from    all_data
    group by
            title,
            content,
            p_date
"""

SQL_DICT = {
    'brand_ner': {
        'douyin': brand_ner_content,
        'toutiao': brand_ner_content,
        'huoshan': brand_ner_content
    }
}

def get_data_sdf(hc, args):
    import datetime
    temp_date = datetime.datetime.strptime(args.date, '%Y%m%d')
    start_time = (temp_date + datetime.timedelta(days=-6)).strftime('%Y%m%d')
    start_time = str(start_time)
    source_map = {
        'douyin': 'aweme',
        'toutiao': 'toutiao',
        'houshan': 'hotsoon_pure'
    }
    sql = SQL_DICT[args.task][args.source].format(app_type=source_map[args.source],
        start_date=start_time, end_date=args.date)
    print("[MAIN_INFO] Data SQL:\n", sql)
    return hc.sql(sql)

def write_to_hive(hc, table, args):
    sql = """
        insert overwrite table ad_yuntu.yuntu_model_tagging partition(
            task = '{task}',
            scene = '{scene}',
            branch = '{branch}',
            source = '{source}',
            version = '{version}',
            date = '{date}'
        )
        select  item_id,
                title,
                content,
                tags
        from    {table}
    """.format(task=args.task, scene=args.scene, branch=args.branch,
               source=args.source, version=args.version, date=args.date,
               table=table)
    print("[MAIN_INFO] Write SQL:\n", sql)
    hc.sql(sql)

def get_model_tags(client_obj, data_json_list, task, source, version):
    import time
    import json
    if task == 'brand_ner':
        resp_list = client_obj.get_multi_branch_brand(
            data_json_list, source=source, version=version)
        time.sleep(1)
    else:
        raise NotImplementedError('%s' % task)
    if len(resp_list) > 0 and len(resp_list[0]) == 0:
        resp_list = [{'branch_outputs': []} for _ in data_json_list]
    # 输出格式:
    # {
    #     'branch_name': [
    #         {'brand_name': 'x', 'start_index': 0, 'end_index': 3},
    #         ...
    #     ],
    #     ...
    # }
    tags_list = []
    for resp in resp_list:
        tags = dict()
        for branch_output in resp['branch_outputs']:
            if 'branch' in branch_output:
                tags[branch_output['branch']] = [
                    json.dumps(brand, ensure_ascii=False, separators=(',', ':'))
                    for brand in branch_output['brands']
                ]
        tags_list.append(tags)
    return tags_list

def pred_rdd_data(rows, task, source, version, batch_size, cluster, idc_list):
    import sys
    sys.path.insert(0, '/opt/tiger/business_tagging_platform')
    from pyspark.sql import Row
    row_construct = Row('item_id', 'title', 'content', 'tags')
    # str_caller = 'ad.va.yuntu_brand_tagging_' + str(source)
    str_caller = 'ad.va.test'
    # 根据gpu机房,设置服务发现机房 ['hl', 'lf', 'lq']
    client_obj = YuntuModelTaggingChoiceClient(
        'ad.ms.yuntu_model_tagging', cluster, idc_list,
        caller=str_caller, timeout=5, transport='framed'
    )
    data_json_list = []
    all_item_id_list = []
    for row in rows:
        all_item_id_list.append(row.item_id_list)
        data_json_list.append({
            'item_id': row.item_id_list[0],
            'title': row.title if row.title else '',
            'content': row.content if row.content else ''
        })
        if len(all_item_id_list) >= batch_size:
            tags_list = get_model_tags(
                client_obj, data_json_list, task, source, version
            )
            for idx, tags in enumerate(tags_list):
                data_json = data_json_list[idx]
                item_id_list = all_item_id_list[idx]
                for item_id in item_id_list:
                    yield row_construct(
                        item_id, data_json['title'], data_json['content'], tags
                    )
            all_item_id_list = []
            data_json_list = []
    if len(all_item_id_list) > 0:
        tags_list = get_model_tags(
            client_obj, data_json_list, task, source, version
        )
        for idx, tags in enumerate(tags_list):
            data_json = data_json_list[idx]
            item_id_list = all_item_id_list[idx]
            for item_id in item_id_list:
                yield row_construct(
                    item_id, data_json['title'], data_json['content'], tags
                )

def execute_model_tagging(sc, hc, args):
    # 1. 读表
    ori_data_sdf = get_data_sdf(hc, args)
    print('==================================')
    print('get data done')
    # ori_data_sdf = ori_data_sdf.limit(1000000)  # test
    # 2. 处理数据
    tag_data_rdd = ori_data_sdf.rdd.repartition(3000).mapPartitions(
        lambda rows: pred_rdd_data(
            rows, task=args.task, source=args.source,
            version=args.version, batch_size=args.batch_size,
            cluster=args.cluster, idc_list=args.idc_list)
    )
    # 3. 写表
    if not tag_data_rdd.isEmpty():
        from pyspark.sql.types import StructField, StructType, \
            LongType, ArrayType, StringType, MapType
        schema = StructType([
            StructField('item_id', LongType(), False),
            StructField('title', StringType()),
            StructField('content', StringType()),
            StructField('tags', MapType(StringType(), ArrayType(StringType())))
        ])
        tag_data_sdf = tag_data_rdd.toDF(schema=schema)
        table_name = '%s_table' % args.task_mark
        tag_data_sdf.registerTempTable(table_name)
        write_to_hive(hc, table_name, args)

def main():
    import argparse
    args = argparse.ArgumentParser().parse_args()
    args.date = '${date}'
    args.task = '{{task}}'
    args.scene = 'brand_list'
    args.branch = 'ecom_p456'
    args.source = '{{source}}'
    args.version = '{{version}}'
    args.yarn_user = '{{yarn_user}}'
    args.cluster = 'display'
    args.idc_list = ['hl']
    args.batch_size = 8
    args.boost_ratio = 10
    args.cores = 3
    args.executors = 50
    args.task_mark = '%s_%s_%s_%s_%s_%s' % (
        args.task, args.scene, args.branch,
        args.source, args.version, args.date
    )
    print('[MAIN_INFO] parameters: %s' % args)
    import time
    start_time = time.time()
    sc, hc = get_spark_context(
        task_name='yuntu_model_tagging_%s_%s' % (args.task_mark, args.yarn_user),
        params={
            'spark.yarn.appMasterEnv.YARN_CONTAINER_RUNTIME_TYPE': 'docker',
            'spark.executorEnv.YARN_CONTAINER_RUNTIME_TYPE': 'docker',
            # 控制并发
            'spark.vcore.boost.ratio': args.boost_ratio,
            'spark.executor.cores': args.cores,
            'spark.dynamicAllocation.initialExecutors': args.executors,
            'spark.dynamicAllocation.minExecutors': args.executors,
            'spark.dynamicAllocation.maxExecutors': args.executors,
            # 控制内存
            'spark.driver.memory': '20g',
            'spark.executor.memory': '20g',
            'spark.yarn.executor.memoryOverhead': '19g',
            'spark.yarn.driver.memoryOverhead': '19g',
            'spark.yarn.am.memoryOverhead': '19g',
        }
    )
    sc.setLogLevel('WARN')
    execute_model_tagging(sc, hc, args)
    print("[MAIN_INFO] token time %ss" % (time.time() - start_time))

if __name__ == "__main__":
    main()