您的位置:首页 > 理论基础 > 计算机网络

打造先进的内存KV数据库-5 TCP侦听

2016-01-05 23:46 549 查看

TCP侦听

作为支持集群的数据库,必定要与多个客户端交互信息,不可能让数据库与所有客户共享地址空间(虽然这样性能好),所以需要使用TCP协议进行交互数据,(UDP协议不可靠。。。弃用),C语言的TCP库其实还好,但是对于高并发和并行的处理不如Go,而且并发锁机制比较难写,所以使用Go写了服务器和客户端调用C的库,目前版本没有什么身份验证,之后会加上。

代码实现

//server.go
package main
// #cgo LDFLAGS: -L ./lib -lmonkeyS
// #include "./lib/core.h"
// #include <stdlib.h>
import "C"
import (
"unsafe"
_"fmt"
"net"
"strings"
)

func main() {
str := []byte("monkey")
str = append(str,0)
C.CreateDB((*C.char)(unsafe.Pointer(&str[0])))  //创建基础数据库
servicePort := ":1517"
tcpAddr,err := net.ResolveTCPAddr("tcp4",servicePort)
if err != nil {
panic(err)
}
l,err := net.ListenTCP("tcp",tcpAddr)   //侦听TCP
if err != nil {
panic(err)
}
for{
conn,err := l.Accept()
if err != nil {
panic(err)
}
go Handler(conn)
}
}

func Handler(conn net.Conn) {

str := []byte("monkey")                         //环境变量-当前数据库
db := C.SwitchDB((*C.char)(unsafe.Pointer(&str[0])))
for {
buff := []byte{}
buf := make([]byte,1024)
length,err := conn.Read(buf)
total := uint32(0); //前4个字节保存消息长度
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(buf[i]);
}
//fmt.Println("Message length:",total)
buff = append(buff,buf[4:]...)
total -= uint32(length)
for total > 0 {
length,err = conn.Read(buf)
total -= uint32(length)
buff = append(buff,buf...)
}
if err != nil {
conn.Close()
break
}
TranslateMessage(conn,&db,buff)                     //解析消息
}

}

func TranslateMessage(conn net.Conn,db **C.Database,message []byte) {
command := string(message)
params := strings.Split(command," ")
//fmt.Println(params)
response := []byte{}
if params[0] == "set" {
r := C.Set(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))),(unsafe.Pointer(&([]byte(params[2]))[0])))
for i := 0;;i++ {
response = append(response,byte(r.msg[i]))
if response[i] == 0 { break; }
}

}else if params[0] == "get" {
r := C.Get(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
// for i := 0;;i++ {
//  response = append(response,byte(r.msg[i]))
//  if response[i] == 0 { break; }
// }
if(int(r.code) == 0) {
for i := 0;;i++ {
response = append(response,byte(*(*C.char)(unsafe.Pointer((uintptr(r.pData)+uintptr(i))))))
if response[i] == 0 { break; }
}
}else {
// for i := 0;;i++ {
// response = append(response,byte(r.msg[i]))
// if response[i] == 0 { break; }
// }
}

}else if params[0] == "delete" || params[0] == "remove" {
r := C.Delete(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
for i := 0;;i++ {
response = append(response,byte(r.msg[i]))
if response[i] == 0 { break; }
}

}else if params[0] == "createdb" {
d := C.CreateDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
if d != nil {
*db = d
response = []byte("Already exist,switched\n")
}else {
response = []byte("Created\n")
}
}else if params[0] == "switchdb" {
d := C.SwitchDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
if d != nil {
*db = d
response = []byte("ok\n")
}else {
response = []byte("fail\n")
}
}else if params[0] == "dropdb" {
*db = C.DropDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
}else if strings.EqualFold("listdb",params[0]) {
r := C.ListDB()
for i := 0;i < 1024;i++ {
b := byte(*(*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(r))+uintptr(i))))
response = append(response,b)
if(b == 0){ break; }
}
C.free(unsafe.Pointer(r))
}else {
//fmt.Println("unkown command:",params[0])
}
total := len(response) + 4
header := make([]byte,4)
i := 0
for total > 0 {
header[3-i] = byte(total % 256)
total /= 256
i++
}
response = append(header,response...)
conn.Write(response)
}


//Client.go
package main
import "net"
import "fmt"
func main() {
tcpAddr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:1517")
if err != nil {
panic(err)
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
panic(err)
}

for {
buf1 := ""
buf2 := ""
buf3 := ""
buf := ""
fmt.Print("monkey>")
fmt.Scanf("%s",&buf1)
if buf1 == "set" {
fmt.Scanf("%s",&buf2)
fmt.Scanf("%s",&buf3)
buf = buf1 + " " + buf2 + " " + buf3
}else if buf1 == "get"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "remove" || buf1 == "delete" {
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "createdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "switchdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "dropdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "listdb"{
buf = buf1 + " "
}else if buf1 == "exit"{
fmt.Println("Bye!")
break;
}
total := uint32(0)
total = uint32(len(buf) + 4)
header := make([]byte,4)
i := 0
for total > 0 {
header[3-i] = byte(total % 256)
total /= 256
i++
}
conn.Write(append(header,([]byte(buf))...))

buff := []byte{}
buff2 := make([]byte,1024)
length,_ := conn.Read(buff2)
total = uint32(0);  //前4个字节保存消息长度
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(buff2[i]);
}
buff = append(buff,buff2[4:]...)
total -= uint32(length)
for total > 0 {
length,_ = conn.Read(buff2)
total -= uint32(length)
buff = append(buff,buff2...)
}
for i := 0;i < 1024;i++ {
if buff[i] == 0 { break; }
fmt.Printf("%c",buff[i])
}
fmt.Print("\n")
}
}


修正:上述代码存在严重问题:

发送1K以上数据会无法正确接收

改进代码如下:

//tcp.go
package tcp
import "net"
import "fmt"

func ok(bytes []byte) bool {
return bytes[0] == 111 && bytes[1] == 107 && bytes[2] == 0;
}

func bytes4uint(bytes []byte) uint32 {
total := uint32(0);
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(bytes[i]);
}
return total
}

func uint32bytes(n uint32) []byte {
header := make([]byte,4)
i := 0
for n > 0 {
header[3-i] = byte(n % 256)
n /= 256
i++
}
return header
}

type TCPSession struct {
Conn *net.TCPConn
ToSend chan interface{} //要发送的数据
Received chan interface{}   //接受到的数据
Closed bool //是否已经关闭
}

func (s *TCPSession) Init() {
s.ToSend = make(chan interface{})
s.Received = make(chan interface{})
go s.Send()
go s.Recv()
}

func (s *TCPSession) Send() {
for {
if s.Closed {
return
}
buf0 := <- s.ToSend //取出要发送的数据
buf := buf0.([]byte)

_,err := s.Conn.Write(buf)  //发送掉
//fmt.Println("send,",buf)
if err != nil {
s.Closed = true
return
}
}

}

func (s *TCPSession) Recv() {
for {
if s.Closed {
return
}
buf := make([]byte,1024)
_,err := s.Conn.Read(buf)
if err != nil {
s.Closed = true
return
}
s.Received <- buf
//fmt.Println("read,",buf)
}

}

func (s *TCPSession) SendMessage(bytes []byte) {
total := len(bytes) / 1024
if len(bytes) % 1024 != 0 {
total++
}
header := uint32bytes(uint32(total))    //计算条数
s.ToSend <- header
//fmt.Println(header)
for i := 0;i < total-1;i++ {
buf := bytes[0:1024]    //发送这一段
bytes = bytes[1024:]
s.ToSend <- buf
continue
}
//发送最后一段
if total == 0 {
return
}
buf := bytes[0:]    //发送这一段
s.ToSend <- buf
}

func (s *TCPSession) ReadMessage() []byte {
buf0 := <- s.Received
buf := buf0.([]byte)
//fmt.Println(buf)
total := bytes4uint(buf)
var buff []byte
if buf[4] != 0 {    //两份报表被合并
buff = buf[4:]
total--
} else {
buff = []byte{}
}

for i := uint32(0);i < total;i++ {
buf0 := <- s.Received
buf := buf0.([]byte)
buff = append(buff,buf...)
}
return buff
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: