MIT-6.S081 | Lab-Multithreading(2021)

336 阅读7分钟

Lab Multithreading

Uthread: switching between threads

​ 实现用户态的线程切换,同内核态的线程切换比较,步骤减少了许多,不用多考虑陷阱帧(trapframe)的切换问题。只需要考虑栈指针、返回地址以及被调用者寄存器(callee register)

​ 整个过程:第一次调用thread_schedule();之后,会离开thread[0](也即是进程main),之后就一直在三个线程thread[1] ~ thread[3]之间一直互相切换,直到exit(0)

UPROGS=\
	$U/_cat\
	$U/_echo\
	...
	$U/_uthread\ # insert _uthread

要实现的内容:

  • thread_switch(): 这个函数和内核中的 swtch() 完全一样,用于切换处理器的上下文。和内核中相同,因为执行这个函数的过程是一个正常的函数调用,所以我们不需要保存和交换调用者保存的寄存器。
  • thread_create() :这个函数是用于创建新的用户线程的。参考内核态多线程的实现。我们调用 swtch() 后,决定跳转位置的是 ra 寄存器,决定恢复出来的被调用者保存寄存器的是 sp 寄存器。所以,在这个函数中,我们应该合理的设置 ra 寄存器,使得第一次执行用户函数时,是这个函数的第一条语句。
  • thread_schedule():参考内核中的实现,这个函数和内核中的 scheduler() 的作用相同。也就是在当前进程调用 yield() 后,找到一个 RUNNABLE 的进程,然后执行这个进程。在 thread_schedule() 中,我们会需要调用 thread_switch() 来切换处理器的上下文。

​ 首先thread_switch(),uthread.c中没有上下文属性,所以要添加,需要的寄存器和swtch一样。

struct context
{
  uint64 ra; // 它用于存储返回地址。
  uint64 sp; // 它用于存储堆栈指针。
  // callee-saved registers
  uint64 s0;
  uint64 s1;
  uint64 s2;
  uint64 s3;
  uint64 s4;
  uint64 s5;
  uint64 s6;
  uint64 s7;
  uint64 s8;
  uint64 s9;
  uint64 s10;
  uint64 s11;
};


struct thread {
  char       stack[STACK_SIZE]; /* the thread's stack */
  int        state;             /* FREE, RUNNING, RUNNABLE */
  struct context context;
};

然后 thread_switch() 差不多就可以直接把 swtch() 中的东西抄过来了:

.text

	/*
         * save the old thread's registers,
         * restore the new thread's registers.
         */

	.globl thread_switch
	 // a0 是老的上下文,a1 是新的
thread_switch:
	/* YOUR CODE HERE */
	sd ra, 0(a0)
    sd sp, 8(a0)
    sd s0, 16(a0)
    sd s1, 24(a0)
    sd s2, 32(a0)
    sd s3, 40(a0)
    sd s4, 48(a0)
    sd s5, 56(a0)
    sd s6, 64(a0)
    sd s7, 72(a0)
    sd s8, 80(a0)
    sd s9, 88(a0)
    sd s10, 96(a0)
    sd s11, 104(a0)

    ld ra, 0(a1)
    ld sp, 8(a1)
    ld s0, 16(a1)
    ld s1, 24(a1)
    ld s2, 32(a1)
    ld s3, 40(a1)
    ld s4, 48(a1)
    ld s5, 56(a1)
    ld s6, 64(a1)
    ld s7, 72(a1)
    ld s8, 80(a1)
    ld s9, 88(a1)
    ld s10, 96(a1)
    ld s11, 104(a1)
	ret    /* return to ra */

​ 接下来是 thread_create()。实现这个函数主要需要思考如何设置 ra 和 sp 寄存器。因为用户进程一开始的时候是没有使用寄存器的,所以如何设置上下文中的其他寄存器是无所谓的。

​ 参考kernel/proc.c/allocproc(),使线程恢复被调用者寄存器之后,能够直接跳到函数头部,同时,栈恢复为对应线程的栈。所以ra和sp的设置的顺序参考allocproc即可。

// kernel/proc.c/allocproc()
  memset(&p->context, 0, sizeof(p->context));
  p->context.ra = (uint64)forkret;
  p->context.sp = p->kstack + PGSIZE;

  return p;
//thread_create
void 
thread_create(void (*func)())
{
  struct thread *t;

  for (t = all_thread; t < all_thread + MAX_THREAD; t++) {
    if (t->state == FREE) break;
  }
  t->state = RUNNABLE;
  // YOUR CODE HERE
  t->context.ra = (uint64) func;
  t->context.sp = (uint64)(t->stack + STACK_SIZE);
}

接下来处理 thread_schedule():很明显我们要交换current_threadnext_thread() 的上下文。

void 
thread_schedule(void)
{
  struct thread *t, *next_thread;

  /* Find another runnable thread. */
  next_thread = 0;
  t = current_thread + 1;
  for(int i = 0; i < MAX_THREAD; i++){
    if(t >= all_thread + MAX_THREAD)  //是否到末尾即最后一个线程
      t = all_thread;                 //回到开头
    if(t->state == RUNNABLE) {
      next_thread = t;
      break;
    }
    t = t + 1;
  }

  if (next_thread == 0)
  { // 检查是否没有找到可运行的线程。
    printf("thread_schedule: no runnable threads\n");
    exit(-1);
  }
  // 检查是否current_thread与 不同next_thread,意味着需要线程切换。
  if (current_thread != next_thread) {         /* switch threads?  */
    next_thread->state = RUNNING;
    t = current_thread;
    current_thread = next_thread;
    /* YOUR CODE HERE
     * Invoke thread_switch to switch from t to next_thread:
     * thread_switch(??, ??);
     */
      //交换上下文
    thread_switch((uint64)&t->context, (uint64)&next_thread->context);
  } else
    next_thread = 0;
}

Using threads

​ 一个散列表(哈希表)的程序,然后做一些更改,使得这个程序在多线程的环境下也可用。

最关键的有三个函数 insert()put()get()

static void 
insert(int key, int value, struct entry **p, struct entry *n)
{
  struct entry *e = malloc(sizeof(struct entry));
  e->key = key;
  e->value = value;
  e->next = n;
  *p = e; // 把 p table[i] 的起始点改成 e
}
static 
void put(int key, int value)
{
  // is the key already present?
  struct entry *e = 0;
  for (e = table[i]; e != 0; e = e->next) {
    if (e->key == key)
      break;
  }
  if(e){
    // update the existing key.
    e->value = value;
  } else {
    // the new is new.
    insert(key, value, &table[i], table[i]); // 在 table[i] 的最前面插入一个 key val 对
  }
}

​ 其实就是尝试在散列表中添加一个键值对。这个函数会先尝试查找散列表中是否存在某个 key 如果存在,就用 value 替代掉原来和 key 对应的值。如果不存在,就调用 insert() 函数插入该键值对。

static struct entry*
get(int key)
{
  int i = key % NBUCKET;
  struct entry *e = 0;
  for (e = table[i]; e != 0; e = e->next) {
    if (e->key == key) break;
  }
  return e;
}

​ 当有两个键 k1 和 k2,他们属于散列表中的同一链表,并且链表中都还不存在这两个键值对。现在有两个线程 t1 和 t2,它们分别尝试在该链表中插入这两个键值。

那么有如下的可能情况:

​ t1 先检查了链表中不存在 k1,于是准备调用 insert() 在链表前插入键值对。

​ 这个时候,线程调度器切换到了 t2(也可能是在多核环境下,两个线程并行执行,但是 t2 比 t1 快)。

​ 然后 t2 也发现了链表中不存在 k2,所以调用 insert() 插入。插入之后,k2 成了链表的第一个元素。

​ 随后 t1 也真正的插入了 k1。但是,因为 t1 并不知道 t2 已经把 k2 插入到了开头,于是在其认为的链表开头(k2 所处位置)插入了 k1,k2 就被覆盖掉了,于是造成了键值对丢失。

​ 这样的情况下,我们需要通过加锁来解决问题。

所以在put里修改。

static 
void put(int key, int value)
{
  pthread_mutex_lock(&lock); // first_mutex
  int i = key % NBUCKET;
  
  // is the key already present?
  pthread_mutex_unlock(&lock); // second_mutex
  struct entry *e = 0;
  for (e = table[i]; e != 0; e = e->next) {
    if (e->key == key)
      break;
  }
  pthread_mutex_lock(&lock); // second_mutex
  if(e){
    // update the existing key.
    e->value = value;
  } else {
    // the new is new.
    insert(key, value, &table[i], table[i]);
  }
  pthread_mutex_unlock(&lock); // first_mutex
}

first_mutex先上锁再解锁是防止双线程同时修改keyvalue

second_mutex先解锁在上锁是用来加速put()的,因为使用第二对锁的这段代码不需要修改key或者value,只是访问了key,而不是修改。

pthread_mutex_t lock;  //全局
int
main(int argc, char *argv[])
{
  pthread_t *tha;
  void *value;
  double t1, t0;
  pthread_mutex_init(&lock, NULL);  //初始化

Barrier

实现同步屏障。

同步屏障(Barrier)是并行计算中的一种同步方法。对于一群进程或线程,程序中的一个同步屏障意味着任何线程/进程执行到此后必须等待,直到所有线程/进程都到达此点才可继续执行下文。

In this assignment you'll implement a barrier: a point in an application at which all participating threads must wait until all other participating threads reach that point too. You'll use pthread condition variables, which are a sequence coordination technique similar to xv6's sleep and wakeup.

也就是说:每一轮都有同样数目的线程来到,我们会阻塞所有线程,直到这一轮的最后一个线程进来,这时,我们将这一轮除最后一个线程外的所有线程唤醒,然后进入下一轮。

  • 进入下一轮之前记得将bstate.nthread = 0,使得下一轮能够重新计数
  • bstate.round++;
pthread_cond_wait(&cond, &mutex); // 在 cond 上进入休眠状态,释放互斥锁,唤醒时获取
pthread_cond_broadcast(&cond); // 唤醒每个在 cond 上休眠的线程

正好满足需求

static void 
barrier()
{
  // YOUR CODE HERE
  //
  // Block until all threads have called barrier() and
  // then increment bstate.round.
  //
  pthread_mutex_lock(&bstate.barrier_mutex);
  // next thread
  bstate.nthread++;
   /* 最后一个线程来了,我们进入下一轮,唤醒所有线程 */
  if(bstate.nthread == nthread)
  {
    bstate.nthread = 0; //进入下一轮
    bstate.round++;
    pthread_cond_broadcast(&bstate.barrier_cond); // 唤醒沉睡的线程
  }
  else{
    // 如果没有全部到达 barrier 的位置,就等待
    pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
  }
  pthread_mutex_unlock(&bstate.barrier_mutex);
}