MIT 6.1810 Lab: Copy-on-Write Fork for xv6

lab网址:https://pdos.csail.mit.edu/6.828/2022/labs/cow.html

xv6Book:https://pdos.csail.mit.edu/6.828/2022/xv6/book-riscv-rev3.pdf

Implement copy-on-write fork

这部分需要我们实现写时拷贝,题目给出解决方案为,当fork时,将父子进程的页表项都设置为只度,当发生写错误时,在处理函数中完成拷贝操作。

修改uvmcopy

uvmcopy通过复制页表项,将父进程的虚拟地址空间复制给子进程,这里只复制从0开始的部分,其他部分由proc_pagetablemappages映射,如TRAMPOLINETRAPFRAME,还有一些保存在内核中,如内核栈,这些都是父子进程独立的,不会被共用。共用的部分应当包括用户程序的代码、数据、栈、堆。

设置标志位不可写*pte &= ~PTE_W;等,添加物理页引用计数,注意要将原先不可写的页表项区分开。

// Given a parent process's page table, copy
// its memory into a child's page table.
// Copies both the page table and the
// physical memory.
// returns 0 on success, -1 on failure.
// frees any allocated pages on failure.
int
uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
{
  pte_t *pte;
  uint64 pa, i;
  uint flags;

  for(i = 0; i < sz; i += PGSIZE){
    if((pte = walk(old, i, 0)) == 0)
      panic("uvmcopy: pte should exist");
    if((*pte & PTE_V) == 0)
      panic("uvmcopy: page not present");
    //both parent and child pte disable write
+    if(*pte & PTE_W){
+      *pte &= ~PTE_W;
+      *pte |= PTE_COW;
+    }
+    pa = PTE2PA(*pte);
+    flags = PTE_FLAGS(*pte);
+    if(mappages(new, i, PGSIZE, (uint64)pa, flags) != 0){
+      goto err;
+    }
+    //adjust ref count
+    incrrefcnt(pa);
+  }
  return 0;

 err:
  uvmunmap(new, 0, i / PGSIZE, 1);
  return -1;
}

修改usertrap

当发生写错误时,会引发中断,我们可以在usertrap中检测识别,在这里完成写时拷贝的工作。作者发现自己之前对写时拷贝的理解有误,之前以为发生写操作时,会复制所有的页面,其实写时拷贝是精确到页的。

注意uvmcopy复制地址空间时,创建了新的页表项,页表项没有共用,只是页表项的内容相同,从而父子进程共用相同的地址空间。

发生中断的时候怎样知道是哪个页面发生了写错误呢?可以通过读取寄存器,其中scause标记了页面写错误,stval标记错误的地址。作者原先通过r_scause() == 15识别写时拷贝,但是发现这种方法不能和原始写错误区别开,最后无法通过usertests,题目暗示可以通过保留位的方式来识别,即在页表项设置一个COW保留位,如何发生写错误时,判断错误地址的页表项是否设置了保留位。

//
// handle an interrupt, exception, or system call from user space.
// called from trampoline.S
//
void
usertrap(void)
{
  int which_dev = 0;

  if((r_sstatus() & SSTATUS_SPP) != 0)
    panic("usertrap: not from user mode");

  // send interrupts and exceptions to kerneltrap(),
  // since we're now in the kernel.
  w_stvec((uint64)kernelvec);

  struct proc *p = myproc();
  
  // save user program counter.
  p->trapframe->epc = r_sepc();
  
  if(r_scause() == 8){
    // system call

    if(killed(p))
      exit(-1);

    // sepc points to the ecall instruction,
    // but we want to return to the next instruction.
    p->trapframe->epc += 4;

    // an interrupt will change sepc, scause, and sstatus,
    // so enable only now that we're done with those registers.
    intr_on();

    syscall();
+  }else if(r_scause() == 15){
+    uint64 addr = r_stval();
+    pte_t *pte = walk(p->pagetable,addr,0);
+    uint64 pa = PTE2PA(*pte);
+    void* mem;
+    if((mem = kalloc()) == 0){
+      setkilled(p);
+      exit(-1);
+    }
+    memmove(mem, (char*)pa, PGSIZE);
+    kfree((void *)pa);
+    uint flags = PTE_FLAGS(*pte);
+    flags |= PTE_W;
+    *pte = PA2PTE(mem);
+    *pte |= flags;
  }else if((which_dev = devintr()) != 0){
    // ok
  } else {
    printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
    printf("            sepc=%p stval=%p\n", r_sepc(), r_stval());
    setkilled(p);
  }

  if(killed(p))
    exit(-1);

  // give up the CPU if this is a timer interrupt.
  if(which_dev == 2)
    yield();

  usertrapret();
}

物理页引用计数

考虑到物理页的共用,需要给物理页设置引用计数。作者本来想直接在物理页的数据结构中加一个字段,但是没想到xv6是用空闲物理页自身链成的链表,一旦一个物理页被使用,那么这个物理页的数据结构就不存在了。因此只能尝试在kmem中定义的数组作为引用计数。

diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index 0699e7e..cd9905c 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -21,11 +21,14 @@ struct run {
 struct {
   struct spinlock lock;
   struct run *freelist;
+  int refcount[(PHYSTOP-KERNBASE)/PGSIZE];
 } kmem;
 
 void
 kinit()
 {
+  for(int i=0;i<(PHYSTOP-KERNBASE)/PGSIZE;i++)
+    kmem.refcount[i]=1;
   initlock(&kmem.lock, "kmem");
   freerange(end, (void*)PHYSTOP);
 }
@@ -51,14 +54,22 @@ kfree(void *pa)
   if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
     panic("kfree");
 
+  //ref count
+  acquire(&kmem.lock);
+  kmem.refcount[((uint64)(pa-KERNBASE))/PGSIZE-1]--;
+  if(kmem.refcount[((uint64)(pa-KERNBASE))/PGSIZE-1] > 0){
+    release(&kmem.lock);
+    return;
+  }
+  release(&kmem.lock);
+
   // Fill with junk to catch dangling refs.
   memset(pa, 1, PGSIZE);
-
   r = (struct run*)pa;
-
   acquire(&kmem.lock);
   r->next = kmem.freelist;
   kmem.freelist = r;
+  kmem.refcount[((uint64)(pa-KERNBASE))/PGSIZE-1]=1;
   release(&kmem.lock);
 }
 
@@ -74,6 +85,7 @@ kalloc(void)
   r = kmem.freelist;
   if(r)
     kmem.freelist = r->next;
+    
   release(&kmem.lock);
 
   if(r)

+void incrrefcnt(uint64 pa){
+  acquire(&kmem.lock);
+  kmem.refcount[((uint64)(pa-KERNBASE))/PGSIZE-1]++;
+  release(&kmem.lock); 
+}

修改copyout

copyout是内核通过页表向进程的虚拟地址空间写数据,因此同样需要这个函数,完成相应的更改操作。检测要写入的页的引用计数是否大于1,若大于则创建新页。

// Copy from kernel to user.
// Copy len bytes from src to virtual address dstva in a given page table.
// Return 0 on success, -1 on error.
int
copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
{
  uint64 n, va0, pa0;

  while(len > 0){
    va0 = PGROUNDDOWN(dstva);
    pa0 = walkaddr(pagetable, va0);
    if(pa0 == 0)
      return -1;
+    if(showrefcnt(pa0) >= 1){
+      uint64 opa;
+      pte_t *pte = walk(pagetable,va0,0);
+      opa = PTE2PA(*pte);
+      void* mem;
+      if((mem = kalloc()) == 0){
+        return -1;
+      }
+      memmove(mem, (char*)opa, PGSIZE);
+      kfree((void *)opa);
+      uint flags = PTE_FLAGS(*pte);
+      flags |= PTE_W;
+      *pte = PA2PTE(mem);
+      *pte |= flags;
+      pa0 = (uint64)mem;
 +   }
    n = PGSIZE - (dstva - va0);
    if(n > len)
      n = len;
    memmove((void *)(pa0 + (dstva - va0)), src, n);

    len -= n;
    src += n;
    dstva = va0 + PGSIZE;
  }
  return 0;
}

结果

这一节调试难度较大,COW作用在每个进程上,牵涉的内容太多,通过所有测试需要大量的时间来调试。

以下是全部的git diff

diff --git a/kernel/defs.h b/kernel/defs.h
index a3c962b..38f5f29 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -63,6 +63,8 @@ void            ramdiskrw(struct buf*);
 void*           kalloc(void);
 void            kfree(void *);
 void            kinit(void);
+void            incrrefcnt(uint64);
+int             showrefcnt(uint64 pa);
 
 // log.c
 void            initlog(int, struct superblock*);
diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index 0699e7e..535fa9f 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -21,11 +21,14 @@ struct run {
 struct {
   struct spinlock lock;
   struct run *freelist;
+  int refcount[128*1024*1024/PGSIZE];
 } kmem;
 
 void
 kinit()
 {
+  for(int i=0; i < 128*1024*1024/PGSIZE ; i++)
+    kmem.refcount[i]=1;
   initlock(&kmem.lock, "kmem");
   freerange(end, (void*)PHYSTOP);
 }
@@ -35,8 +38,9 @@ freerange(void *pa_start, void *pa_end)
 {
   char *p;
   p = (char*)PGROUNDUP((uint64)pa_start);
-  for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE)
+  for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE){
     kfree(p);
+  }
 }
 
 // Free the page of physical memory pointed at by pa,
@@ -51,11 +55,19 @@ kfree(void *pa)
   if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
     panic("kfree");
 
+  //ref count
+  acquire(&kmem.lock);
+  uint64 cnt = ((uint64)pa-KERNBASE)/PGSIZE;
+  kmem.refcount[cnt]--;
+  if(kmem.refcount[cnt] > 0){
+    release(&kmem.lock);
+    return;
+  }
+  release(&kmem.lock);
+
   // Fill with junk to catch dangling refs.
   memset(pa, 1, PGSIZE);
-
   r = (struct run*)pa;
-
   acquire(&kmem.lock);
   r->next = kmem.freelist;
   kmem.freelist = r;
@@ -72,11 +84,30 @@ kalloc(void)
 
   acquire(&kmem.lock);
   r = kmem.freelist;
-  if(r)
+  if(r){
     kmem.freelist = r->next;
+    uint64 cnt = ((uint64)r-KERNBASE)/PGSIZE;
+    kmem.refcount[cnt]=1;
+  }
   release(&kmem.lock);
 
   if(r)
     memset((char*)r, 5, PGSIZE); // fill with junk
   return (void*)r;
 }
+
+void incrrefcnt(uint64 pa){
+  acquire(&kmem.lock);
+  uint64 cnt = ((uint64)pa-KERNBASE)/PGSIZE;
+  kmem.refcount[cnt]++;
+  release(&kmem.lock); 
+}
+
+int showrefcnt(uint64 pa){
+  int cnt;
+  acquire(&kmem.lock);
+  uint64 _cnt = ((uint64)pa-KERNBASE)/PGSIZE;
+  cnt = kmem.refcount[_cnt];
+  release(&kmem.lock);
+  return cnt;
+}
diff --git a/kernel/riscv.h b/kernel/riscv.h
index 20a01db..24e8105 100644
--- a/kernel/riscv.h
+++ b/kernel/riscv.h
@@ -343,6 +343,7 @@ typedef uint64 *pagetable_t; // 512 PTEs
 #define PTE_W (1L << 2)
 #define PTE_X (1L << 3)
 #define PTE_U (1L << 4) // user can access
+#define PTE_COW (1L << 6) // cow bit
 
 // shift a physical address to the right place for a PTE.
 #define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
diff --git a/kernel/trap.c b/kernel/trap.c
index 512c850..3a7687e 100644
--- a/kernel/trap.c
+++ b/kernel/trap.c
@@ -65,7 +65,31 @@ usertrap(void)
     intr_on();
 
     syscall();
-  } else if((which_dev = devintr()) != 0){
+  }else if(r_scause() == 15){
+    uint64 addr = r_stval();
+    if(addr > p->sz){
+      setkilled(p);
+      goto kill;
+    }
+    pte_t *pte = walk(p->pagetable,addr,0);
+    if((*pte & PTE_COW) ==0){
+      setkilled(p);
+      goto kill;
+    }
+    uint64 pa = PTE2PA(*pte);
+    void* mem;
+    if((mem = kalloc()) == 0){
+      setkilled(p);
+      goto kill;
+    }
+    memmove(mem, (char*)pa, PGSIZE);
+    kfree((void *)pa);
+    uint flags = PTE_FLAGS(*pte);
+    flags |= PTE_W;
+    flags &= ~PTE_COW;
+    *pte = PA2PTE(mem);
+    *pte |= flags;
+  }else if((which_dev = devintr()) != 0){
     // ok
   } else {
     printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
@@ -73,6 +97,7 @@ usertrap(void)
     setkilled(p);
   }
 
+kill:
   if(killed(p))
     exit(-1);
 
diff --git a/kernel/vm.c b/kernel/vm.c
index 9f69783..edad921 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -308,22 +308,24 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
   pte_t *pte;
   uint64 pa, i;
   uint flags;
-  char *mem;
 
   for(i = 0; i < sz; i += PGSIZE){
     if((pte = walk(old, i, 0)) == 0)
       panic("uvmcopy: pte should exist");
     if((*pte & PTE_V) == 0)
       panic("uvmcopy: page not present");
+    //both parent and child pte disable write
+    if(*pte & PTE_W){
+      *pte &= ~PTE_W;
+      *pte |= PTE_COW;
+    }
     pa = PTE2PA(*pte);
     flags = PTE_FLAGS(*pte);
-    if((mem = kalloc()) == 0)
-      goto err;
-    memmove(mem, (char*)pa, PGSIZE);
-    if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
-      kfree(mem);
+    if(mappages(new, i, PGSIZE, (uint64)pa, flags) != 0){
       goto err;
     }
+    //adjust ref count
+    incrrefcnt(pa);
   }
   return 0;
 
@@ -358,6 +360,22 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
     pa0 = walkaddr(pagetable, va0);
     if(pa0 == 0)
       return -1;
+    if(showrefcnt(pa0) > 1){
+      uint64 opa;
+      pte_t *pte = walk(pagetable,va0,0);
+      opa = PTE2PA(*pte);
+      void* mem;
+      if((mem = kalloc()) == 0){
+        return -1;
+      }
+      memmove(mem, (char*)opa, PGSIZE);
+      kfree((void *)opa);
+      uint flags = PTE_FLAGS(*pte);
+      flags |= PTE_W;
+      *pte = PA2PTE(mem);
+      *pte |= flags;
+      pa0 = (uint64)mem;
+    }
     n = PGSIZE - (dstva - va0);
     if(n > len)
       n = len;
posted @ 2024-02-06 16:15  benoqtr  阅读(58)  评论(0编辑  收藏  举报