提问:如何用python yield实现单线程并发?

552 阅读3分钟

问题背景:

100多个校区都是sqlite数据库,要求备份sqlite数据库到阿里oss。
数据形式   校区分别是demo1,demo2, .... demo127
  • /www
    • /wwwroot
      • /demo1
        • a.file
        • b.file
        • c.file
        • db.sqlite3
        • ...
      • /demo2
        • a.file
        • b.file
        • c.file
        • db.sqlite3
        • ...
      • ... 本人有相应两种代码,一个是单线程,一个是多线程。但是本人想用协程 yield / yield from 实现并发。 单线程时间大概是29s,多线程大概是9s,我自己写的yield代码时间大概是31s, 还不如不折腾。后面有对应代码。

要求

希望有人帮我优化一下协程代码。I/O有两处,一个是本地文件读写,一个是阿里的ems_oss.put_object(remote_file, data)都是I/O。希望您可以处理一下。

# 多线程
import os
import oss2
from datetime import datetime
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED, ProcessPoolExecutor


def ems_oss2():
    auth = oss2.Auth('aTAasINNBhdfZHLfnD', 'AqOQaeLfaAoAsavyPGahgdptra37') # 注 错误认证信息
    bucket = oss2.Bucket(auth, 'http://oss-country-area-internal.aliyuncs.com', 'bucket_name') # 注 非真实地址
    return bucket


def upload_db(second_dir, db_path):
    remote_file = 'old_sqlite/{}/{}/{}/{}.sqlite3'.format(
                        datetime.now().strftime("%Y-%m"),
                        datetime.now().strftime("%d"),
                        datetime.now().strftime("%H"), 
                        second_dir)
    ems_oss = ems_oss2()
    with open(db_path, 'rb') as f:
        data = f.read()
    if not ems_oss.object_exists(remote_file):
        ems_oss.put_object(remote_file, data)

def delete_db():
    old_data = datetime.now() - relativedelta(weeks=3)
    remote_file = 'old_sqlite/{}/{}/{}/'.format(
                        old_data.strftime("%Y-%m"),
                        old_data.strftime("%d"),
                        old_data.strftime("%H"))
    ems_oss = ems_oss2()
    key_list = [i.key for i in oss2.ObjectIteratorV2(ems_oss, prefix=remote_file)]
    if key_list:
        ems_oss.batch_delete_objects(key_list)

if __name__ == "__main__":
    import time
    time_start = time.time()
    BASE_PATH = '/www/wwwroot'
    first_dir = os.listdir(BASE_PATH)
    count = 0
    all_task = []
    executor = ThreadPoolExecutor(max_workers=20)
    for _ in first_dir:
        second_dir = os.path.join(BASE_PATH, _)
        if os.path.isdir(second_dir):
            # os.chdir(second_dir)
            db_path = os.path.join(second_dir, "db.sqlite3")
            if os.path.exists(db_path):
                print('count= '+str(count), second_dir)
                count += 1
                all_task.append(executor.submit(upload_db, (_), (db_path)))
            else:
                print('no db.sqlite3')
            # os.chdir(BASE_PATH)
    print(all_task)
    print(executor)
    for future in as_completed(all_task):
        data = future.result()
        print("{} is finshed".format(data))
    delete_db()
    time_end = time.time()
    print("一共耗时{}:{}".format((time_end-time_start)//60, (time_end-time_start)%60))

一般来说,多线程时间在9s左右

# 单线程
import os
import oss2
from datetime import datetime
from dateutil.relativedelta import relativedelta


def ems_oss2():
    auth = oss2.Auth('aTAasINNBhdfZHLfnD', 'AqOQaeLfaAoAsavyPGahgdptra37') # 注 错误认证信息
    bucket = oss2.Bucket(auth, 'http://oss-country-area-internal.aliyuncs.com', 'bucket_name') # 注 非真实地址
    return bucket



def upload_db(second_dir, db_path):
    remote_file = 'old_sqlite/{}/{}/{}/{}.sqlite3'.format(
                        datetime.now().strftime("%Y-%m"),
                        datetime.now().strftime("%d"),
                        datetime.now().strftime("%H"), 
                        second_dir)
    ems_oss = ems_oss2()
    with open(db_path, 'rb') as f:
        data = f.read()
    if not ems_oss.object_exists(remote_file):
        ems_oss.put_object(remote_file, data)

def delete_db():
    old_data = datetime.now() - relativedelta(weeks=3)
    remote_file = 'old_sqlite/{}/{}/{}/'.format(
                        old_data.strftime("%Y-%m"),
                        old_data.strftime("%d"),
                        old_data.strftime("%H"))
    ems_oss = ems_oss2()
    key_list = [i.key for i in oss2.ObjectIteratorV2(ems_oss, prefix=remote_file)]
    if key_list:
        ems_oss.batch_delete_objects(key_list)

if __name__ == "__main__":
    import time
    time_start = time.time()
    BASE_PATH = '/www/wwwroot'
    first_dir = os.listdir(BASE_PATH)
    count = 0
    for _ in first_dir:
        second_dir = os.path.join(BASE_PATH, _)
        if os.path.isdir(second_dir):
            # os.chdir(second_dir)
            db_path = os.path.join(second_dir, "db.sqlite3")
            if os.path.exists(db_path):
                print('count= '+str(count), second_dir)
                count += 1
                upload_db(_, db_path)
            else:
                print('no db.sqlite3')
            # os.chdir(BASE_PATH)
    delete_db()
    time_end = time.time()
    print("一共耗时{}:{}".format((time_end-time_start)//60, (time_end-time_start)%60))

单线程耗时一般是29s左右

import os
import oss2
from datetime import datetime
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED, ProcessPoolExecutor


def ems_oss2():
    auth = oss2.Auth('aTAasINNBhdfZHLfnD', 'AqOQaeLfaAoAsavyPGahgdptra37') # 注 错误认证信息
    bucket = oss2.Bucket(auth, 'http://oss-country-area-internal.aliyuncs.com', 'bucket_name') # 注 非真实地址
    return bucket


def yield_read_db(db_path):
    with open(db_path, 'rb') as f:
        yield f.read()

def yield_upload_db(second_dir, db_path):
    remote_file = 'old_sqlite/{}/{}/{}/{}.sqlite3'.format(
                        datetime.now().strftime("%Y-%m"),
                        datetime.now().strftime("%d"),
                        datetime.now().strftime("%H"),
                        second_dir)

    ems_oss = ems_oss2()
    data = yield
    if not ems_oss.object_exists(remote_file):
        if not data:
            print(remote_file, "没收到数据")
        ems_oss.put_object(remote_file, data)

def delete_db():
    old_data = datetime.now() - relativedelta(weeks=3)
    remote_file = 'old_sqlite/{}/{}/{}/'.format(
                        old_data.strftime("%Y-%m"),
                        old_data.strftime("%d"),
                        old_data.strftime("%H"))
    ems_oss = ems_oss2()
    key_list = [i.key for i in oss2.ObjectIteratorV2(ems_oss, prefix=remote_file)]
    if key_list:
        ems_oss.batch_delete_objects(key_list)



if __name__ == "__main__":
    import time
    time_start = time.time()
    BASE_PATH = '/www/wwwroot'
    first_dir = os.listdir(BASE_PATH)
    count = 0
    db_path_list = []
    for _ in first_dir:
        # 获取一级目录下所有的类似njat.mtsori.com的文件夹,然后找到文件夹下面db.sqlite3的文件路径
        second_dir = os.path.join(BASE_PATH, _)
        if os.path.isdir(second_dir):
            # os.chdir(second_dir)
            db_path = os.path.join(second_dir, "db.sqlite3")
            if os.path.exists(db_path):
                print('count= '+str(count), second_dir)
                count += 1
                db_path_list.append((_, db_path))
            else:
                print('no db.sqlite3')
            # os.chdir(BASE_PATH)
    db_path_list_length = len(db_path_list)
    for i in range(db_path_list_length//10 + 1):
        # todo: read_data
        # 一次性实例化10组read_data函数, 事实上的生产者
        once_yield_list = [yield_read_db(db_path_list[_+10*i][1]) for _ in range(10) if _+10*i < db_path_list_length]
        # 一次性实例化10组upload函数,事实上的消费者
        once_upload_list = [yield_upload_db(db_path_list[_+10*i][0], db_path_list[_+10*i][1]) for _ in range(10) if _+10*i < db_path_list_length]
        data_list = [once_yield_list[i].send(None) for i in range(len(once_yield_list))]
        [once_upload_list[i].send(None) for i in range(len(once_yield_list))]
        for i in range(len(once_yield_list)):
            try:
                once_upload_list[i].send(data_list[i]) # 让消费者消费
            except StopIteration:
                pass
        # [once_upload_list[i].send(data_list[i]) for i in range(len(once_yield_list))]
        # for i in range(len(once_yield_list)):
        #     data = once_yield_list[i].send(None) # 预激生产者
        #     if data:
        #         once_upload_list[i].send(None) # 预激消费者
                # try:
                #     once_upload_list[i].send(data) # 让消费者消费
                # except StopIteration:
                #     pass


    time_end = time.time()
    print("一共耗时{}:{}".format((time_end-time_start)//60, (time_end-time_start)%60))

这个垃圾协程代码31s.