update mssql drive to last working version 20180314172330-6a30f4e59a44 (#7306)

pull/6845/head
Antoine GIRARD 5 years ago committed by Lunny Xiao
parent aeb8f7aad8
commit 1e46eedce7
  1. 2
      go.mod
  2. 4
      go.sum
  3. 174
      vendor/github.com/denisenkom/go-mssqldb/README.md
  4. 45
      vendor/github.com/denisenkom/go-mssqldb/appveyor.yml
  5. 145
      vendor/github.com/denisenkom/go-mssqldb/buf.go
  6. 616
      vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
  7. 93
      vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go
  8. 39
      vendor/github.com/denisenkom/go-mssqldb/collation.go
  9. 22
      vendor/github.com/denisenkom/go-mssqldb/decimal.go
  10. 12
      vendor/github.com/denisenkom/go-mssqldb/doc.go
  11. 8
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go
  12. 20
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go
  13. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go
  14. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go
  15. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go
  16. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go
  17. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go
  18. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go
  19. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go
  20. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go
  21. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go
  22. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go
  23. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go
  24. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go
  25. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go
  26. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go
  27. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go
  28. 2
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go
  29. 21
      vendor/github.com/denisenkom/go-mssqldb/log.go
  30. 611
      vendor/github.com/denisenkom/go-mssqldb/mssql.go
  31. 11
      vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3.go
  32. 11
      vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3pre.go
  33. 91
      vendor/github.com/denisenkom/go-mssqldb/mssql_go18.go
  34. 64
      vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go
  35. 12
      vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go
  36. 2
      vendor/github.com/denisenkom/go-mssqldb/net.go
  37. 2
      vendor/github.com/denisenkom/go-mssqldb/ntlm.go
  38. 40
      vendor/github.com/denisenkom/go-mssqldb/parser.go
  39. 2
      vendor/github.com/denisenkom/go-mssqldb/sspi_windows.go
  40. 358
      vendor/github.com/denisenkom/go-mssqldb/tds.go
  41. 387
      vendor/github.com/denisenkom/go-mssqldb/token.go
  42. 53
      vendor/github.com/denisenkom/go-mssqldb/token_string.go
  43. 16
      vendor/github.com/denisenkom/go-mssqldb/tran.go
  44. 679
      vendor/github.com/denisenkom/go-mssqldb/types.go
  45. 74
      vendor/github.com/denisenkom/go-mssqldb/uniqueidentifier.go
  46. 3
      vendor/modules.txt

@ -140,4 +140,4 @@ require (
xorm.io/core v0.6.3
)
replace github.com/denisenkom/go-mssqldb => github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449
replace github.com/denisenkom/go-mssqldb => github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44

@ -60,8 +60,8 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449 h1:JpA+YMG4JLW8nzLmU05mTiuB0O17xHGxpWolEZ0zDuA=
github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc=
github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44 h1:x0uHqLQTSEL9LKic8sWDt3ASkq07ve5ojIIUl5uF64M=
github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc=
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac h1:xrQJVwQCGqDvOO7/0+RyIq5J2M3Q4ZF7Ug/BMQtML1E=
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8=

@ -1,78 +1,130 @@
# A pure Go MSSQL driver for Go's database/sql package
[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/denisenkom/go-mssqldb)
[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb)
## Install
go get github.com/denisenkom/go-mssqldb
Requires Go 1.8 or above.
## Tests
Install with `go get github.com/denisenkom/go-mssqldb` .
`go test` is used for testing. A running instance of MSSQL server is required.
Environment variables are used to pass login information.
## Connection Parameters and DSN
Example:
The recommended connection string uses a URL format:
`sqlserver://username:password@host/instance?param1=value&param2=value`
Other supported formats are listed below.
### Common parameters:
env HOST=localhost SQLUSER=sa SQLPASSWORD=sa DATABASE=test go test
## Connection Parameters
* "server" - host or host\instance (default localhost)
* "port" - used only when there is no instance in server (default 1433)
* "failoverpartner" - host or host\instance (default is no partner).
* "failoverport" - used only when there is no instance in failoverpartner (default 1433)
* "user id" - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used.
* "password"
* "database"
* "connection timeout" - in seconds (default is 30)
* "dial timeout" - in seconds (default is 5)
* "keepAlive" - in seconds; 0 to disable (default is 0)
* "log" - logging flags (default 0/no logging, 63 for full logging)
* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used.
* `password`
* `database`
* `connection timeout` - in seconds (default is 30)
* `dial timeout` - in seconds (default is 5)
* `encrypt`
* `disable` - Data send between client and server is not encrypted.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `keepAlive` - in seconds; 0 to disable (default is 30)
* `app name` - The application name (default is go-mssqldb)
### Connection parameters for ODBC and ADO style connection strings:
* `server` - host or host\instance (default localhost)
* `port` - used only when there is no instance in server (default 1433)
### Less common parameters:
* `failoverpartner` - host or host\instance (default is no partner).
* `failoverport` - used only when there is no instance in failoverpartner (default 1433)
* `packet size` - in bytes; 512 to 32767 (default is 4096)
* Encrypted connections have a maximum packet size of 16383 bytes
* Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
* `log` - logging flags (default 0/no logging, 63 for full logging)
* 1 log errors
* 2 log messages
* 4 log rows affected
* 8 trace sql statements
* 16 log statement parameters
* 32 log transaction begin/end
* "encrypt"
* disable - Data send between client and server is not encrypted.
* false - Data sent between client and server is not encrypted beyond the login packet. (Default)
* true - Data sent between client and server is encrypted.
* "TrustServerCertificate"
* `TrustServerCertificate`
* false - Server certificate is checked. Default is false if encypt is specified.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* "certificate" - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* "hostNameInCertificate" - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* "ServerSPN" - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* "Workstation ID" - The workstation name (default is the host name)
* "app name" - The application name (default is go-mssqldb)
* "ApplicationIntent" - Can be given the value "ReadOnly" to initiate a read-only connection to an Availability Group listener.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener.
Example:
### The connection string can be specified in one of three formats:
```go
db, err := sql.Open("mssql", "server=localhost;user id=sa")
```
## Statement Parameters
1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as
the first segment in the path. All other options are query parameters. Examples:
In the SQL statement text, literals may be replaced by a parameter that matches one of the following:
* `sqlserver://username:password@host/instance?param1=value&param2=value`
* `sqlserver://username:password@host:port?param1=value&param2=value`
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30"` // port 1234 on localhost.
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
* ?
* ?nnn
* :nnn
* $nnn
A string of this format can be constructed using the `URL` type in the `net/url` package.
where nnn represents an integer that specifies a 1-indexed positional parameter. Ex:
```go
query := url.Values{}
query.Add("connection timeout", "30")
u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(username, password),
Host: fmt.Sprintf("%s:%d", hostname, port),
// Path: instance, // if connecting to an instance instead of a port
RawQuery: query.Encode(),
}
db, err := sql.Open("sqlserver", u.String())
```
2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored.
Examples:
* `server=localhost\\SQLExpress;user id=sa;database=master;connection timeout=30`
* `server=localhost;user id=sa;database=master;connection timeout=30`
3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping
values in `{}`. Examples:
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;connection timeout=30`
* `odbc:server=localhost;user id=sa;database=master;connection timeout=30`
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar"
## Executing Stored Procedures
To run a stored procedure, set the query text to the procedure name:
```go
db.Query("SELECT * FROM t WHERE a = ?3, b = ?2, c = ?1", "x", "y", "z")
var account = "abc"
_, err := db.ExecContext(ctx, "sp_RunMe",
sql.Named("ID", 123),
sql.Out{Dest{sql.Named("Account", &account)}
)
```
will expand to roughly
## Statement Parameters
```sql
SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x'
```
The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters in
the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position).
```go
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
```
## Features
@ -87,6 +139,34 @@ SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x'
* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas.
* Supports query notifications
## Tests
`go test` is used for testing. A running instance of MSSQL server is required.
Environment variables are used to pass login information.
Example:
env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test
## Deprecated
These features still exist in the driver, but they are are deprecated.
### Query Parameter Token Replace (driver "mssql")
If you use the driver name "mssql" (rather then "sqlserver" the SQL text
will be loosly parsed and an attempt to extract identifiers using one of
* ?
* ?nnn
* :nnn
* $nnn
will be used. This is not recommended with SQL Server.
There is at least one existing `won't fix` issue with the query parsing.
Use the native "@Name" parameters instead with the "sqlserver" driver name.
## Known Issues
* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled.

@ -0,0 +1,45 @@
version: 1.0.{build}
os: Windows Server 2012 R2
clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb
environment:
GOPATH: c:\gopath
HOST: localhost
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 110
matrix:
- GOVERSION: 18
SQLINSTANCE: SQL2016
- GOVERSION: 110
SQLINSTANCE: SQL2016
- SQLINSTANCE: SQL2014
- SQLINSTANCE: SQL2012SP1
- SQLINSTANCE: SQL2008R2SP2
install:
- set GOROOT=c:\go%GOVERSION%
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
- go version
- go env
build_script:
- go build
before_test:
# setup SQL Server
- ps: |
$instanceName = $env:SQLINSTANCE
Start-Service "MSSQL`$$instanceName"
Start-Service "SQLBrowser"
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
- pip install codecov
test_script:
- go test -race -coverprofile=coverage.txt -covermode=atomic
- codecov -f coverage.txt

@ -2,12 +2,14 @@ package mssql
import (
"encoding/binary"
"io"
"errors"
"io"
)
type packetType uint8
type header struct {
PacketType uint8
PacketType packetType
Status uint8
Size uint16
Spid uint16
@ -15,55 +17,84 @@ type header struct {
Pad uint8
}
// tdsBuffer reads and writes TDS packets of data to the transport.
// The write and read buffers are separate to make sending attn signals
// possible without locks. Currently attn signals are only sent during
// reads, not writes.
type tdsBuffer struct {
buf []byte
pos uint16
transport io.ReadWriteCloser
size uint16
transport io.ReadWriteCloser
packetSize int
// Write fields.
wbuf []byte
wpos int
wPacketSeq byte
wPacketType packetType
// Read fields.
rbuf []byte
rpos int
rsize int
final bool
packet_type uint8
afterFirst func()
rPacketType packetType
// afterFirst is assigned to right after tdsBuffer is created and
// before the first use. It is executed after the first packet is
// written and then removed.
afterFirst func()
}
func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
buf := make([]byte, bufsize)
w := new(tdsBuffer)
w.buf = buf
w.pos = 8
w.transport = transport
w.size = 0
return w
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
return &tdsBuffer{
packetSize: int(bufsize),
wbuf: make([]byte, 1<<16),
rbuf: make([]byte, 1<<16),
rpos: 8,
transport: transport,
}
}
func (rw *tdsBuffer) ResizeBuffer(packetSize int) {
rw.packetSize = packetSize
}
func (w *tdsBuffer) PackageSize() int {
return w.packetSize
}
func (w *tdsBuffer) flush() (err error) {
// writing packet size
binary.BigEndian.PutUint16(w.buf[2:], w.pos)
// Write packet size.
w.wbuf[0] = byte(w.wPacketType)
binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos))
w.wbuf[6] = w.wPacketSeq
// writing packet into underlying transport
if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
// Write packet into underlying transport.
if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
return err
}
// It is possible to create a whole new buffer after a flush.
// Useful for debugging. Normally reuse the buffer.
// w.wbuf = make([]byte, 1<<16)
// execute afterFirst hook if it is set
// Execute afterFirst hook if it is set.
if w.afterFirst != nil {
w.afterFirst()
w.afterFirst = nil
}
w.pos = 8
// packet number
w.buf[6] += 1
w.wpos = 8
w.wPacketSeq++
return nil
}
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
total = 0
for {
copied := copy(w.buf[w.pos:], p)
w.pos += uint16(copied)
copied := copy(w.wbuf[w.wpos:w.packetSize], p)
w.wpos += copied
total += copied
if copied == len(p) {
break
return
}
if err = w.flush(); err != nil {
return
@ -74,66 +105,64 @@ func (w *tdsBuffer) Write(p []byte) (total int, err error) {
}
func (w *tdsBuffer) WriteByte(b byte) error {
if int(w.pos) == len(w.buf) {
if int(w.wpos) == len(w.wbuf) {
if err := w.flush(); err != nil {
return err
}
}
w.buf[w.pos] = b
w.pos += 1
w.wbuf[w.wpos] = b
w.wpos += 1
return nil
}
func (w *tdsBuffer) BeginPacket(packet_type byte) {
w.buf[0] = packet_type
w.buf[1] = 0 // packet is incomplete
w.buf[4] = 0 // spid
w.buf[5] = 0
w.buf[6] = 1 // packet id
w.buf[7] = 0 // window
w.pos = 8
func (w *tdsBuffer) BeginPacket(packetType packetType) {
w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
w.wpos = 8
w.wPacketSeq = 1
w.wPacketType = packetType
}
func (w *tdsBuffer) FinishPacket() error {
w.buf[1] = 1 // this is last packet
w.wbuf[1] = 1 // Mark this as the last packet in the message.
return w.flush()
}
var headerSize = binary.Size(header{})
func (r *tdsBuffer) readNextPacket() error {
header := header{}
h := header{}
var err error
err = binary.Read(r.transport, binary.BigEndian, &header)
err = binary.Read(r.transport, binary.BigEndian, &h)
if err != nil {
return err
}
offset := uint16(binary.Size(header))
if int(header.Size) > len(r.buf) {
if int(h.Size) > len(r.rbuf) {
return errors.New("Invalid packet size, it is longer than buffer size")
}
if int(offset) > int(header.Size) {
if headerSize > int(h.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
}
_, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
_, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
if err != nil {
return err
}
r.pos = offset
r.size = header.Size
r.final = header.Status != 0
r.packet_type = header.PacketType
r.rpos = headerSize
r.rsize = int(h.Size)
r.final = h.Status != 0
r.rPacketType = h.PacketType
return nil
}
func (r *tdsBuffer) BeginRead() (uint8, error) {
func (r *tdsBuffer) BeginRead() (packetType, error) {
err := r.readNextPacket()
if err != nil {
return 0, err
}
return r.packet_type, nil
return r.rPacketType, nil
}
func (r *tdsBuffer) ReadByte() (res byte, err error) {
if r.pos == r.size {
if r.rpos == r.rsize {
if r.final {
return 0, io.EOF
}
@ -142,8 +171,8 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) {
return 0, err
}
}
res = r.buf[r.pos]
r.pos++
res = r.rbuf[r.rpos]
r.rpos++
return res, nil
}
@ -207,7 +236,7 @@ func (r *tdsBuffer) readUcs2(numchars int) string {
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
copied = 0
err = nil
if r.pos == r.size {
if r.rpos == r.rsize {
if r.final {
return 0, io.EOF
}
@ -216,7 +245,7 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
return
}
}
copied = copy(buf, r.buf[r.pos:r.size])
r.pos += uint16(copied)
copied = copy(buf, r.rbuf[r.rpos:r.rsize])
r.rpos += copied
return
}

@ -0,0 +1,616 @@
package mssql
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
)
type Bulk struct {
cn *Conn
metadata []columnStruct
bulkColumns []columnStruct
columnsName []string
tablename string
numRows int
headerSent bool
Options BulkOptions
Debug bool
}
type BulkOptions struct {
CheckConstraints bool
FireTriggers bool
KeepNulls bool
KilobytesPerBatch int
RowsPerBatch int
Order []string
Tablock bool
}
type DataValue interface{}
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
b.Debug = false
return &b
}
func (b *Bulk) sendBulkCommand() (err error) {
//get table columns info
err = b.getMetadata()
if err != nil {
return err
}
//match the columns
for _, colname := range b.columnsName {
var bulkCol *columnStruct
for _, m := range b.metadata {
if m.ColName == colname {
bulkCol = &m
break
}
}
if bulkCol != nil {
if bulkCol.ti.TypeId == typeUdt {
//send udt as binary
bulkCol.ti.TypeId = typeBigVarBin
}
b.bulkColumns = append(b.bulkColumns, *bulkCol)
b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
} else {
return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
}
}
//create the bulk command
//columns definitions
var col_defs bytes.Buffer
for i, col := range b.bulkColumns {
if i != 0 {
col_defs.WriteString(", ")
}
col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
}
//options
var with_opts []string
if b.Options.CheckConstraints {
with_opts = append(with_opts, "CHECK_CONSTRAINTS")
}
if b.Options.FireTriggers {
with_opts = append(with_opts, "FIRE_TRIGGERS")
}
if b.Options.KeepNulls {
with_opts = append(with_opts, "KEEP_NULLS")
}
if b.Options.KilobytesPerBatch > 0 {
with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
}
if b.Options.RowsPerBatch > 0 {
with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
}
if len(b.Options.Order) > 0 {
with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
}
if b.Options.Tablock {
with_opts = append(with_opts, "TABLOCK")
}
var with_part string
if len(with_opts) > 0 {
with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
}
query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
stmt, err := b.cn.Prepare(query)
if err != nil {
return fmt.Errorf("Prepare failed: %s", err.Error())
}
b.dlogf(query)
_, err = stmt.Exec(nil)
if err != nil {
return err
}
b.headerSent = true
var buf = b.cn.sess.buf
buf.BeginPacket(packBulkLoadBCP)
// send the columns metadata
columnMetadata := b.createColMetadata()
_, err = buf.Write(columnMetadata)
return
}
// AddRow immediately writes the row to the destination table.
// The arguments are the row values in the order they were specified.
func (b *Bulk) AddRow(row []interface{}) (err error) {
if !b.headerSent {
err = b.sendBulkCommand()
if err != nil {
return
}
}
if len(row) != len(b.bulkColumns) {
return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
len(row), len(b.bulkColumns))
}
bytes, err := b.makeRowData(row)
if err != nil {
return
}
_, err = b.cn.sess.buf.Write(bytes)
if err != nil {
return
}
b.numRows = b.numRows + 1
return
}
func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
buf.WriteByte(byte(tokenRow))
var logcol bytes.Buffer
for i, col := range b.bulkColumns {
if b.Debug {
logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
}
param, err := b.makeParam(row[i], col)
if err != nil {
return nil, fmt.Errorf("bulkcopy: %s", err.Error())
}
if col.ti.Writer == nil {
return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
col.ColName, col.ti.TypeId)
}
err = col.ti.Writer(buf, param.ti, param.buffer)
if err != nil {
return nil, fmt.Errorf("bulkcopy: %s", err.Error())
}
}
b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
return buf.Bytes(), nil
}
func (b *Bulk) Done() (rowcount int64, err error) {
if b.headerSent == false {
//no rows had been sent
return 0, nil
}
var buf = b.cn.sess.buf
buf.WriteByte(byte(tokenDone))
binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
} else {
binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
}
buf.FinishPacket()
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), b.cn.sess, tokchan, nil)
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneStruct:
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
}
if token.isError() {
return 0, token.getError()
}
case error:
return 0, b.cn.checkBadConn(token)
}
}
return rowCount, nil
}
func (b *Bulk) createColMetadata() []byte {
buf := new(bytes.Buffer)
buf.WriteByte(byte(tokenColMetadata)) // token
binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
for i, col := range b.bulkColumns {
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
} else {
binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
}
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
writeTypeInfo(buf, &b.bulkColumns[i].ti)
if col.ti.TypeId == typeNText ||
col.ti.TypeId == typeText ||
col.ti.TypeId == typeImage {
tablename_ucs2 := str2ucs2(b.tablename)
binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
buf.Write(tablename_ucs2)
}
colname_ucs2 := str2ucs2(col.ColName)
buf.WriteByte(uint8(len(colname_ucs2) / 2))
buf.Write(colname_ucs2)
}
return buf.Bytes()
}
func (b *Bulk) getMetadata() (err error) {
stmt, err := b.cn.Prepare("SET FMTONLY ON")
if err != nil {
return
}
_, err = stmt.Exec(nil)
if err != nil {
return
}
//get columns info
stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
if err != nil {
return
}
stmt2 := stmt.(*Stmt)
cols, err := stmt2.QueryMeta()
if err != nil {
return fmt.Errorf("get columns info failed: %v", err.Error())
}
b.metadata = cols
if b.Debug {
for _, col := range b.metadata {
b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
col.Flags, col.ti.Collation.LcidAndFlags)
}
}
return nil
}
// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info.
func (s *Stmt) QueryMeta() (cols []columnStruct, err error) {
if err = s.sendQuery(nil); err != nil {
return
}
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
s.c.clearOuts()
loop:
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
break loop
case []columnStruct:
cols = token
break loop
case error:
return nil, s.c.checkBadConn(token)
}
}
return cols, nil
}
func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
res.ti.Size = col.ti.Size
res.ti.TypeId = col.ti.TypeId
if val == nil {
res.ti.Size = 0
return
}
switch col.ti.TypeId {
case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
var intvalue int64
switch val := val.(type) {
case int:
intvalue = int64(val)
case int32:
intvalue = int64(val)
case int64:
intvalue = val
default:
err = fmt.Errorf("mssql: invalid type for int column")
return
}
res.buffer = make([]byte, res.ti.Size)
if col.ti.Size == 1 {
res.buffer[0] = byte(intvalue)
} else if col.ti.Size == 2 {
binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
} else if col.ti.Size == 4 {
binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
} else if col.ti.Size == 8 {
binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
}
case typeFlt4, typeFlt8, typeFltN:
var floatvalue float64
switch val := val.(type) {
case float32:
floatvalue = float64(val)
case float64:
floatvalue = val
case int:
floatvalue = float64(val)
case int64:
floatvalue = float64(val)
default:
err = fmt.Errorf("mssql: invalid type for float column: %s", val)
return
}
if col.ti.Size == 4 {
res.buffer = make([]byte, 4)
binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
} else if col.ti.Size == 8 {
res.buffer = make([]byte, 8)
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
}
case typeNVarChar, typeNText, typeNChar:
switch val := val.(type) {
case string:
res.buffer = str2ucs2(val)
case []byte:
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
return
}
res.ti.Size = len(res.buffer)
case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
switch val := val.(type) {
case string:
res.buffer = []byte(val)
case []byte:
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
return
}
res.ti.Size = len(res.buffer)
case typeBit, typeBitN:
if reflect.TypeOf(val).Kind() != reflect.Bool {
err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
return
}
res.ti.TypeId = typeBitN
res.ti.Size = 1
res.buffer = make([]byte, 1)
if val.(bool) {
res.buffer[0] = 1
}
case typeDateTime2N, typeDateTimeOffsetN:
switch val := val.(type) {
case time.Time:
days, ns := dateTime2(val)
ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000)
var data = make([]byte, 5)
data[0] = byte(ns)
data[1] = byte(ns >> 8)
data[2] = byte(ns >> 16)
data[3] = byte(ns >> 24)
data[4] = byte(ns >> 32)
if col.ti.Scale <= 2 {
res.ti.Size = 6
} else if col.ti.Scale <= 4 {
res.ti.Size = 7
} else {
res.ti.Size = 8
}
var buf []byte
buf = make([]byte, res.ti.Size)
copy(buf, data[0:res.ti.Size-3])
buf[res.ti.Size-3] = byte(days)
buf[res.ti.Size-2] = byte(days >> 8)
buf[res.ti.Size-1] = byte(days >> 16)
if col.ti.TypeId == typeDateTimeOffsetN {
_, offset := val.Zone()
var offsetMinute = uint16(offset / 60)
buf = append(buf, byte(offsetMinute))
buf = append(buf, byte(offsetMinute>>8))
res.ti.Size = res.ti.Size + 2
}
res.buffer = buf
default:
err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
return
}
case typeDateN:
switch val := val.(type) {
case time.Time:
days, _ := dateTime2(val)
res.ti.Size = 3
res.buffer = make([]byte, 3)
res.buffer[0] = byte(days)
res.buffer[1] = byte(days >> 8)
res.buffer[2] = byte(days >> 16)
default:
err = fmt.Errorf("mssql: invalid type for date column: %s", val)
return
}
case typeDateTime, typeDateTimeN, typeDateTim4:
switch val := val.(type) {
case time.Time:
if col.ti.Size == 4 {
res.ti.Size = 4
res.buffer = make([]byte, 4)
ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
dur := val.Sub(ref)
days := dur / (24 * time.Hour)
if days < 0 {
err = fmt.Errorf("mssql: Date %s is out of range", val)
return
}
mins := val.Hour()*60 + val.Minute()
binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days))
binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins))
} else if col.ti.Size == 8 {
res.ti.Size = 8
res.buffer = make([]byte, 8)
days := divFloor(val.Unix(), 24*60*60)
//25567 - number of days since Jan 1 1900 UTC to Jan 1 1970
days = days + 25567
tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3)
binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
} else {
err = fmt.Errorf("mssql: invalid size of column")
}
default:
err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
}
// case typeMoney, typeMoney4, typeMoneyN:
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
var value float64
switch v := val.(type) {
case int:
value = float64(v)
case int8:
value = float64(v)
case int16:
value = float64(v)
case int32:
value = float64(v)
case int64:
value = float64(v)
case float32:
value = float64(v)
case float64:
value = v
case string:
if value, err = strconv.ParseFloat(v, 64); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
}
default:
return res, fmt.Errorf("unknown value for decimal: %#v", v)
}
perc := col.ti.Prec
scale := col.ti.Scale
var dec Decimal
dec, err = Float64ToDecimalScale(value, scale)
if err != nil {
return res, err
}
dec.prec = perc
var length byte
switch {
case perc <= 9:
length = 4
case perc <= 19:
length = 8
case perc <= 28:
length = 12
default:
length = 16
}
buf := make([]byte, length+1)
// first byte length written by typeInfo.writer
res.ti.Size = int(length) + 1
// second byte sign
if value < 0 {
buf[0] = 0
} else {
buf[0] = 1
}
ub := dec.UnscaledBytes()
l := len(ub)
if l > int(length) {
err = fmt.Errorf("decimal out of range: %s", dec)
return res, err
}
// reverse the bytes
for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
buf[i] = ub[j]
}
res.buffer = buf
case typeBigVarBin:
switch val := val.(type) {
case []byte:
res.ti.Size = len(val)
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
return
}
case typeGuid:
switch val := val.(type) {
case []byte:
res.ti.Size = len(val)
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
return
}
default:
err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
}
return
}
func (b *Bulk) dlogf(format string, v ...interface{}) {
if b.Debug {
b.cn.sess.log.Printf(format, v...)
}
}

@ -0,0 +1,93 @@
package mssql
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
)
type copyin struct {
cn *Conn
bulkcopy *Bulk
closed bool
}
type serializableBulkConfig struct {
TableName string
ColumnsName []string
Options BulkOptions
}
func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
return d.open(context.Background(), dsn)
}
func (c *Conn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
config_json := query[11:]
bulkconfig := serializableBulkConfig{}
err = json.Unmarshal([]byte(config_json), &bulkconfig)
if err != nil {
return
}
bulkcopy := c.CreateBulk(bulkconfig.TableName, bulkconfig.ColumnsName)
bulkcopy.Options = bulkconfig.Options
ci := &copyin{
cn: c,
bulkcopy: bulkcopy,
}
return ci, nil
}
func CopyIn(table string, options BulkOptions, columns ...string) string {
bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
config_json, err := json.Marshal(bulkconfig)
if err != nil {
panic(err)
}
stmt := "INSERTBULK " + string(config_json)
return stmt
}
func (ci *copyin) NumInput() int {
return -1
}
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
return nil, errors.New("ErrNotSupported")
}
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
if ci.closed {
return nil, errors.New("errCopyInClosed")
}
if len(v) == 0 {
rowCount, err := ci.bulkcopy.Done()
ci.closed = true
return driver.RowsAffected(rowCount), err
}
t := make([]interface{}, len(v))
for i, val := range v {
t[i] = val
}
err = ci.bulkcopy.AddRow(t)
if err != nil {
return
}
return driver.RowsAffected(0), nil
}
func (ci *copyin) Close() (err error) {
return nil
}

@ -1,39 +0,0 @@
package mssql
import (
"encoding/binary"
"io"
)
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
type collation struct {
lcidAndFlags uint32
sortId uint8
}
func (c collation) getLcid() uint32 {
return c.lcidAndFlags & 0x000fffff
}
func (c collation) getFlags() uint32 {
return (c.lcidAndFlags & 0x0ff00000) >> 20
}
func (c collation) getVersion() uint32 {
return (c.lcidAndFlags & 0xf0000000) >> 28
}
func readCollation(r *tdsBuffer) (res collation) {
res.lcidAndFlags = r.uint32()
res.sortId = r.byte()
return
}
func writeCollation(w io.Writer, col collation) (err error) {
if err = binary.Write(w, binary.LittleEndian, col.lcidAndFlags); err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, col.sortId)
return
}

@ -32,7 +32,13 @@ func (d Decimal) ToFloat64() float64 {
return val
}
const autoScale = 100
func Float64ToDecimal(f float64) (Decimal, error) {
return Float64ToDecimalScale(f, autoScale)
}
func Float64ToDecimalScale(f float64, scale uint8) (Decimal, error) {
var dec Decimal
if math.IsNaN(f) {
return dec, errors.New("NaN")
@ -49,10 +55,10 @@ func Float64ToDecimal(f float64) (Decimal, error) {
}
dec.prec = 20
var integer float64
for dec.scale = 0; dec.scale <= 20; dec.scale++ {
for dec.scale = 0; dec.scale <= scale; dec.scale++ {
integer = f * scaletblflt64[dec.scale]
_, frac := math.Modf(integer)
if frac == 0 {
if frac == 0 && scale == autoScale {
break
}
}
@ -73,7 +79,7 @@ func init() {
}
}
func (d Decimal) Bytes() []byte {
func (d Decimal) BigInt() big.Int {
bytes := make([]byte, 16)
binary.BigEndian.PutUint32(bytes[0:4], d.integer[3])
binary.BigEndian.PutUint32(bytes[4:8], d.integer[2])
@ -84,9 +90,19 @@ func (d Decimal) Bytes() []byte {
if !d.positive {
x.Neg(&x)
}
return x
}
func (d Decimal) Bytes() []byte {
x := d.BigInt()
return scaleBytes(x.String(), d.scale)
}
func (d Decimal) UnscaledBytes() []byte {
x := d.BigInt()
return x.Bytes()
}
func scaleBytes(s string, scale uint8) []byte {
z := make([]byte, 0, len(s)+1)
if s[0] == '-' || s[0] == '+' {

@ -0,0 +1,12 @@
// package mssql implements the TDS protocol used to connect to MS SQL Server (sqlserver)
// database servers.
//
// This package registers two drivers:
// sqlserver: uses native "@" parameter placeholder names and does no pre-processing.
// mssql: expects identifiers to be prefixed with ":" and pre-processes queries.
//
// If the ordinal position is used for query parameters, identifiers will be named
// "@p1", "@p2", ... "@pN".
//
// Please refer to the README for the format of the DSN.
package mssql

@ -1,14 +1,14 @@
package mssql
package cp
type charsetMap struct {
sb [256]rune // single byte runes, -1 for a double byte character lead byte
db map[int]rune // double byte runes
}
func collation2charset(col collation) *charsetMap {
func collation2charset(col Collation) *charsetMap {
// http://msdn.microsoft.com/en-us/library/ms144250.aspx
// http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx
switch col.sortId {
switch col.SortId {
case 30, 31, 32, 33, 34:
return cp437
case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61:
@ -86,7 +86,7 @@ func collation2charset(col collation) *charsetMap {
return cp1252
}
func charset2utf8(col collation, s []byte) string {
func CharsetToUTF8(col Collation, s []byte) string {
cm := collation2charset(col)
if cm == nil {
return string(s)

@ -0,0 +1,20 @@
package cp
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
type Collation struct {
LcidAndFlags uint32
SortId uint8
}
func (c Collation) getLcid() uint32 {
return c.LcidAndFlags & 0x000fffff
}
func (c Collation) getFlags() uint32 {
return (c.LcidAndFlags & 0x0ff00000) >> 20
}
func (c Collation) getVersion() uint32 {
return (c.LcidAndFlags & 0xf0000000) >> 28
}

@ -1,4 +1,4 @@
package mssql
package cp
var cp1250 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1251 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1252 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1253 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1254 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1255 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1256 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1257 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp1258 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp437 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp850 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp874 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp932 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp936 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp949 *charsetMap = &charsetMap{
sb: [256]rune{

@ -1,4 +1,4 @@
package mssql
package cp
var cp950 *charsetMap = &charsetMap{
sb: [256]rune{

@ -4,19 +4,26 @@ import (
"log"
)
type Logger log.Logger
type Logger interface {
Printf(format string, v ...interface{})
Println(v ...interface{})
}
type optionalLogger struct {
logger Logger
}
func (logger *Logger) Printf(format string, v ...interface{}) {
if logger != nil {
(*log.Logger)(logger).Printf(format, v...)
func (o optionalLogger) Printf(format string, v ...interface{}) {
if o.logger != nil {
o.logger.Printf(format, v...)
} else {
log.Printf(format, v...)
}
}
func (logger *Logger) Println(v ...interface{}) {
if logger != nil {
(*log.Logger)(logger).Println(v...)
func (o optionalLogger) Println(v ...interface{}) {
if o.logger != nil {
o.logger.Println(v...)
} else {
log.Println(v...)
}

@ -1,122 +1,269 @@
package mssql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"math"
"net"
"reflect"
"strings"
"time"
)
var driverInstance = &Driver{processQueryText: true}
var driverInstanceNoProcess = &Driver{processQueryText: false}
func init() {
sql.Register("mssql", &MssqlDriver{})
sql.Register("mssql", driverInstance)
sql.Register("sqlserver", driverInstanceNoProcess)
createDialer = func(p *connectParams) dialer {
return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
}
}
// Abstract the dialer for testing and for non-TCP based connections.
type dialer interface {
Dial(ctx context.Context, addr string) (net.Conn, error)
}
var createDialer func(p *connectParams) dialer
type tcpDialer struct {
nd *net.Dialer
}
func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return d.nd.DialContext(ctx, "tcp", addr)
}
type Driver struct {
log optionalLogger
processQueryText bool
}
// OpenConnector opens a new connector. Useful to dial with a context.
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
if err != nil {
return nil, err
}
return &Connector{
params: params,
driver: d,
}, nil
}
func (d *Driver) Open(dsn string) (driver.Conn, error) {
return d.open(context.Background(), dsn)
}
// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
// In the future, settings that cannot be passed through a string DSN
// may be set directly on the connector.
type Connector struct {
params connectParams
driver *Driver
}
type MssqlDriver struct {
log *log.Logger
// Connect to the server and return a TDS connection.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.driver.connect(ctx, c.params)
}
func (d *MssqlDriver) SetLogger(logger *log.Logger) {
d.log = logger
// Driver underlying the Connector.
func (c *Connector) Driver() driver.Driver {
return c.driver
}
func CheckBadConn(err error) error {
if err == io.EOF {
func SetLogger(logger Logger) {
driverInstance.SetLogger(logger)
driverInstanceNoProcess.SetLogger(logger)
}
func (d *Driver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
}
type Conn struct {
sess *tdsSession
transactionCtx context.Context
processQueryText bool
connectionGood bool
outs map[string]interface{}
}
func (c *Conn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
// but we return actual error instead of ErrBadConn
// this will cause connection to stay in a pool
// but next request to this connection will return ErrBadConn
// it might be possible to revise this hack after
// https://github.com/golang/go/issues/20807
// is implemented
switch err {
case nil:
return nil
case io.EOF:
return driver.ErrBadConn
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
// only ever be returned in response to a *mssql.Conn.connectionGood == false
// check in the external facing API.
panic("driver.ErrBadConn in checkBadConn. This should not happen.")
}
switch e := err.(type) {
switch err.(type) {
case net.Error:
if e.Timeout() {
return e
}
return driver.ErrBadConn
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
}
}
type MssqlConn struct {
sess *tdsSession
func (c *Conn) clearOuts() {
c.outs = nil
}
func (c *MssqlConn) Commit() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
return err
}
func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
go processResponse(ctx, c.sess, tokchan, c.outs)
c.clearOuts()
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return c.checkBadConn(token.getError())
}
case error:
return token
return c.checkBadConn(token)
}
}
return nil
}
func (c *MssqlConn) Rollback() error {
func (c *Conn) Commit() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil {
return c.checkBadConn(err)
}
return c.simpleProcessResp(c.transactionCtx)
}
func (c *Conn) sendCommitRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
return err
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Faild to send CommitXact: %v", err)
}
return nil
}
func (c *Conn) Rollback() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil {
return c.checkBadConn(err)
}
return c.simpleProcessResp(c.transactionCtx)
}
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case error:
return token
func (c *Conn) sendRollbackRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send RollbackXact: %v", err)
}
return nil
}
func (c *MssqlConn) Begin() (driver.Tx, error) {
func (c *Conn) Begin() (driver.Tx, error) {
return c.begin(context.Background(), isolationUseCurrent)
}
func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
return nil, c.checkBadConn(err)
}
return
}
func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
c.transactionCtx = ctx
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
if err := sendBeginXact(c.sess.buf, headers, 0, ""); err != nil {
return nil, CheckBadConn(err)
}
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case error:
if c.sess.tranid != 0 {
return nil, token
}
return nil, CheckBadConn(token)
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send BiginXant: %v", err)
}
return nil
}
func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
if err := c.simpleProcessResp(ctx); err != nil {
return nil, err
}
// successful BEGINXACT request will return sess.tranid
// for started transaction
return c, nil
}
func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
params, err := parseConnectParams(dsn)
if err != nil {
return nil, err
}
return d.connect(ctx, params)
}
sess, err := connect(params)
// connect to the server, using the provided context for dialing only.
func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
sess, err := connect(ctx, d.log, params)
if err != nil {
// main server failed, try fail-over partner
if params.failOverPartner == "" {
@ -128,24 +275,29 @@ func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
params.port = params.failOverPort
}
sess, err = connect(params)
sess, err = connect(ctx, d.log, params)
if err != nil {
// fail-over partner also failed, now fail
return nil, err
}
}
conn := &MssqlConn{sess}
conn.sess.log = (*Logger)(d.log)
conn := &Conn{
sess: sess,
transactionCtx: context.Background(),
processQueryText: d.processQueryText,
connectionGood: true,
}
conn.sess.log = d.log
return conn, nil
}
func (c *MssqlConn) Close() error {
func (c *Conn) Close() error {
return c.sess.buf.transport.Close()
}
type MssqlStmt struct {
c *MssqlConn
type Stmt struct {
c *Conn
query string
paramCount int
notifSub *queryNotifSub
@ -157,16 +309,30 @@ type queryNotifSub struct {
timeout uint32
}
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
q, paramCount := parseParams(query)
return &MssqlStmt{c, q, paramCount, nil}, nil
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
}
return c.prepareContext(context.Background(), query)
}
func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
paramCount := -1
if c.processQueryText {
query, paramCount = parseParams(query)
}
return &Stmt{c, query, paramCount, nil}, nil
}
func (s *MssqlStmt) Close() error {
func (s *Stmt) Close() error {
return nil
}
func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
to := uint32(timeout / time.Second)
if to < 1 {
to = 1
@ -174,183 +340,273 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati
s.notifSub = &queryNotifSub{id, options, to}
}
func (s *MssqlStmt) NumInput() int {
func (s *Stmt) NumInput() int {
return s.paramCount
}
func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
func (s *Stmt) sendQuery(args []namedValue) (err error) {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
}
if s.notifSub != nil {
headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
}
if len(args) != s.paramCount {
return errors.New(fmt.Sprintf("sql: expected %d parameters, got %d", s.paramCount, len(args)))
}
headers = append(headers,
headerStruct{
hdrtype: dataStmHdrQueryNotif,
data: queryNotifHdr{
s.notifSub.msgText,
s.notifSub.options,
s.notifSub.timeout,
}.pack(),
})
}
// no need to check number of parameters here, it is checked by database/sql
if s.c.sess.logFlags&logSQL != 0 {
s.c.sess.log.Println(s.query)
}
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
for i := 0; i < len(args); i++ {
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
if len(args[i].Name) > 0 {
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
} else {
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
}
}
}
if len(args) == 0 {
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
if s.c.sess.tranid != 0 {
return err
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
}
return CheckBadConn(err)
s.c.connectionGood = false
return fmt.Errorf("failed to send SQL Batch: %v", err)
}
} else {
params := make([]Param, len(args)+2)
decls := make([]string, len(args))
params[0], err = s.makeParam(s.query)
if err != nil {
return
}
for i, val := range args {
params[i+2], err = s.makeParam(val)
proc := Sp_ExecuteSql
var params []Param
if isProc(s.query) {
proc.name = s.query
params, _, err = s.makeRPCParams(args, 0)
} else {
var decls []string
params, decls, err = s.makeRPCParams(args, 2)
if err != nil {
return
}
name := fmt.Sprintf("@p%d", i+1)
params[i+2].Name = name
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
}
params[1], err = s.makeParam(strings.Join(decls, ","))
if err != nil {
return
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
if s.c.sess.tranid != 0 {
return err
if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
}
return CheckBadConn(err)
s.c.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
}
}
return
}
func (s *MssqlStmt) Query(args []driver.Value) (res driver.Rows, err error) {
// isProc takes the query text in s and determines if it is a stored proc name
// or SQL text.
func isProc(s string) bool {
if len(s) == 0 {
return false
}
if s[0] == '[' && s[len(s)-1] == ']' && strings.ContainsAny(s, "\n\r") == false {
return true
}
return !strings.ContainsAny(s, " \t\n\r;")
}
func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
var err error
params := make([]Param, len(args)+offset)
decls := make([]string, len(args))
for i, val := range args {
params[i+offset], err = s.makeParam(val.Value)
if err != nil {
return nil, nil, err
}
var name string
if len(val.Name) > 0 {
name = "@" + val.Name
} else {
name = fmt.Sprintf("@p%d", val.Ordinal)
}
params[i+offset].Name = name
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
}
return params, decls, nil
}
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
func convertOldArgs(args []driver.Value) []namedValue {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return list
}
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.queryContext(context.Background(), convertOldArgs(args))
}
func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return
return nil, s.c.checkBadConn(err)
}
return s.processQueryResponse(ctx)
}
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(s.c.sess, tokchan)
ctx, cancel := context.WithCancel(ctx)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
s.c.clearOuts()
// process metadata
var cols []string
var cols []columnStruct
loop:
for tok := range tokchan {
switch token := tok.(type) {
// by ignoring DONE token we effectively
// skip empty result-sets
// this improves results in queryes like that:
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = make([]string, len(token))
for i, col := range token {
cols[i] = col.ColName
}
cols = token
break loop
case error:
if s.c.sess.tranid != 0 {
return nil, token
case doneStruct:
if token.isError() {
return nil, s.c.checkBadConn(token.getError())
}
return nil, CheckBadConn(token)
case error:
return nil, s.c.checkBadConn(token)
}
}
return &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols}, nil
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
return
}
func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) {
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), convertOldArgs(args))
}
func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return
return nil, s.c.checkBadConn(err)
}
if res, err = s.processExec(ctx); err != nil {
return nil, s.c.checkBadConn(err)
}
return
}
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(s.c.sess, tokchan)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
s.c.clearOuts()
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneInProcStruct:
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
rowCount += int64(token.RowCount)
}
case error:
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Println("got error:", token)
}
if s.c.sess.tranid != 0 {
return nil, token
if token.isError() {
return nil, token.getError()
}
return nil, CheckBadConn(token)
case error:
return nil, token
}
}
return &MssqlResult{s.c, rowCount}, nil
return &Result{s.c, rowCount}, nil
}
type MssqlRows struct {
sess *tdsSession
cols []string
type Rows struct {
stmt *Stmt
cols []columnStruct
tokchan chan tokenStruct
nextCols []string
nextCols []columnStruct
cancel func()
}
func (rc *MssqlRows) Close() error {
func (rc *Rows) Close() error {
rc.cancel()
for _ = range rc.tokchan {
}
rc.tokchan = nil
return nil
}
func (rc *MssqlRows) Columns() (res []string) {
return rc.cols
func (rc *Rows) Columns() (res []string) {
res = make([]string, len(rc.cols))
for i, col := range rc.cols {
res[i] = col.ColName
}
return
}
func (rc *MssqlRows) Next(dest []driver.Value) (err error) {
func (rc *Rows) Next(dest []driver.Value) error {
if !rc.stmt.c.connectionGood {
return driver.ErrBadConn
}
if rc.nextCols != nil {
return io.EOF
}
for tok := range rc.tokchan {
switch tokdata := tok.(type) {
case []columnStruct:
cols := make([]string, len(tokdata))
for i, col := range tokdata {
cols[i] = col.ColName
}
rc.nextCols = cols
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case error:
return tokdata
return rc.stmt.c.checkBadConn(tokdata)
}
}
return io.EOF
}
func (rc *MssqlRows) HasNextResultSet() bool {
func (rc *Rows) HasNextResultSet() bool {
return rc.nextCols != nil
}
func (rc *MssqlRows) NextResultSet() error {
func (rc *Rows) NextResultSet() error {
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
@ -359,11 +615,69 @@ func (rc *MssqlRows) NextResultSet() error {
return nil
}
func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
// It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
// database system type name without the length. Type names should be uppercase.
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
}
// It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
}
// The nullable value should
// be true if it is known the column may be null, or false if the column is known
// to be not nullable.
// If the column nullability is unknown, ok should be false.
func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
nullable = r.cols[index].Flags&colFlagNullable != 0
ok = true
return
}
func makeStrParam(val string) (res Param) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
return
}
func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
if val == nil {
res.ti.TypeId = typeNVarChar
res.ti.TypeId = typeNull
res.buffer = nil
res.ti.Size = 2
res.ti.Size = 0
return
}
switch val := val.(type) {
@ -382,9 +696,7 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
res.ti.Size = len(val)
res.buffer = val
case string:
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
res = makeStrParam(val)
case bool:
res.ti.TypeId = typeBitN
res.ti.Size = 1
@ -425,22 +737,21 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
}
default:
err = fmt.Errorf("mssql: unknown type for %T", val)
return
return s.makeParamExtra(val)
}
return
}
type MssqlResult struct {
c *MssqlConn
type Result struct {
c *Conn
rowsAffected int64
}
func (r *MssqlResult) RowsAffected() (int64, error) {
func (r *Result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
func (r *MssqlResult) LastInsertId() (int64, error) {
func (r *Result) LastInsertId() (int64, error) {
s, err := r.c.Prepare("select cast(@@identity as bigint)")
if err != nil {
return 0, err

@ -1,11 +0,0 @@
// +build go1.3
package mssql
import (
"net"
)
func createDialer(p connectParams) *net.Dialer {
return &net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}
}

@ -1,11 +0,0 @@
// +build !go1.3
package mssql
import (
"net"
)
func createDialer(p *connectParams) *net.Dialer {
return &net.Dialer{Timeout: p.dial_timeout}
}

@ -0,0 +1,91 @@
// +build go1.8
package mssql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"strings"
)
var _ driver.Pinger = &Conn{}
// Ping is used to check if the remote server is available and satisfies the Pinger interface.
func (c *Conn) Ping(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
stmt := &Stmt{c, `select 1;`, 0, nil}
_, err := stmt.ExecContext(ctx, nil)
return err
}
var _ driver.ConnBeginTx = &Conn{}
// BeginTx satisfies ConnBeginTx.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if opts.ReadOnly {
return nil, errors.New("Read-only transactions are not supported")
}
var tdsIsolation isoLevel
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
tdsIsolation = isolationUseCurrent
case sql.LevelReadUncommitted:
tdsIsolation = isolationReadUncommited
case sql.LevelReadCommitted:
tdsIsolation = isolationReadCommited
case sql.LevelWriteCommitted:
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
tdsIsolation = isolationRepeatableRead
case sql.LevelSnapshot:
tdsIsolation = isolationSnapshot
case sql.LevelSerializable:
tdsIsolation = isolationSerializable
case sql.LevelLinearizable:
return nil, errors.New("LevelLinearizable isolation level is not supported")
default:
return nil, errors.New("Isolation level is not supported or unknown")
}
return c.begin(ctx, tdsIsolation)
}
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
}
return c.prepareContext(ctx, query)
}
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.queryContext(ctx, list)
}
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.exec(ctx, list)
}

@ -0,0 +1,64 @@
// +build go1.9
package mssql
import (
"database/sql"
"database/sql/driver"
"fmt"
// "github.com/cockroachdb/apd"
)
// Type alias provided for compibility.
//
// Deprecated: users should transition to the new names when possible.
type MssqlDriver = Driver
type MssqlBulk = Bulk
type MssqlBulkOptions = BulkOptions
type MssqlConn = Conn
type MssqlResult = Result
type MssqlRows = Rows
type MssqlStmt = Stmt
var _ driver.NamedValueChecker = &Conn{}
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
switch v := nv.Value.(type) {
case sql.Out:
if c.outs == nil {
c.outs = make(map[string]interface{})
}
c.outs[nv.Name] = v.Dest
// Unwrap the Out value and check the inner value.
lnv := *nv
lnv.Value = v.Dest
err := c.CheckNamedValue(&lnv)
if err != nil {
if err != driver.ErrSkip {
return err
}
lnv.Value, err = driver.DefaultParameterConverter.ConvertValue(lnv.Value)
if err != nil {
return err
}
}
nv.Value = sql.Out{Dest: lnv.Value}
return nil
// case *apd.Decimal:
// return nil
default:
return driver.ErrSkip
}
}
func (s *Stmt) makeParamExtra(val driver.Value) (res Param, err error) {
switch val := val.(type) {
case sql.Out:
res, err = s.makeParam(val.Dest)
res.Flags = fByRevValue
default:
err = fmt.Errorf("mssql: unknown type for %T", val)
}
return
}

@ -0,0 +1,12 @@
// +build !go1.9
package mssql
import (
"database/sql/driver"
"fmt"
)
func (s *Stmt) makeParamExtra(val driver.Value) (Param, error) {
return Param{}, fmt.Errorf("mssql: unknown type for %T", val)
}

@ -33,7 +33,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
c.continueRead = false
}
if !c.continueRead {
var packet uint8
var packet packetType
packet, err = c.buf.BeginRead()
if err != nil {
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())

@ -59,7 +59,7 @@ type NTLMAuth struct {
Workstation string
}
func getAuth(user, password, service, workstation string) (Auth, bool) {
func getAuth(user, password, service, workstation string) (auth, bool) {
if !strings.ContainsRune(user, '\\') {
return nil, false
}

@ -11,6 +11,9 @@ type parser struct {
w bytes.Buffer
paramCount int
paramMax int
// using map as a set
namedParams map[string]bool
}
func (p *parser) next() (rune, bool) {
@ -39,13 +42,14 @@ type stateFunc func(*parser) stateFunc
func parseParams(query string) (string, int) {
p := &parser{
r: bytes.NewReader([]byte(query)),
r: bytes.NewReader([]byte(query)),
namedParams: map[string]bool{},
}
state := parseNormal
for state != nil {
state = state(p)
}
return p.w.String(), p.paramMax
return p.w.String(), p.paramMax + len(p.namedParams)
}
func parseNormal(p *parser) stateFunc {
@ -55,7 +59,7 @@ func parseNormal(p *parser) stateFunc {
return nil
}
if ch == '?' {
return parseParameter
return parseOrdinalParameter
} else if ch == '$' || ch == ':' {
ch2, ok := p.next()
if !ok {
@ -64,7 +68,9 @@ func parseNormal(p *parser) stateFunc {
}
p.unread()
if ch2 >= '0' && ch2 <= '9' {
return parseParameter
return parseOrdinalParameter
} else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' {
return parseNamedParameter
}
}
p.write(ch)
@ -83,7 +89,7 @@ func parseNormal(p *parser) stateFunc {
}
}
func parseParameter(p *parser) stateFunc {
func parseOrdinalParameter(p *parser) stateFunc {
var paramN int
var ok bool
for {
@ -113,6 +119,30 @@ func parseParameter(p *parser) stateFunc {
return parseNormal
}
func parseNamedParameter(p *parser) stateFunc {
var paramName string
var ok bool
for {
var ch rune
ch, ok = p.next()
if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') {
paramName = paramName + string(ch)
} else {
break
}
}
if ok {
p.unread()
}
p.namedParams[paramName] = true
p.w.WriteString("@")
p.w.WriteString(paramName)
if !ok {
return nil
}
return parseNormal
}
func parseQuote(p *parser) stateFunc {
for {
ch, ok := p.next()

@ -113,7 +113,7 @@ type SSPIAuth struct {
ctxt SecHandle
}
func getAuth(user, password, service, workstation string) (Auth, bool) {
func getAuth(user, password, service, workstation string) (auth, bool) {
if user == "" {
return &SSPIAuth{Service: service}, true
}

@ -1,6 +1,7 @@
package mssql
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
@ -9,11 +10,13 @@ import (
"io"
"io/ioutil"
"net"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf16"
"unicode/utf8"
)
@ -47,8 +50,11 @@ func parseInstances(msg []byte) map[string]map[string]string {
return results
}
func getInstances(address string) (map[string]map[string]string, error) {
conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
dialer := &net.Dialer{
Timeout: 5 * time.Second,
}
conn, err := dialer.DialContext(ctx, "udp", address+":1434")
if err != nil {
return nil, err
}
@ -79,11 +85,16 @@ const (
)
// packet types
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch = 1
packRPCRequest = 3
packReply = 4
packCancel = 6
packSQLBatch packetType = 1
packRPCRequest = 3
packReply = 4
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
packAttention = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
@ -119,7 +130,7 @@ type tdsSession struct {
columns []columnStruct
tranid uint64
logFlags uint64
log *Logger
log optionalLogger
routedServer string
routedPort uint16
}
@ -131,6 +142,7 @@ const (
logSQL = 8
logParams = 16
logTransaction = 32
logDebug = 64
)
type columnStruct struct {
@ -490,6 +502,11 @@ func readBVarChar(r io.Reader) (res string, err error) {
if err != nil {
return "", err
}
// A zero length could be returned, return an empty string
if numchars == 0 {
return "", nil
}
return readUcs2(r, int(numchars))
}
@ -588,7 +605,7 @@ func (hdr transDescrHdr) pack() (res []byte) {
}
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
// calculatint total length
// Calculating total length.
var totallen uint32 = 4
for _, hdr := range headers {
totallen += 4 + 2 + uint32(len(hdr.data))
@ -616,9 +633,7 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
return nil
}
func sendSqlBatch72(buf *tdsBuffer,
sqltext string,
headers []headerStruct) (err error) {
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
buf.BeginPacket(packSQLBatch)
if err = writeAllHeaders(buf, headers); err != nil {
@ -632,6 +647,13 @@ func sendSqlBatch72(buf *tdsBuffer,
return buf.FinishPacket()
}
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
func sendAttention(buf *tdsBuffer) error {
buf.BeginPacket(packAttention)
return buf.FinishPacket()
}
type connectParams struct {
logFlags uint64
port uint64
@ -654,6 +676,7 @@ type connectParams struct {
typeFlags uint8
failOverPartner string
failOverPort uint64
packetSize uint16
}
func splitConnectionString(dsn string) (res map[string]string) {
@ -677,9 +700,241 @@ func splitConnectionString(dsn string) (res map[string]string) {
return res
}
// Splits a URL in the ODBC format
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
res := map[string]string{}
type parserState int
const (
// Before the start of a key
parserStateBeforeKey parserState = iota
// Inside a key
parserStateKey
// Beginning of a value. May be bare or braced
parserStateBeginValue
// Inside a bare value
parserStateBareValue
// Inside a braced value
parserStateBracedValue
// A closing brace inside a braced value.
// May be the end of the value or an escaped closing brace, depending on the next character
parserStateBracedValueClosingBrace
// After a value. Next character should be a semicolon or whitespace.
parserStateEndValue
)
var state = parserStateBeforeKey
var key string
var value string
for i, c := range dsn {
switch state {
case parserStateBeforeKey:
switch {
case c == '=':
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
case !unicode.IsSpace(c) && c != ';':
state = parserStateKey
key += string(c)
}
case parserStateKey:
switch c {
case '=':
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
}
state = parserStateBeginValue
case ';':
// Key without value
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
}
res[key] = value
key = ""
value = ""
state = parserStateBeforeKey
default:
key += string(c)
}
case parserStateBeginValue:
switch {
case c == '{':
state = parserStateBracedValue
case c == ';':
// Empty value
res[key] = value
key = ""
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
state = parserStateBareValue
value += string(c)
}
case parserStateBareValue:
if c == ';' {
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
key = ""
value = ""
state = parserStateBeforeKey
} else {
value += string(c)
}
case parserStateBracedValue:
if c == '}' {
state = parserStateBracedValueClosingBrace
} else {
value += string(c)
}
case parserStateBracedValueClosingBrace:
if c == '}' {
// Escaped closing brace
value += string(c)
state = parserStateBracedValue
continue
}
// End of braced value
res[key] = value
key = ""
value = ""
// This character is the first character past the end,
// so it needs to be parsed like the parserStateEndValue state.
state = parserStateEndValue
switch {
case c == ';':
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
}
case parserStateEndValue:
switch {
case c == ';':
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
}
}
}
switch state {
case parserStateBeforeKey: // Okay
case parserStateKey: // Unfinished key. Treat as key without value.
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
}
res[key] = value
case parserStateBeginValue: // Empty value
res[key] = value
case parserStateBareValue:
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
case parserStateBracedValue:
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
case parserStateBracedValueClosingBrace: // End of braced value
res[key] = value
case parserStateEndValue: // Okay
}
return res, nil
}
// Normalizes the given string as an ODBC-format key
func normalizeOdbcKey(s string) string {
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
}
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
func splitConnectionStringURL(dsn string) (map[string]string, error) {
res := map[string]string{}
u, err := url.Parse(dsn)
if err != nil {
return res, err
}
if u.Scheme != "sqlserver" {
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
}
if u.User != nil {
res["user id"] = u.User.Username()
p, exists := u.User.Password()
if exists {
res["password"] = p
}
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
}
if len(u.Path) > 0 {
res["server"] = host + "\\" + u.Path[1:]
} else {
res["server"] = host
}
if len(port) > 0 {
res["port"] = port
}
query := u.Query()
for k, v := range query {
if len(v) > 1 {
return res, fmt.Errorf("key %s provided more than once", k)
}
res[strings.ToLower(k)] = v[0]
}
return res, nil
}
func parseConnectParams(dsn string) (connectParams, error) {
params := splitConnectionString(dsn)
var p connectParams
var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
}
params = parameters
} else if strings.HasPrefix(dsn, "sqlserver://") {
parameters, err := splitConnectionStringURL(dsn)
if err != nil {
return p, err
}
params = parameters
} else {
params = splitConnectionString(dsn)
}
strlog, ok := params["log"]
if ok {
var err error
@ -712,7 +967,32 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}
p.dial_timeout = 5 * time.Second
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
// Default packet size remains at 4096 bytes
p.packetSize = 4096
strpsize, ok := params["packet size"]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
if err != nil {
f := "Invalid packet size '%v': %v"
return p, fmt.Errorf(f, strpsize, err.Error())
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
p.packetSize = uint16(psize)
if p.packetSize < 512 {
p.packetSize = 512
} else if p.packetSize > 32767 {
p.packetSize = 32767
}
}
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.dial_timeout = 15 * time.Second
p.conn_timeout = 30 * time.Second
strconntimeout, ok := params["connection timeout"]
if ok {
@ -732,8 +1012,12 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
p.dial_timeout = time.Duration(timeout) * time.Second
}
keepAlive, ok := params["keepalive"]
if ok {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.keepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
@ -743,7 +1027,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
encrypt, ok := params["encrypt"]
if ok {
if strings.ToUpper(encrypt) == "DISABLE" {
if strings.EqualFold(encrypt, "DISABLE") {
p.disableEncryption = true
} else {
var err error
@ -819,7 +1103,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
return p, nil
}
type Auth interface {
type auth interface {
InitialBytes() ([]byte, error)
NextBytes([]byte) ([]byte, error)
Free()
@ -828,7 +1112,7 @@ type Auth interface {
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
func dialConnection(p connectParams) (conn net.Conn, err error) {
func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
@ -839,9 +1123,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
ips = []net.IP{ip}
}
if len(ips) == 1 {
d := createDialer(p)
d := createDialer(&p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
conn, err = d.Dial("tcp", addr)
conn, err = d.Dial(ctx, addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
@ -850,9 +1134,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
portStr := strconv.Itoa(int(p.port))
for _, ip := range ips {
go func(ip net.IP) {
d := createDialer(p)
d := createDialer(&p)
addr := net.JoinHostPort(ip.String(), portStr)
conn, err := d.Dial("tcp", addr)
conn, err := d.Dial(ctx, addr)
if err == nil {
connChan <- conn
} else {
@ -887,16 +1171,15 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
f := "Unable to open tcp connection with host '%v:%v': %v"
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
}
return conn, err
}
func connect(p connectParams) (res *tdsSession, err error) {
func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
res = nil
// if instance is specified use instance resolution service
if p.instance != "" {
p.instance = strings.ToUpper(p.instance)
instances, err := getInstances(p.host)
instances, err := getInstances(ctx, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
@ -914,16 +1197,17 @@ func connect(p connectParams) (res *tdsSession, err error) {
}
initiate_connection:
conn, err := dialConnection(p)
conn, err := dialConnection(ctx, p)
if err != nil {
return nil, err
}
toconn := NewTimeoutConn(conn, p.conn_timeout)
outbuf := newTdsBuffer(4096, toconn)
outbuf := newTdsBuffer(p.packetSize, toconn)
sess := tdsSession{
buf: outbuf,
log: log,
logFlags: p.logFlags,
}
@ -969,8 +1253,7 @@ initiate_connection:
if p.certificate != "" {
pem, err := ioutil.ReadFile(p.certificate)
if err != nil {
f := "Cannot read certificate '%s': %s"
return nil, fmt.Errorf(f, p.certificate, err.Error())
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
}
certs := x509.NewCertPool()
certs.AppendCertsFromPEM(pem)
@ -980,15 +1263,20 @@ initiate_connection:
config.InsecureSkipVerify = true
}
config.ServerName = p.hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
outbuf.transport = conn
toconn.buf = outbuf
tlsConn := tls.Client(toconn, &config)
err = tlsConn.Handshake()
toconn.buf = nil
outbuf.transport = tlsConn
if err != nil {
f := "TLS Handshake failed: %s"
return nil, fmt.Errorf(f, err.Error())
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
@ -999,7 +1287,7 @@ initiate_connection:
login := login{
TDSVersion: verTDS74,
PacketSize: uint32(len(outbuf.buf)),
PacketSize: uint32(outbuf.PackageSize()),
Database: p.database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.workstation,
@ -1028,7 +1316,7 @@ initiate_connection:
var sspi_msg []byte
continue_login:
tokchan := make(chan tokenStruct, 5)
go processResponse(&sess, tokchan)
go processResponse(context.Background(), &sess, tokchan, nil)
success := false
for tok := range tokchan {
switch token := tok.(type) {
@ -1042,6 +1330,10 @@ continue_login:
sess.loginAck = token
case error:
return nil, fmt.Errorf("Login error: %s", token.Error())
case doneStruct:
if token.isError() {
return nil, fmt.Errorf("Login error: %s", token.getError())
}
}
}
if sspi_msg != nil {

@ -1,30 +1,40 @@
package mssql
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
)
//go:generate stringer -type token
type token byte
// token ids
const (
tokenReturnStatus = 121 // 0x79
tokenColMetadata = 129 // 0x81
tokenOrder = 169 // 0xA9
tokenError = 170 // 0xAA
tokenInfo = 171 // 0xAB
tokenLoginAck = 173 // 0xad
tokenRow = 209 // 0xd1
tokenNbcRow = 210 // 0xd2
tokenEnvChange = 227 // 0xE3
tokenSSPI = 237 // 0xED
tokenDone = 253 // 0xFD
tokenDoneProc = 254
tokenDoneInProc = 255
tokenReturnStatus token = 121 // 0x79
tokenColMetadata token = 129 // 0x81
tokenOrder token = 169 // 0xA9
tokenError token = 170 // 0xAA
tokenInfo token = 171 // 0xAB
tokenReturnValue token = 0xAC
tokenLoginAck token = 173 // 0xad
tokenRow token = 209 // 0xd1
tokenNbcRow token = 210 // 0xd2
tokenEnvChange token = 227 // 0xE3
tokenSSPI token = 237 // 0xED
tokenDone token = 253 // 0xFD
tokenDoneProc token = 254
tokenDoneInProc token = 255
)
// done flags
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
const (
doneFinal = 0
doneMore = 1
@ -59,6 +69,13 @@ const (
envRouting = 20
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
colFlagNullable = 1
// TODO implement more flags
)
// interface for all tokens
type tokenStruct interface{}
@ -70,6 +87,19 @@ type doneStruct struct {
Status uint16
CurCmd uint16
RowCount uint64
errors []Error
}
func (d doneStruct) isError() bool {
return d.Status&doneError != 0 || len(d.errors) > 0
}
func (d doneStruct) getError() Error {
if len(d.errors) > 0 {
return d.errors[len(d.errors)-1]
} else {
return Error{Message: "Request failed but didn't provide reason"}
}
}
type doneInProcStruct doneStruct
@ -120,27 +150,23 @@ func processEnvChg(sess *tdsSession) {
badStreamPanic(err)
}
case envTypLanguage:
//currently ignored
// old value
_, err = readBVarChar(r)
if err != nil {
badStreamPanic(err)
}
// currently ignored
// new value
_, err = readBVarChar(r)
if err != nil {
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypCharset:
//currently ignored
// old value
_, err = readBVarChar(r)
if err != nil {
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypCharset:
// currently ignored
// new value
_, err = readBVarChar(r)
if err != nil {
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypPacketSize:
@ -156,38 +182,55 @@ func processEnvChg(sess *tdsSession) {
if err != nil {
badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
}
if len(sess.buf.buf) != packetsizei {
newbuf := make([]byte, packetsizei)
copy(newbuf, sess.buf.buf)
sess.buf.buf = newbuf
}
sess.buf.ResizeBuffer(packetsizei)
case envSortId:
// currently ignored
// old value, should be 0
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// new value
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSortFlags:
// currently ignored
// old value, should be 0
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// new value
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSqlCollation:
// currently ignored
// old value
if _, err = readBVarChar(r); err != nil {
var collationSize uint8
err = binary.Read(r, binary.LittleEndian, &collationSize)
if err != nil {
badStreamPanic(err)
}
// new value
// SQL Collation data should contain 5 bytes in length
if collationSize != 5 {
badStreamPanicf("Invalid SQL Collation size value returned from server: %s", collationSize)
}
// 4 bytes, contains: LCID ColFlags Version
var info uint32
err = binary.Read(r, binary.LittleEndian, &info)
if err != nil {
badStreamPanic(err)
}
// 1 byte, contains: sortID
var sortID uint8
err = binary.Read(r, binary.LittleEndian, &sortID)
if err != nil {
badStreamPanic(err)
}
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
@ -226,21 +269,21 @@ func processEnvChg(sess *tdsSession) {
sess.tranid = 0
case envEnlistDTC:
// currently ignored
// old value
// new value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// new value, should be 0
// old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envDefectTran:
// currently ignored
// old value, should be 0
// new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
// new value
// old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
@ -358,6 +401,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) {
return res
}
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
func parseDone(r *tdsBuffer) (res doneStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@ -365,6 +409,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) {
return res
}
// https://msdn.microsoft.com/en-us/library/dd340553.aspx
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@ -473,26 +518,57 @@ func parseInfo(r *tdsBuffer) (res Error) {
return
}
func processResponse(sess *tdsSession, ch chan tokenStruct) {
// https://msdn.microsoft.com/en-us/library/dd303881.aspx
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
/*
ParamOrdinal
ParamName
Status
UserType
Flags
TypeInfo
CryptoMetadata
Value
*/
r.uint16()
nv.Name = r.BVarChar()
r.byte()
r.uint32() // UserType (uint16 prior to 7.2)
r.uint16()
ti := readTypeInfo(r)
nv.Value = ti.Reader(&ti, r)
return
}
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: Intercepted panic %v", err)
}
ch <- err
}
close(ch)
}()
packet_type, err := sess.buf.BeginRead()
if err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: BeginRead failed %v", err)
}
ch <- err
return
}
if packet_type != packReply {
badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type)
badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply))
}
var columns []columnStruct
var lastError Error
var failed bool
errs := make([]Error, 0, 5)
for {
token := sess.buf.byte()
token := token(sess.buf.byte())
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got token %v", token)
}
switch token {
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
@ -514,18 +590,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
ch <- done
case tokenDone, tokenDoneProc:
done := parseDone(sess.buf)
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
if done.Status&doneError != 0 || failed {
ch <- lastError
return
done.errors = errs
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
}
if done.Status&doneSrvError != 0 {
lastError.Message = "Server Error"
ch <- lastError
ch <- errors.New("SQL Server had internal error")
return
}
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
ch <- done
if done.Status&doneMore == 0 {
return
@ -544,18 +619,210 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
case tokenEnvChange:
processEnvChg(sess)
case tokenError:
lastError = parseError72(sess.buf)
failed = true
err := parseError72(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
}
errs = append(errs, err)
if sess.logFlags&logErrors != 0 {
sess.log.Println(lastError.Message)
sess.log.Println(err.Message)
}
case tokenInfo:
info := parseInfo(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got INFO %d %s", info.Number, info.Message)
}
if sess.logFlags&logMessages != 0 {
sess.log.Println(info.Message)
}
case tokenReturnValue:
nv := parseReturnValue(sess.buf)
if len(nv.Name) > 0 {
name := nv.Name[1:] // Remove the leading "@".
if ov, has := outs[name]; has {
err = scanIntoOut(nv.Value, ov)
if err != nil {
fmt.Println("scan error", err)
ch <- err
}
}
}
default:
badStreamPanic(fmt.Errorf("unknown token type returned: %v", token))
}
}
}
func scanIntoOut(fromServer, scanInto interface{}) error {
switch fs := fromServer.(type) {
case int64:
switch si := scanInto.(type) {
case *int64:
*si = fs
default:
badStreamPanicf("Unknown token type: %d", token)
return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
}
return nil
case string:
switch si := scanInto.(type) {
case *string:
*si = fs
default:
return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
}
return nil
}
return fmt.Errorf("unsupported type from server %[1]T=%[1]v", fromServer)
}
type parseRespIter byte
const (
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
parseRespIterNext // Fetch the next token.
parseRespIterDone // Done with parsing the response.
)
type parseRespState byte
const (
parseRespStateNormal parseRespState = iota // Normal response state.
parseRespStateCancel // Query is canceled, wait for server to confirm.
parseRespStateClosing // Waiting for tokens to come through.
)
type parseResp struct {
sess *tdsSession
ctxDone <-chan struct{}
state parseRespState
cancelError error
}
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
if err := sendAttention(ts.sess.buf); err != nil {
ts.dlogf("failed to send attention signal %v", err)
ch <- err
return parseRespIterDone
}
ts.state = parseRespStateCancel
return parseRespIterContinue
}
func (ts *parseResp) dlog(msg string) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Println(msg)
}
}
func (ts *parseResp) dlogf(f string, v ...interface{}) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Printf(f, v...)
}
}
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
switch ts.state {
default:
panic("unknown state")
case parseRespStateNormal:
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished")
return parseRespIterDone
}
if err, ok := tok.(net.Error); ok && err.Timeout() {
ts.cancelError = err
ts.dlog("got timeout error, sending attention signal to server")
return ts.sendAttention(ch)
}
// Pass the token along.
ch <- tok
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.dlog("got cancel message, sending attention signal to server")
return ts.sendAttention(ch)
}
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished but waiting for attention ack")
return parseRespIterNext
}
switch tok := tok.(type) {
default:
// Ignore all other tokens while waiting.
// The TDS spec says other tokens may arrive after an attention
// signal is sent. Ignore these tokens and continue looking for
// a DONE with attention confirm mark.
case doneStruct:
if tok.Status&doneAttn != 0 {
ts.dlog("got cancellation confirmation from server")
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
} else {
ch <- ctx.Err()
}
return parseRespIterDone
}
// If an error happens during cancel, pass it along and just stop.
// We are uncertain to receive more tokens.
case error:
ch <- tok
ts.state = parseRespStateClosing
}
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.state = parseRespStateClosing
return parseRespIterContinue
}
case parseRespStateClosing: // Wait for current token chan to close.
if _, ok := <-tokChan; !ok {
ts.dlog("response finished")
return parseRespIterDone
}
return parseRespIterContinue
}
}
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
ts := &parseResp{
sess: sess,
ctxDone: ctx.Done(),
}
defer func() {
// Ensure any remaining error is piped through
// or the query may look like it executed when it actually failed.
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
}
close(ch)
}()
// Loop over multiple responses.
for {
ts.dlog("initiating response reading")
tokChan := make(chan tokenStruct)
go processSingleResponse(sess, tokChan, outs)
// Loop over multiple tokens in response.
tokensLoop:
for {
switch ts.iter(ctx, ch, tokChan) {
case parseRespIterContinue:
// Nothing, continue to next token.
case parseRespIterNext:
break tokensLoop
case parseRespIterDone:
return
}
}
}
}

@ -0,0 +1,53 @@
// Code generated by "stringer -type token"; DO NOT EDIT
package mssql
import "fmt"
const (
_token_name_0 = "tokenReturnStatus"
_token_name_1 = "tokenColMetadata"
_token_name_2 = "tokenOrdertokenErrortokenInfo"
_token_name_3 = "tokenLoginAck"
_token_name_4 = "tokenRowtokenNbcRow"
_token_name_5 = "tokenEnvChange"
_token_name_6 = "tokenSSPI"
_token_name_7 = "tokenDonetokenDoneProctokenDoneInProc"
)
var (
_token_index_0 = [...]uint8{0, 17}
_token_index_1 = [...]uint8{0, 16}
_token_index_2 = [...]uint8{0, 10, 20, 29}
_token_index_3 = [...]uint8{0, 13}
_token_index_4 = [...]uint8{0, 8, 19}
_token_index_5 = [...]uint8{0, 14}
_token_index_6 = [...]uint8{0, 9}
_token_index_7 = [...]uint8{0, 9, 22, 37}
)
func (i token) String() string {
switch {
case i == 121:
return _token_name_0
case i == 129:
return _token_name_1
case 169 <= i && i <= 171:
i -= 169
return _token_name_2[_token_index_2[i]:_token_index_2[i+1]]
case i == 173:
return _token_name_3
case 209 <= i && i <= 210:
i -= 209
return _token_name_4[_token_index_4[i]:_token_index_4[i+1]]
case i == 227:
return _token_name_5
case i == 237:
return _token_name_6
case 253 <= i && i <= 255:
i -= 253
return _token_name_7[_token_index_7[i]:_token_index_7[i+1]]
default:
return fmt.Sprintf("token(%d)", i)
}
}

@ -1,6 +1,7 @@
package mssql
// Transaction Manager requests
// http://msdn.microsoft.com/en-us/library/dd339887.aspx
package mssql
import (
"encoding/binary"
@ -16,7 +17,18 @@ const (
tmSaveXact = 9
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation uint8,
type isoLevel uint8
const (
isolationUseCurrent isoLevel = 0
isolationReadUncommited = 1
isolationReadCommited = 2
isolationRepeatableRead = 3
isolationSerializable = 4
isolationSnapshot = 5
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
name string) (err error) {
buf.BeginPacket(packTransMgrReq)
writeAllHeaders(buf, headers)

@ -6,8 +6,11 @@ import (
"fmt"
"io"
"math"
"reflect"
"strconv"
"time"
"github.com/denisenkom/go-mssqldb/internal/cp"
)
// fixed-length data types
@ -66,6 +69,9 @@ const (
typeNText = 0x63
typeVariant = 0x62
)
const PLP_NULL = 0xFFFFFFFFFFFFFFFF
const UNKNOWN_PLP_LEN = 0xFFFFFFFFFFFFFFFE
const PLP_TERMINATOR = 0x00000000
// TYPE_INFO rule
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
@ -75,11 +81,32 @@ type typeInfo struct {
Scale uint8
Prec uint8
Buffer []byte
Collation collation
Collation cp.Collation
UdtInfo udtInfo
XmlInfo xmlInfo
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
Writer func(w io.Writer, ti typeInfo, buf []byte) (err error)
}
// Common Language Runtime (CLR) Instances
// http://msdn.microsoft.com/en-us/library/dd357962.aspx
type udtInfo struct {
//MaxByteSize uint32
DBName string
SchemaName string
TypeName string
AssemblyQualifiedName string
}
// XML Values
// http://msdn.microsoft.com/en-us/library/dd304764.aspx
type xmlInfo struct {
SchemaPresent uint8
DBName string
OwningSchema string
XmlSchemaCollection string
}
func readTypeInfo(r *tdsBuffer) (res typeInfo) {
res.TypeId = r.byte()
switch res.TypeId {
@ -114,7 +141,8 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
switch ti.TypeId {
case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
// those are fixed length types
// those are fixed length
ti.Writer = writeFixedType
default: // all others are VARLENTYPE
err = writeVarLen(w, ti)
if err != nil {
@ -124,19 +152,25 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
return
}
func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
_, err = w.Write(buf)
return
}
func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
switch ti.TypeId {
case typeDateN:
ti.Writer = writeByteLenType
case typeTimeN, typeDateTime2N, typeDateTimeOffsetN:
if err = binary.Write(w, binary.LittleEndian, ti.Scale); err != nil {
return
}
ti.Writer = writeByteLenType
case typeGuid, typeIntN, typeDecimal, typeNumeric,
case typeIntN, typeDecimal, typeNumeric,
typeBitN, typeDecimalN, typeNumericN, typeFltN,
typeMoneyN, typeDateTimeN, typeChar,
typeVarChar, typeBinary, typeVarBinary:
// byle len types
if ti.Size > 0xff {
panic("Invalid size for BYLELEN_TYPE")
@ -156,6 +190,14 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
}
}
ti.Writer = writeByteLenType
case typeGuid:
if !(ti.Size == 0x10 || ti.Size == 0x00) {
panic("Invalid size for BYLELEN_TYPE")
}
if err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)); err != nil {
return
}
ti.Writer = writeByteLenType
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
typeNVarChar, typeNChar, typeXml, typeUdt:
// short len types
@ -176,14 +218,19 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
return
}
case typeXml:
var schemapresent uint8 = 0
if err = binary.Write(w, binary.LittleEndian, schemapresent); err != nil {
if err = binary.Write(w, binary.LittleEndian, ti.XmlInfo.SchemaPresent); err != nil {
return
}
}
case typeText, typeImage, typeNText, typeVariant:
// LONGLEN_TYPE
panic("LONGLEN_TYPE not implemented")
if err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)); err != nil {
return
}
if err = writeCollation(w, ti.Collation); err != nil {
return
}
ti.Writer = writeLongLenType
default:
panic("Invalid type")
}
@ -207,7 +254,7 @@ func decodeDateTime(buf []byte) time.Time {
0, 0, secs, ns, time.UTC)
}
func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
r.ReadFull(ti.Buffer)
buf := ti.Buffer
switch ti.TypeId {
@ -241,12 +288,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
panic("shoulnd't get here")
}
func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
_, err = w.Write(buf)
return
}
func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.byte()
if size == 0 {
return nil
@ -305,6 +347,10 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
default:
badStreamPanicf("Invalid size for MONEYNTYPE")
}
case typeDateTim4:
return decodeDateTim4(buf)
case typeDateTime:
return decodeDateTime(buf)
case typeDateTimeN:
switch len(buf) {
case 4:
@ -341,7 +387,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readShortLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint16()
if size == 0xffff {
return nil
@ -384,7 +430,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} {
// information about this format can be found here:
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
// and here:
@ -415,10 +461,51 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
}
panic("shoulnd't get here")
}
func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
//textptr
err = binary.Write(w, binary.LittleEndian, byte(0x10))
if err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
if err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
if err != nil {
return
}
//timestamp?
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
if err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, uint32(ti.Size))
if err != nil {
return
}
_, err = w.Write(buf)
return
}
func readCollation(r *tdsBuffer) (res cp.Collation) {
res.LcidAndFlags = r.uint32()
res.SortId = r.byte()
return
}
func writeCollation(w io.Writer, col cp.Collation) (err error) {
if err = binary.Write(w, binary.LittleEndian, col.LcidAndFlags); err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, col.SortId)
return
}
// reads variant value
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.int32()
if size == 0 {
return nil
@ -510,14 +597,14 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
// partially length prefixed stream
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint64()
var buf *bytes.Buffer
switch size {
case 0xffffffffffffffff:
case PLP_NULL:
// null
return nil
case 0xfffffffffffffffe:
case UNKNOWN_PLP_LEN:
// size unknown
buf = bytes.NewBuffer(make([]byte, 0, 1000))
default:
@ -548,15 +635,16 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
}
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
if err = binary.Write(w, binary.LittleEndian, uint64(len(buf))); err != nil {
if err = binary.Write(w, binary.LittleEndian, uint64(UNKNOWN_PLP_LEN)); err != nil {
return
}
for {
chunksize := uint32(len(buf))
if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil {
if chunksize == 0 {
err = binary.Write(w, binary.LittleEndian, uint32(PLP_TERMINATOR))
return
}
if chunksize == 0 {
if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil {
return
}
if _, err = w.Write(buf[:chunksize]); err != nil {
@ -606,19 +694,27 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) {
}
ti.Reader = readByteLenType
case typeXml:
schemapresent := r.byte()
if schemapresent != 0 {
// just ignore this for now
ti.XmlInfo.SchemaPresent = r.byte()
if ti.XmlInfo.SchemaPresent != 0 {
// dbname
r.BVarChar()
ti.XmlInfo.DBName = r.BVarChar()
// owning schema
r.BVarChar()
ti.XmlInfo.OwningSchema = r.BVarChar()
// xml schema collection
r.UsVarChar()
ti.XmlInfo.XmlSchemaCollection = r.UsVarChar()
}
ti.Reader = readPLPType
case typeUdt:
ti.Size = int(r.uint16())
ti.UdtInfo.DBName = r.BVarChar()
ti.UdtInfo.SchemaName = r.BVarChar()
ti.UdtInfo.TypeName = r.BVarChar()
ti.UdtInfo.AssemblyQualifiedName = r.UsVarChar()
ti.Buffer = make([]byte, ti.Size)
ti.Reader = readPLPType
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
typeNVarChar, typeNChar, typeUdt:
typeNVarChar, typeNChar:
// short len types
ti.Size = int(r.uint16())
switch ti.TypeId {
@ -701,7 +797,8 @@ func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte {
// http://msdn.microsoft.com/en-us/library/ee780895.aspx
func decodeDateInt(buf []byte) (days int) {
return int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256
days = int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256
return
}
func decodeDate(buf []byte) time.Time {
@ -767,8 +864,8 @@ func dateTime2(t time.Time) (days int32, ns int64) {
return
}
func decodeChar(col collation, buf []byte) string {
return charset2utf8(col, buf)
func decodeChar(col cp.Collation, buf []byte) string {
return cp.CharsetToUTF8(col, buf)
}
func decodeUcs2(buf []byte) string {
@ -787,12 +884,127 @@ func decodeXml(ti typeInfo, buf []byte) string {
return decodeUcs2(buf)
}
func decodeUdt(ti typeInfo, buf []byte) int {
panic("Not implemented")
func decodeUdt(ti typeInfo, buf []byte) []byte {
return buf
}
// makes go/sql type instance as described below
// It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func makeGoLangScanType(ti typeInfo) reflect.Type {
switch ti.TypeId {
case typeInt1:
return reflect.TypeOf(int64(0))
case typeInt2:
return reflect.TypeOf(int64(0))
case typeInt4:
return reflect.TypeOf(int64(0))
case typeInt8:
return reflect.TypeOf(int64(0))
case typeFlt4:
return reflect.TypeOf(float64(0))
case typeIntN:
switch ti.Size {
case 1:
return reflect.TypeOf(int64(0))
case 2:
return reflect.TypeOf(int64(0))
case 4:
return reflect.TypeOf(int64(0))
case 8:
return reflect.TypeOf(int64(0))
default:
panic("invalid size of INTNTYPE")
}
case typeFlt8:
return reflect.TypeOf(float64(0))
case typeFltN:
switch ti.Size {
case 4:
return reflect.TypeOf(float64(0))
case 8:
return reflect.TypeOf(float64(0))
default:
panic("invalid size of FLNNTYPE")
}
case typeBigVarBin:
return reflect.TypeOf([]byte{})
case typeVarChar:
return reflect.TypeOf("")
case typeNVarChar:
return reflect.TypeOf("")
case typeBit, typeBitN:
return reflect.TypeOf(true)
case typeDecimalN, typeNumericN:
return reflect.TypeOf([]byte{})
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return reflect.TypeOf([]byte{})
case 8:
return reflect.TypeOf([]byte{})
default:
panic("invalid size of MONEYN")
}
case typeDateTim4:
return reflect.TypeOf(time.Time{})
case typeDateTime:
return reflect.TypeOf(time.Time{})
case typeDateTimeN:
switch ti.Size {
case 4:
return reflect.TypeOf(time.Time{})
case 8:
return reflect.TypeOf(time.Time{})
default:
panic("invalid size of DATETIMEN")
}
case typeDateTime2N:
return reflect.TypeOf(time.Time{})
case typeDateN:
return reflect.TypeOf(time.Time{})
case typeTimeN:
return reflect.TypeOf(time.Time{})
case typeDateTimeOffsetN:
return reflect.TypeOf(time.Time{})
case typeBigVarChar:
return reflect.TypeOf("")
case typeBigChar:
return reflect.TypeOf("")
case typeNChar:
return reflect.TypeOf("")
case typeGuid:
return reflect.TypeOf([]byte{})
case typeXml:
return reflect.TypeOf("")
case typeText:
return reflect.TypeOf("")
case typeNText:
return reflect.TypeOf("")
case typeImage:
return reflect.TypeOf([]byte{})
case typeBigBinary:
return reflect.TypeOf([]byte{})
case typeVariant:
return reflect.TypeOf(nil)
default:
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
}
}
func makeDecl(ti typeInfo) string {
switch ti.TypeId {
case typeNull:
// maybe we should use something else here
// this is tested in TestNull
return "nvarchar(1)"
case typeInt1:
return "tinyint"
case typeInt2:
return "smallint"
case typeInt4:
return "int"
case typeInt8:
return "bigint"
case typeFlt4:
@ -821,24 +1033,415 @@ func makeDecl(ti typeInfo) string {
default:
panic("invalid size of FLNNTYPE")
}
case typeDecimal, typeDecimalN:
return fmt.Sprintf("decimal(%d, %d)", ti.Prec, ti.Scale)
case typeNumeric, typeNumericN:
return fmt.Sprintf("numeric(%d, %d)", ti.Prec, ti.Scale)
case typeMoney4:
return "smallmoney"
case typeMoney:
return "money"
case typeMoneyN:
switch ti.Size {
case 4:
return "smallmoney"
case 8:
return "money"
default:
panic("invalid size of MONEYNTYPE")
}
case typeBigVarBin:
if ti.Size > 8000 || ti.Size == 0 {
return fmt.Sprintf("varbinary(max)")
return "varbinary(max)"
} else {
return fmt.Sprintf("varbinary(%d)", ti.Size)
}
case typeNChar:
return fmt.Sprintf("nchar(%d)", ti.Size/2)
case typeBigChar, typeChar:
return fmt.Sprintf("char(%d)", ti.Size)
case typeBigVarChar, typeVarChar:
if ti.Size > 4000 || ti.Size == 0 {
return fmt.Sprintf("varchar(max)")
} else {
return fmt.Sprintf("varchar(%d)", ti.Size)
}
case typeNVarChar:
if ti.Size > 8000 || ti.Size == 0 {
return fmt.Sprintf("nvarchar(max)")
return "nvarchar(max)"
} else {
return fmt.Sprintf("nvarchar(%d)", ti.Size/2)
}
case typeBit, typeBitN:
return "bit"
case typeDateTimeN:
case typeDateN:
return "date"
case typeDateTim4:
return "smalldatetime"
case typeDateTime:
return "datetime"
case typeDateTimeN:
switch ti.Size {
case 4:
return "smalldatetime"
case 8:
return "datetime"
default:
panic("invalid size of DATETIMNTYPE")
}
case typeDateTime2N:
return fmt.Sprintf("datetime2(%d)", ti.Scale)
case typeDateTimeOffsetN:
return fmt.Sprintf("datetimeoffset(%d)", ti.Scale)
case typeText:
return "text"
case typeNText:
return "ntext"
case typeUdt:
return ti.UdtInfo.TypeName
case typeGuid:
return "uniqueidentifier"
default:
panic(fmt.Sprintf("not implemented makeDecl for type %#x", ti.TypeId))
}
}
// makes go/sql type name as described below
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
// database system type name without the length. Type names should be uppercase.
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func makeGoLangTypeName(ti typeInfo) string {
switch ti.TypeId {
case typeInt1:
return "TINYINT"
case typeInt2:
return "SMALLINT"
case typeInt4:
return "INT"
case typeInt8:
return "BIGINT"
case typeFlt4:
return "REAL"
case typeIntN:
switch ti.Size {
case 1:
return "TINYINT"
case 2:
return "SMALLINT"
case 4:
return "INT"
case 8:
return "BIGINT"
default:
panic("invalid size of INTNTYPE")
}
case typeFlt8:
return "FLOAT"
case typeFltN:
switch ti.Size {
case 4:
return "REAL"
case 8:
return "FLOAT"
default:
panic("invalid size of FLNNTYPE")
}
case typeBigVarBin:
return "VARBINARY"
case typeVarChar:
return "VARCHAR"
case typeNVarChar:
return "NVARCHAR"
case typeBit, typeBitN:
return "BIT"
case typeDecimalN, typeNumericN:
return "DECIMAL"
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return "SMALLMONEY"
case 8:
return "MONEY"
default:
panic("invalid size of MONEYN")
}
case typeDateTim4:
return "SMALLDATETIME"
case typeDateTime:
return "DATETIME"
case typeDateTimeN:
switch ti.Size {
case 4:
return "SMALLDATETIME"
case 8:
return "DATETIME"
default:
panic("invalid size of DATETIMEN")
}
case typeDateTime2N:
return "DATETIME2"
case typeDateN:
return "DATE"
case typeTimeN:
return "TIME"
case typeDateTimeOffsetN:
return "DATETIMEOFFSET"
case typeBigVarChar:
return "VARCHAR"
case typeBigChar:
return "CHAR"
case typeNChar:
return "NCHAR"
case typeGuid:
return "UNIQUEIDENTIFIER"
case typeXml:
return "XML"
case typeText:
return "TEXT"
case typeNText:
return "NTEXT"
case typeImage:
return "IMAGE"
case typeVariant:
return "SQL_VARIANT"
case typeBigBinary:
return "BINARY"
default:
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
}
}
// makes go/sql type length as described below
// It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
switch ti.TypeId {
case typeInt1:
return 0, false
case typeInt2:
return 0, false
case typeInt4:
return 0, false
case typeInt8:
return 0, false
case typeFlt4:
return 0, false
case typeIntN:
switch ti.Size {
case 1:
return 0, false
case 2:
return 0, false
case 4:
return 0, false
case 8:
return 0, false
default:
panic("invalid size of INTNTYPE")
}
case typeFlt8:
return 0, false
case typeFltN:
switch ti.Size {
case 4:
return 0, false
case 8:
return 0, false
default:
panic("invalid size of FLNNTYPE")
}
case typeBit, typeBitN:
return 0, false
case typeDecimalN, typeNumericN:
return 0, false
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return 0, false
case 8:
return 0, false
default:
panic("invalid size of MONEYN")
}
case typeDateTim4, typeDateTime:
return 0, false
case typeDateTimeN:
switch ti.Size {
case 4:
return 0, false
case 8:
return 0, false
default:
panic("invalid size of DATETIMEN")
}
case typeDateTime2N:
return 0, false
case typeDateN:
return 0, false
case typeTimeN:
return 0, false
case typeDateTimeOffsetN:
return 0, false
case typeBigVarBin:
if ti.Size == 0xffff {
return 2147483645, true
} else {
return int64(ti.Size), true
}
case typeVarChar:
return int64(ti.Size), true
case typeBigVarChar:
if ti.Size == 0xffff {
return 2147483645, true
} else {
return int64(ti.Size), true
}
case typeBigChar:
return int64(ti.Size), true
case typeNVarChar:
if ti.Size == 0xffff {
return 2147483645 / 2, true
} else {
return int64(ti.Size) / 2, true
}
case typeNChar:
return int64(ti.Size) / 2, true
case typeGuid:
return 0, false
case typeXml:
return 1073741822, true
case typeText:
return 2147483647, true
case typeNText:
return 1073741823, true
case typeImage:
return 2147483647, true
case typeVariant:
return 0, false
case typeBigBinary:
return 0, false
default:
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
}
}
// makes go/sql type precision and scale as described below
// It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
switch ti.TypeId {
case typeInt1:
return 0, 0, false
case typeInt2:
return 0, 0, false
case typeInt4:
return 0, 0, false
case typeInt8:
return 0, 0, false
case typeFlt4:
return 0, 0, false
case typeIntN:
switch ti.Size {
case 1:
return 0, 0, false
case 2:
return 0, 0, false
case 4:
return 0, 0, false
case 8:
return 0, 0, false
default:
panic("invalid size of INTNTYPE")
}
case typeFlt8:
return 0, 0, false
case typeFltN:
switch ti.Size {
case 4:
return 0, 0, false
case 8:
return 0, 0, false
default:
panic("invalid size of FLNNTYPE")
}
case typeBit, typeBitN:
return 0, 0, false
case typeDecimalN, typeNumericN:
return int64(ti.Prec), int64(ti.Scale), true
case typeMoney, typeMoney4, typeMoneyN:
switch ti.Size {
case 4:
return 0, 0, false
case 8:
return 0, 0, false
default:
panic("invalid size of MONEYN")
}
case typeDateTim4, typeDateTime:
return 0, 0, false
case typeDateTimeN:
switch ti.Size {
case 4:
return 0, 0, false
case 8:
return 0, 0, false
default:
panic("invalid size of DATETIMEN")
}
case typeDateTime2N:
return 0, 0, false
case typeDateN:
return 0, 0, false
case typeTimeN:
return 0, 0, false
case typeDateTimeOffsetN:
return 0, 0, false
case typeBigVarBin:
return 0, 0, false
case typeVarChar:
return 0, 0, false
case typeBigVarChar:
return 0, 0, false
case typeBigChar:
return 0, 0, false
case typeNVarChar:
return 0, 0, false
case typeNChar:
return 0, 0, false
case typeGuid:
return 0, 0, false
case typeXml:
return 0, 0, false
case typeText:
return 0, 0, false
case typeNText:
return 0, 0, false
case typeImage:
return 0, 0, false
case typeVariant:
return 0, 0, false
case typeBigBinary:
return 0, 0, false
default:
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
}

@ -0,0 +1,74 @@
package mssql
import (
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
)
type UniqueIdentifier [16]byte
func (u *UniqueIdentifier) Scan(v interface{}) error {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
b[i], b[j] = b[j], b[i]
}
}
switch vt := v.(type) {
case []byte:
if len(vt) != 16 {
return errors.New("mssql: invalid UniqueIdentifier length")
}
var raw UniqueIdentifier
copy(raw[:], vt)
reverse(raw[0:4])
reverse(raw[4:6])
reverse(raw[6:8])
*u = raw
return nil
case string:
if len(vt) != 36 {
return errors.New("mssql: invalid UniqueIdentifier string length")
}
b := []byte(vt)
for i, c := range b {
switch c {
case '-':
b = append(b[:i], b[i+1:]...)
}
}
_, err := hex.Decode(u[:], []byte(b))
return err
default:
return fmt.Errorf("mssql: cannot convert %T to UniqueIdentifier", v)
}
}
func (u UniqueIdentifier) Value() (driver.Value, error) {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
b[i], b[j] = b[j], b[i]
}
}
raw := make([]byte, len(u))
copy(raw, u[:])
reverse(raw[0:4])
reverse(raw[4:6])
reverse(raw[6:8])
return raw, nil
}
func (u UniqueIdentifier) String() string {
return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:])
}

@ -85,8 +85,9 @@ github.com/couchbase/vellum/utf8
github.com/couchbaselabs/go-couchbase
# github.com/davecgh/go-spew v1.1.1
github.com/davecgh/go-spew/spew
# github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 => github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449
# github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 => github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44
github.com/denisenkom/go-mssqldb
github.com/denisenkom/go-mssqldb/internal/cp
# github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac
github.com/dgrijalva/jwt-go
# github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712

Loading…
Cancel
Save